fix: enhance backup service to support restoration from WAL files and add corresponding tests
This commit is contained in:
@@ -610,56 +610,99 @@ func (s *BackupService) extractDatabaseFromBackup(zipPath string) (string, error
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
var dbEntry *zip.File
|
||||
var walEntry *zip.File
|
||||
var shmEntry *zip.File
|
||||
for _, file := range r.File {
|
||||
if filepath.Clean(file.Name) != s.DatabaseName {
|
||||
continue
|
||||
switch filepath.Clean(file.Name) {
|
||||
case s.DatabaseName:
|
||||
dbEntry = file
|
||||
case s.DatabaseName + "-wal":
|
||||
walEntry = file
|
||||
case s.DatabaseName + "-shm":
|
||||
shmEntry = file
|
||||
}
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "charon-restore-db-*.sqlite")
|
||||
if dbEntry == nil {
|
||||
return "", fmt.Errorf("database entry %s not found in backup archive", s.DatabaseName)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "charon-restore-db-*.sqlite")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create restore snapshot file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("close restore snapshot file: %w", err)
|
||||
}
|
||||
|
||||
extractToPath := func(file *zip.File, destinationPath string) error {
|
||||
outFile, err := os.OpenFile(destinationPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create restore snapshot file: %w", err)
|
||||
return fmt.Errorf("open destination file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer func() {
|
||||
_ = outFile.Close()
|
||||
}()
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("open database entry in backup archive: %w", err)
|
||||
return fmt.Errorf("open archive entry: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = rc.Close()
|
||||
}()
|
||||
|
||||
const maxDecompressedSize = 100 * 1024 * 1024 // 100MB
|
||||
limitedReader := io.LimitReader(rc, maxDecompressedSize+1)
|
||||
written, err := io.Copy(tmpFile, limitedReader)
|
||||
written, err := io.Copy(outFile, limitedReader)
|
||||
if err != nil {
|
||||
_ = rc.Close()
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("copy database entry from backup archive: %w", err)
|
||||
return fmt.Errorf("copy archive entry: %w", err)
|
||||
}
|
||||
|
||||
if written > maxDecompressedSize {
|
||||
_ = rc.Close()
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("database entry %s exceeded decompression limit (%d bytes), potential decompression bomb", file.Name, maxDecompressedSize)
|
||||
return fmt.Errorf("archive entry %s exceeded decompression limit (%d bytes), potential decompression bomb", file.Name, maxDecompressedSize)
|
||||
}
|
||||
if err := outFile.Sync(); err != nil {
|
||||
return fmt.Errorf("sync destination file: %w", err)
|
||||
}
|
||||
|
||||
if err := rc.Close(); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("close database entry reader: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("close restore snapshot file: %w", err)
|
||||
}
|
||||
|
||||
return tmpPath, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("database entry %s not found in backup archive", s.DatabaseName)
|
||||
if err := extractToPath(dbEntry, tmpPath); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("extract database entry from backup archive: %w", err)
|
||||
}
|
||||
|
||||
if walEntry != nil {
|
||||
walPath := tmpPath + "-wal"
|
||||
if err := extractToPath(walEntry, walPath); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
_ = os.Remove(walPath)
|
||||
return "", fmt.Errorf("extract wal entry from backup archive: %w", err)
|
||||
}
|
||||
|
||||
if shmEntry != nil {
|
||||
shmPath := tmpPath + "-shm"
|
||||
if err := extractToPath(shmEntry, shmPath); err != nil {
|
||||
logger.Log().WithError(err).Warn("failed to extract sqlite shm entry from backup archive")
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkpointSQLiteDatabase(tmpPath); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
_ = os.Remove(walPath)
|
||||
_ = os.Remove(tmpPath + "-shm")
|
||||
return "", fmt.Errorf("checkpoint extracted sqlite wal: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Remove(walPath)
|
||||
_ = os.Remove(tmpPath + "-shm")
|
||||
}
|
||||
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) unzip(src, dest string) error {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -55,3 +57,66 @@ func TestBackupService_RehydrateLiveDatabase(t *testing.T) {
|
||||
require.Len(t, restoredUsers, 1)
|
||||
assert.Equal(t, "restore-user@example.com", restoredUsers[0].Email)
|
||||
}
|
||||
|
||||
func TestBackupService_RehydrateLiveDatabase_FromBackupWithWAL(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dataDir := filepath.Join(tmpDir, "data")
|
||||
require.NoError(t, os.MkdirAll(dataDir, 0o700))
|
||||
|
||||
dbPath := filepath.Join(dataDir, "charon.db")
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Exec("PRAGMA journal_mode=WAL").Error)
|
||||
require.NoError(t, db.Exec("PRAGMA wal_autocheckpoint=0").Error)
|
||||
require.NoError(t, db.AutoMigrate(&models.User{}))
|
||||
|
||||
seedUser := models.User{
|
||||
UUID: uuid.NewString(),
|
||||
Email: "restore-from-wal@example.com",
|
||||
Name: "Restore From WAL",
|
||||
Role: "user",
|
||||
Enabled: true,
|
||||
APIKey: uuid.NewString(),
|
||||
}
|
||||
require.NoError(t, db.Create(&seedUser).Error)
|
||||
|
||||
walPath := dbPath + "-wal"
|
||||
_, err = os.Stat(walPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := NewBackupService(&config.Config{DatabasePath: dbPath})
|
||||
defer svc.Stop()
|
||||
|
||||
backupName := "backup_with_wal.zip"
|
||||
backupPath := filepath.Join(svc.BackupDir, backupName)
|
||||
backupFile, err := os.Create(backupPath)
|
||||
require.NoError(t, err)
|
||||
zipWriter := zip.NewWriter(backupFile)
|
||||
|
||||
addFileToZip := func(sourcePath, zipEntryName string) {
|
||||
sourceFile, openErr := os.Open(sourcePath)
|
||||
require.NoError(t, openErr)
|
||||
defer func() {
|
||||
_ = sourceFile.Close()
|
||||
}()
|
||||
|
||||
zipEntry, createErr := zipWriter.Create(zipEntryName)
|
||||
require.NoError(t, createErr)
|
||||
_, copyErr := io.Copy(zipEntry, sourceFile)
|
||||
require.NoError(t, copyErr)
|
||||
}
|
||||
|
||||
addFileToZip(dbPath, svc.DatabaseName)
|
||||
addFileToZip(walPath, svc.DatabaseName+"-wal")
|
||||
require.NoError(t, zipWriter.Close())
|
||||
require.NoError(t, backupFile.Close())
|
||||
|
||||
require.NoError(t, db.Where("1 = 1").Delete(&models.User{}).Error)
|
||||
require.NoError(t, svc.RestoreBackup(backupName))
|
||||
require.NoError(t, svc.RehydrateLiveDatabase(db))
|
||||
|
||||
var restoredUsers []models.User
|
||||
require.NoError(t, db.Find(&restoredUsers).Error)
|
||||
require.Len(t, restoredUsers, 1)
|
||||
assert.Equal(t, "restore-from-wal@example.com", restoredUsers[0].Email)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user