fix: enhance backup service to support restoration from WAL files and add corresponding tests

This commit is contained in:
GitHub Actions
2026-02-13 08:06:59 +00:00
parent a572a68537
commit d0334ddd40
2 changed files with 139 additions and 31 deletions

View File

@@ -610,56 +610,99 @@ func (s *BackupService) extractDatabaseFromBackup(zipPath string) (string, error
_ = r.Close()
}()
var dbEntry *zip.File
var walEntry *zip.File
var shmEntry *zip.File
for _, file := range r.File {
if filepath.Clean(file.Name) != s.DatabaseName {
continue
switch filepath.Clean(file.Name) {
case s.DatabaseName:
dbEntry = file
case s.DatabaseName + "-wal":
walEntry = file
case s.DatabaseName + "-shm":
shmEntry = file
}
}
tmpFile, err := os.CreateTemp("", "charon-restore-db-*.sqlite")
if dbEntry == nil {
return "", fmt.Errorf("database entry %s not found in backup archive", s.DatabaseName)
}
tmpFile, err := os.CreateTemp("", "charon-restore-db-*.sqlite")
if err != nil {
return "", fmt.Errorf("create restore snapshot file: %w", err)
}
tmpPath := tmpFile.Name()
if err := tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("close restore snapshot file: %w", err)
}
extractToPath := func(file *zip.File, destinationPath string) error {
outFile, err := os.OpenFile(destinationPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return "", fmt.Errorf("create restore snapshot file: %w", err)
return fmt.Errorf("open destination file: %w", err)
}
tmpPath := tmpFile.Name()
defer func() {
_ = outFile.Close()
}()
rc, err := file.Open()
if err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Errorf("open database entry in backup archive: %w", err)
return fmt.Errorf("open archive entry: %w", err)
}
defer func() {
_ = rc.Close()
}()
const maxDecompressedSize = 100 * 1024 * 1024 // 100MB
limitedReader := io.LimitReader(rc, maxDecompressedSize+1)
written, err := io.Copy(tmpFile, limitedReader)
written, err := io.Copy(outFile, limitedReader)
if err != nil {
_ = rc.Close()
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Errorf("copy database entry from backup archive: %w", err)
return fmt.Errorf("copy archive entry: %w", err)
}
if written > maxDecompressedSize {
_ = rc.Close()
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Errorf("database entry %s exceeded decompression limit (%d bytes), potential decompression bomb", file.Name, maxDecompressedSize)
return fmt.Errorf("archive entry %s exceeded decompression limit (%d bytes), potential decompression bomb", file.Name, maxDecompressedSize)
}
if err := outFile.Sync(); err != nil {
return fmt.Errorf("sync destination file: %w", err)
}
if err := rc.Close(); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Errorf("close database entry reader: %w", err)
}
if err := tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("close restore snapshot file: %w", err)
}
return tmpPath, nil
return nil
}
return "", fmt.Errorf("database entry %s not found in backup archive", s.DatabaseName)
if err := extractToPath(dbEntry, tmpPath); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("extract database entry from backup archive: %w", err)
}
if walEntry != nil {
walPath := tmpPath + "-wal"
if err := extractToPath(walEntry, walPath); err != nil {
_ = os.Remove(tmpPath)
_ = os.Remove(walPath)
return "", fmt.Errorf("extract wal entry from backup archive: %w", err)
}
if shmEntry != nil {
shmPath := tmpPath + "-shm"
if err := extractToPath(shmEntry, shmPath); err != nil {
logger.Log().WithError(err).Warn("failed to extract sqlite shm entry from backup archive")
}
}
if err := checkpointSQLiteDatabase(tmpPath); err != nil {
_ = os.Remove(tmpPath)
_ = os.Remove(walPath)
_ = os.Remove(tmpPath + "-shm")
return "", fmt.Errorf("checkpoint extracted sqlite wal: %w", err)
}
_ = os.Remove(walPath)
_ = os.Remove(tmpPath + "-shm")
}
return tmpPath, nil
}
func (s *BackupService) unzip(src, dest string) error {

View File

@@ -1,6 +1,8 @@
package services
import (
"archive/zip"
"io"
"os"
"path/filepath"
"testing"
@@ -55,3 +57,66 @@ func TestBackupService_RehydrateLiveDatabase(t *testing.T) {
require.Len(t, restoredUsers, 1)
assert.Equal(t, "restore-user@example.com", restoredUsers[0].Email)
}
func TestBackupService_RehydrateLiveDatabase_FromBackupWithWAL(t *testing.T) {
tmpDir := t.TempDir()
dataDir := filepath.Join(tmpDir, "data")
require.NoError(t, os.MkdirAll(dataDir, 0o700))
dbPath := filepath.Join(dataDir, "charon.db")
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.Exec("PRAGMA journal_mode=WAL").Error)
require.NoError(t, db.Exec("PRAGMA wal_autocheckpoint=0").Error)
require.NoError(t, db.AutoMigrate(&models.User{}))
seedUser := models.User{
UUID: uuid.NewString(),
Email: "restore-from-wal@example.com",
Name: "Restore From WAL",
Role: "user",
Enabled: true,
APIKey: uuid.NewString(),
}
require.NoError(t, db.Create(&seedUser).Error)
walPath := dbPath + "-wal"
_, err = os.Stat(walPath)
require.NoError(t, err)
svc := NewBackupService(&config.Config{DatabasePath: dbPath})
defer svc.Stop()
backupName := "backup_with_wal.zip"
backupPath := filepath.Join(svc.BackupDir, backupName)
backupFile, err := os.Create(backupPath)
require.NoError(t, err)
zipWriter := zip.NewWriter(backupFile)
addFileToZip := func(sourcePath, zipEntryName string) {
sourceFile, openErr := os.Open(sourcePath)
require.NoError(t, openErr)
defer func() {
_ = sourceFile.Close()
}()
zipEntry, createErr := zipWriter.Create(zipEntryName)
require.NoError(t, createErr)
_, copyErr := io.Copy(zipEntry, sourceFile)
require.NoError(t, copyErr)
}
addFileToZip(dbPath, svc.DatabaseName)
addFileToZip(walPath, svc.DatabaseName+"-wal")
require.NoError(t, zipWriter.Close())
require.NoError(t, backupFile.Close())
require.NoError(t, db.Where("1 = 1").Delete(&models.User{}).Error)
require.NoError(t, svc.RestoreBackup(backupName))
require.NoError(t, svc.RehydrateLiveDatabase(db))
var restoredUsers []models.User
require.NoError(t, db.Find(&restoredUsers).Error)
require.Len(t, restoredUsers, 1)
assert.Equal(t, "restore-from-wal@example.com", restoredUsers[0].Email)
}