chore: remove cached
This commit is contained in:
@@ -1,151 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/config"
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AuthService struct {
|
||||
db *gorm.DB
|
||||
config config.Config
|
||||
}
|
||||
|
||||
func NewAuthService(db *gorm.DB, cfg config.Config) *AuthService {
|
||||
return &AuthService{db: db, config: cfg}
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func (s *AuthService) Register(email, password, name string) (*models.User, error) {
|
||||
email = strings.ToLower(email)
|
||||
var count int64
|
||||
s.db.Model(&models.User{}).Count(&count)
|
||||
|
||||
role := "user"
|
||||
if count == 0 {
|
||||
role = "admin" // First user is admin
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
UUID: uuid.New().String(),
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: role,
|
||||
APIKey: uuid.New().String(),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := user.SetPassword(password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.db.Create(user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) Login(email, password string) (string, error) {
|
||||
email = strings.ToLower(email)
|
||||
var user models.User
|
||||
if err := s.db.Where("email = ?", email).First(&user).Error; err != nil {
|
||||
return "", errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
if !user.Enabled {
|
||||
return "", errors.New("account disabled")
|
||||
}
|
||||
|
||||
if user.LockedUntil != nil && user.LockedUntil.After(time.Now()) {
|
||||
return "", errors.New("account locked")
|
||||
}
|
||||
|
||||
if !user.CheckPassword(password) {
|
||||
user.FailedLoginAttempts++
|
||||
if user.FailedLoginAttempts >= 5 {
|
||||
lockTime := time.Now().Add(15 * time.Minute)
|
||||
user.LockedUntil = &lockTime
|
||||
}
|
||||
s.db.Save(&user)
|
||||
return "", errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Reset failed attempts
|
||||
user.FailedLoginAttempts = 0
|
||||
user.LockedUntil = nil
|
||||
now := time.Now()
|
||||
user.LastLogin = &now
|
||||
s.db.Save(&user)
|
||||
|
||||
return s.GenerateToken(&user)
|
||||
}
|
||||
|
||||
func (s *AuthService) GenerateToken(user *models.User) (string, error) {
|
||||
expirationTime := time.Now().Add(24 * time.Hour)
|
||||
claims := &Claims{
|
||||
UserID: user.ID,
|
||||
Role: user.Role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expirationTime),
|
||||
Issuer: "cpmp",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.config.JWTSecret))
|
||||
}
|
||||
|
||||
func (s *AuthService) ChangePassword(userID uint, oldPassword, newPassword string) error {
|
||||
var user models.User
|
||||
if err := s.db.First(&user, userID).Error; err != nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
if !user.CheckPassword(oldPassword) {
|
||||
return errors.New("invalid current password")
|
||||
}
|
||||
|
||||
if err := user.SetPassword(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Save(&user).Error
|
||||
}
|
||||
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(s.config.JWTSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) GetUserByID(id uint) (*models.User, error) {
|
||||
var user models.User
|
||||
if err := s.db.First(&user, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/config"
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupAuthTestDB(t *testing.T) *gorm.DB {
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&models.User{}))
|
||||
return db
|
||||
}
|
||||
|
||||
func TestAuthService_Register(t *testing.T) {
|
||||
db := setupAuthTestDB(t)
|
||||
cfg := config.Config{JWTSecret: "test-secret"}
|
||||
service := NewAuthService(db, cfg)
|
||||
|
||||
// Test 1: First user should be admin
|
||||
admin, err := service.Register("admin@example.com", "password123", "Admin User")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "admin", admin.Role)
|
||||
assert.NotEmpty(t, admin.PasswordHash)
|
||||
assert.NotEqual(t, "password123", admin.PasswordHash)
|
||||
|
||||
// Test 2: Second user should be regular user
|
||||
user, err := service.Register("user@example.com", "password123", "Regular User")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "user", user.Role)
|
||||
}
|
||||
|
||||
func TestAuthService_Login(t *testing.T) {
|
||||
db := setupAuthTestDB(t)
|
||||
cfg := config.Config{JWTSecret: "test-secret"}
|
||||
service := NewAuthService(db, cfg)
|
||||
|
||||
// Setup user
|
||||
_, err := service.Register("test@example.com", "password123", "Test User")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test 1: Successful login
|
||||
token, err := service.Login("test@example.com", "password123")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Test 2: Invalid password
|
||||
token, err = service.Login("test@example.com", "wrongpassword")
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, token)
|
||||
assert.Equal(t, "invalid credentials", err.Error())
|
||||
|
||||
// Test 3: Account locking
|
||||
// Fail 4 more times (total 5)
|
||||
for i := 0; i < 4; i++ {
|
||||
_, err = service.Login("test@example.com", "wrongpassword")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Check if locked
|
||||
var user models.User
|
||||
db.Where("email = ?", "test@example.com").First(&user)
|
||||
assert.Equal(t, 5, user.FailedLoginAttempts)
|
||||
assert.NotNil(t, user.LockedUntil)
|
||||
assert.True(t, user.LockedUntil.After(time.Now()))
|
||||
|
||||
// Try login with correct password while locked
|
||||
token, err = service.Login("test@example.com", "password123")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "account locked", err.Error())
|
||||
}
|
||||
|
||||
func TestAuthService_ChangePassword(t *testing.T) {
|
||||
db := setupAuthTestDB(t)
|
||||
cfg := config.Config{JWTSecret: "test-secret"}
|
||||
service := NewAuthService(db, cfg)
|
||||
|
||||
user, err := service.Register("test@example.com", "password123", "Test User")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Success
|
||||
err = service.ChangePassword(user.ID, "password123", "newpassword")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify login with new password
|
||||
_, err = service.Login("test@example.com", "newpassword")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Fail with old password
|
||||
_, err = service.Login("test@example.com", "password123")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Fail with wrong current password
|
||||
err = service.ChangePassword(user.ID, "wrong", "another")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "invalid current password", err.Error())
|
||||
|
||||
// Fail with non-existent user
|
||||
err = service.ChangePassword(999, "password", "new")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateToken(t *testing.T) {
|
||||
db := setupAuthTestDB(t)
|
||||
cfg := config.Config{JWTSecret: "test-secret"}
|
||||
service := NewAuthService(db, cfg)
|
||||
|
||||
user, err := service.Register("test@example.com", "password123", "Test User")
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := service.Login("test@example.com", "password123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Valid token
|
||||
claims, err := service.ValidateToken(token)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, user.ID, claims.UserID)
|
||||
|
||||
// Invalid token
|
||||
_, err = service.ValidateToken("invalid.token.string")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_GetUserByID(t *testing.T) {
|
||||
db := setupAuthTestDB(t)
|
||||
cfg := config.Config{JWTSecret: "test-secret"}
|
||||
service := NewAuthService(db, cfg)
|
||||
|
||||
// Setup user
|
||||
user, err := service.Register("test@example.com", "password123", "Test User")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test 1: Get existing user
|
||||
foundUser, err := service.GetUserByID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, foundUser.ID)
|
||||
assert.Equal(t, user.Email, foundUser.Email)
|
||||
|
||||
// Test 2: Get non-existent user
|
||||
_, err = service.GetUserByID(999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/config"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
type BackupService struct {
|
||||
DataDir string
|
||||
BackupDir string
|
||||
Cron *cron.Cron
|
||||
}
|
||||
|
||||
type BackupFile struct {
|
||||
Filename string `json:"filename"`
|
||||
Size int64 `json:"size"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
func NewBackupService(cfg *config.Config) *BackupService {
|
||||
// Ensure backup directory exists
|
||||
backupDir := filepath.Join(filepath.Dir(cfg.DatabasePath), "backups")
|
||||
if err := os.MkdirAll(backupDir, 0755); err != nil {
|
||||
fmt.Printf("Failed to create backup directory: %v\n", err)
|
||||
}
|
||||
|
||||
s := &BackupService{
|
||||
DataDir: filepath.Dir(cfg.DatabasePath), // e.g. /app/data
|
||||
BackupDir: backupDir,
|
||||
Cron: cron.New(),
|
||||
}
|
||||
|
||||
// Schedule daily backup at 3 AM
|
||||
_, err := s.Cron.AddFunc("0 3 * * *", s.RunScheduledBackup)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to schedule backup: %v\n", err)
|
||||
}
|
||||
s.Cron.Start()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *BackupService) RunScheduledBackup() {
|
||||
fmt.Println("Starting scheduled backup...")
|
||||
if name, err := s.CreateBackup(); err != nil {
|
||||
fmt.Printf("Scheduled backup failed: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("Scheduled backup created: %s\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
// ListBackups returns all backup files sorted by time (newest first)
|
||||
func (s *BackupService) ListBackups() ([]BackupFile, error) {
|
||||
entries, err := os.ReadDir(s.BackupDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backups []BackupFile
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".zip") {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
backups = append(backups, BackupFile{
|
||||
Filename: entry.Name(),
|
||||
Size: info.Size(),
|
||||
Time: info.ModTime(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort newest first
|
||||
sort.Slice(backups, func(i, j int) bool {
|
||||
return backups[i].Time.After(backups[j].Time)
|
||||
})
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
// CreateBackup creates a zip archive of the database and caddy data
|
||||
func (s *BackupService) CreateBackup() (string, error) {
|
||||
timestamp := time.Now().Format("2006-01-02_15-04-05")
|
||||
filename := fmt.Sprintf("backup_%s.zip", timestamp)
|
||||
zipPath := filepath.Join(s.BackupDir, filename)
|
||||
|
||||
outFile, err := os.Create(zipPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
w := zip.NewWriter(outFile)
|
||||
defer w.Close()
|
||||
|
||||
// Files/Dirs to backup
|
||||
// 1. Database
|
||||
dbPath := filepath.Join(s.DataDir, "cpm.db")
|
||||
if err := s.addToZip(w, dbPath, "cpm.db"); err != nil {
|
||||
return "", fmt.Errorf("backup db: %w", err)
|
||||
}
|
||||
|
||||
// 2. Caddy Data (Certificates, etc)
|
||||
// We walk the 'caddy' subdirectory
|
||||
caddyDir := filepath.Join(s.DataDir, "caddy")
|
||||
if err := s.addDirToZip(w, caddyDir, "caddy"); err != nil {
|
||||
// It's possible caddy dir doesn't exist yet, which is fine
|
||||
fmt.Printf("Warning: could not backup caddy dir: %v\n", err)
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) addToZip(w *zip.Writer, srcPath, zipPath string) error {
|
||||
file, err := os.Open(srcPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
f, err := w.Create(zipPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(f, file)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *BackupService) addDirToZip(w *zip.Writer, srcDir, zipBase string) error {
|
||||
return filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(srcDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zipPath := filepath.Join(zipBase, relPath)
|
||||
return s.addToZip(w, path, zipPath)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteBackup removes a backup file
|
||||
func (s *BackupService) DeleteBackup(filename string) error {
|
||||
cleanName := filepath.Base(filename)
|
||||
if filename != cleanName {
|
||||
return fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
path := filepath.Join(s.BackupDir, cleanName)
|
||||
if !strings.HasPrefix(path, filepath.Clean(s.BackupDir)) {
|
||||
return fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// GetBackupPath returns the full path to a backup file (for downloading)
|
||||
func (s *BackupService) GetBackupPath(filename string) (string, error) {
|
||||
cleanName := filepath.Base(filename)
|
||||
if filename != cleanName {
|
||||
return "", fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
path := filepath.Join(s.BackupDir, cleanName)
|
||||
if !strings.HasPrefix(path, filepath.Clean(s.BackupDir)) {
|
||||
return "", fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// RestoreBackup restores the database and caddy data from a zip archive
|
||||
func (s *BackupService) RestoreBackup(filename string) error {
|
||||
cleanName := filepath.Base(filename)
|
||||
if filename != cleanName {
|
||||
return fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
// 1. Verify backup exists
|
||||
srcPath := filepath.Join(s.BackupDir, cleanName)
|
||||
if !strings.HasPrefix(srcPath, filepath.Clean(s.BackupDir)) {
|
||||
return fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
if _, err := os.Stat(srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Unzip to DataDir (overwriting)
|
||||
return s.unzip(srcPath, s.DataDir)
|
||||
}
|
||||
|
||||
func (s *BackupService) unzip(src, dest string) error {
|
||||
r, err := zip.OpenReader(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
for _, f := range r.File {
|
||||
fpath := filepath.Join(dest, f.Name)
|
||||
|
||||
// Check for ZipSlip
|
||||
if !strings.HasPrefix(fpath, filepath.Clean(dest)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("illegal file path: %s", fpath)
|
||||
}
|
||||
|
||||
if f.FileInfo().IsDir() {
|
||||
os.MkdirAll(fpath, os.ModePerm)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
_ = outFile.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(outFile, rc)
|
||||
|
||||
// Check for close errors on writable file
|
||||
if closeErr := outFile.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
rc.Close()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBackupService_CreateAndList(t *testing.T) {
|
||||
// Setup temp dirs
|
||||
tmpDir, err := os.MkdirTemp("", "cpm-backup-service-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dataDir := filepath.Join(tmpDir, "data")
|
||||
err = os.MkdirAll(dataDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create dummy DB
|
||||
dbPath := filepath.Join(dataDir, "cpm.db")
|
||||
err = os.WriteFile(dbPath, []byte("dummy db"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create dummy caddy dir
|
||||
caddyDir := filepath.Join(dataDir, "caddy")
|
||||
err = os.MkdirAll(caddyDir, 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(caddyDir, "caddy.json"), []byte("{}"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{DatabasePath: dbPath}
|
||||
service := NewBackupService(cfg)
|
||||
|
||||
// Test Create
|
||||
filename, err := service.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, filename)
|
||||
assert.FileExists(t, filepath.Join(service.BackupDir, filename))
|
||||
|
||||
// Test List
|
||||
backups, err := service.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, filename, backups[0].Filename)
|
||||
assert.True(t, backups[0].Size > 0)
|
||||
|
||||
// Test GetBackupPath
|
||||
path, err := service.GetBackupPath(filename)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(service.BackupDir, filename), path)
|
||||
|
||||
// Test Restore
|
||||
// Modify DB to verify restore
|
||||
err = os.WriteFile(dbPath, []byte("modified db"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = service.RestoreBackup(filename)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify DB content restored
|
||||
content, err := os.ReadFile(dbPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "dummy db", string(content))
|
||||
|
||||
// Test Delete
|
||||
err = service.DeleteBackup(filename)
|
||||
require.NoError(t, err)
|
||||
assert.NoFileExists(t, filepath.Join(service.BackupDir, filename))
|
||||
|
||||
// Test Delete Non-existent
|
||||
err = service.DeleteBackup("non-existent.zip")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_Restore_ZipSlip(t *testing.T) {
|
||||
// Setup temp dirs
|
||||
tmpDir := t.TempDir()
|
||||
service := &BackupService{
|
||||
DataDir: filepath.Join(tmpDir, "data"),
|
||||
BackupDir: filepath.Join(tmpDir, "backups"),
|
||||
}
|
||||
os.MkdirAll(service.BackupDir, 0755)
|
||||
|
||||
// Create malicious zip
|
||||
zipPath := filepath.Join(service.BackupDir, "malicious.zip")
|
||||
zipFile, err := os.Create(zipPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := zip.NewWriter(zipFile)
|
||||
f, err := w.Create("../../../evil.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = f.Write([]byte("evil"))
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
zipFile.Close()
|
||||
|
||||
// Attempt restore
|
||||
err = service.RestoreBackup("malicious.zip")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "illegal file path")
|
||||
}
|
||||
|
||||
func TestBackupService_PathTraversal(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
service := &BackupService{
|
||||
DataDir: filepath.Join(tmpDir, "data"),
|
||||
BackupDir: filepath.Join(tmpDir, "backups"),
|
||||
}
|
||||
os.MkdirAll(service.BackupDir, 0755)
|
||||
|
||||
// Test GetBackupPath with traversal
|
||||
// Should return error
|
||||
_, err := service.GetBackupPath("../../etc/passwd")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid filename")
|
||||
|
||||
// Test DeleteBackup with traversal
|
||||
// Should return error
|
||||
err = service.DeleteBackup("../../etc/passwd")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid filename")
|
||||
}
|
||||
|
||||
func TestBackupService_RunScheduledBackup(t *testing.T) {
|
||||
// Setup temp dirs
|
||||
tmpDir := t.TempDir()
|
||||
dataDir := filepath.Join(tmpDir, "data")
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
|
||||
// Create dummy DB
|
||||
dbPath := filepath.Join(dataDir, "cpm.db")
|
||||
os.WriteFile(dbPath, []byte("dummy db"), 0644)
|
||||
|
||||
cfg := &config.Config{DatabasePath: dbPath}
|
||||
service := NewBackupService(cfg)
|
||||
|
||||
// Run scheduled backup manually
|
||||
service.RunScheduledBackup()
|
||||
|
||||
// Verify backup created
|
||||
backups, err := service.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
)
|
||||
|
||||
// CertificateInfo represents parsed certificate details.
|
||||
type CertificateInfo struct {
|
||||
ID uint `json:"id,omitempty"`
|
||||
UUID string `json:"uuid,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Domain string `json:"domain"`
|
||||
Issuer string `json:"issuer"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Status string `json:"status"` // "valid", "expiring", "expired"
|
||||
Provider string `json:"provider"` // "letsencrypt", "custom"
|
||||
}
|
||||
|
||||
// CertificateService manages certificate retrieval and parsing.
|
||||
type CertificateService struct {
|
||||
dataDir string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewCertificateService creates a new certificate service.
|
||||
func NewCertificateService(dataDir string, db *gorm.DB) *CertificateService {
|
||||
return &CertificateService{
|
||||
dataDir: dataDir,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// ListCertificates returns both auto-generated and custom certificates.
|
||||
func (s *CertificateService) ListCertificates() ([]CertificateInfo, error) {
|
||||
// First, scan Caddy data directory for auto-generated certificates and persist them.
|
||||
certRoot := filepath.Join(s.dataDir, "certificates")
|
||||
log.Printf("CertificateService: scanning cert directory: %s\n", certRoot)
|
||||
|
||||
foundDomains := map[string]struct{}{}
|
||||
|
||||
// If the cert root does not exist, skip scanning but still return DB entries below
|
||||
if _, err := os.Stat(certRoot); err == nil {
|
||||
_ = filepath.Walk(certRoot, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Printf("CertificateService: walk error for %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".crt") {
|
||||
log.Printf("CertificateService: found cert file: %s\n", path)
|
||||
certData, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
log.Printf("CertificateService: failed to read cert file %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(certData)
|
||||
if block == nil {
|
||||
log.Printf("CertificateService: pem decode failed for %s\n", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("CertificateService: failed to parse cert %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
domain := cert.Subject.CommonName
|
||||
if domain == "" && len(cert.DNSNames) > 0 {
|
||||
domain = cert.DNSNames[0]
|
||||
}
|
||||
if domain == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
foundDomains[domain] = struct{}{}
|
||||
|
||||
// Determine expiry
|
||||
expiresAt := cert.NotAfter
|
||||
|
||||
// Upsert into DB for provider 'letsencrypt'
|
||||
var existing models.SSLCertificate
|
||||
res := s.db.Where("provider = ? AND domains = ?", "letsencrypt", domain).First(&existing)
|
||||
if res.Error != nil {
|
||||
if res.Error == gorm.ErrRecordNotFound {
|
||||
// Create new record
|
||||
now := time.Now()
|
||||
newCert := models.SSLCertificate{
|
||||
UUID: uuid.New().String(),
|
||||
Name: domain,
|
||||
Provider: "letsencrypt",
|
||||
Domains: domain,
|
||||
Certificate: string(certData),
|
||||
PrivateKey: "",
|
||||
ExpiresAt: &expiresAt,
|
||||
AutoRenew: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := s.db.Create(&newCert).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to create DB cert for %s: %v\n", domain, err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("CertificateService: db error querying cert %s: %v\n", domain, res.Error)
|
||||
}
|
||||
} else {
|
||||
// Update expiry/certificate content if changed
|
||||
updated := false
|
||||
existing.ExpiresAt = &expiresAt
|
||||
if existing.Certificate != string(certData) {
|
||||
existing.Certificate = string(certData)
|
||||
updated = true
|
||||
}
|
||||
if updated {
|
||||
existing.UpdatedAt = time.Now()
|
||||
if err := s.db.Save(&existing).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to update DB cert for %s: %v\n", domain, err)
|
||||
}
|
||||
} else {
|
||||
// still update ExpiresAt if needed
|
||||
if err := s.db.Model(&existing).Update("expires_at", &expiresAt).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to update expiry for %s: %v\n", domain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
if os.IsNotExist(err) {
|
||||
log.Printf("CertificateService: cert directory does not exist: %s\n", certRoot)
|
||||
} else {
|
||||
log.Printf("CertificateService: failed to stat cert directory: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete stale DB entries for provider 'letsencrypt' not found on disk
|
||||
var acmeCerts []models.SSLCertificate
|
||||
if err := s.db.Where("provider = ?", "letsencrypt").Find(&acmeCerts).Error; err == nil {
|
||||
for _, c := range acmeCerts {
|
||||
if _, ok := foundDomains[c.Domains]; !ok {
|
||||
// remove stale record
|
||||
if err := s.db.Delete(&models.SSLCertificate{}, "id = ?", c.ID).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to delete stale cert %s: %v\n", c.Domains, err)
|
||||
} else {
|
||||
log.Printf("CertificateService: removed stale DB cert for %s\n", c.Domains)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, fetch all certificates from DB to build the response (includes custom and persisted ACME)
|
||||
certs := []CertificateInfo{}
|
||||
var dbCerts []models.SSLCertificate
|
||||
if err := s.db.Find(&dbCerts).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch certs from DB: %w", err)
|
||||
}
|
||||
|
||||
for _, c := range dbCerts {
|
||||
status := "valid"
|
||||
if c.ExpiresAt != nil {
|
||||
if time.Now().After(*c.ExpiresAt) {
|
||||
status = "expired"
|
||||
} else if time.Now().AddDate(0, 0, 30).After(*c.ExpiresAt) {
|
||||
status = "expiring"
|
||||
}
|
||||
}
|
||||
|
||||
expires := time.Time{}
|
||||
if c.ExpiresAt != nil {
|
||||
expires = *c.ExpiresAt
|
||||
}
|
||||
|
||||
certs = append(certs, CertificateInfo{
|
||||
ID: c.ID,
|
||||
UUID: c.UUID,
|
||||
Name: c.Name,
|
||||
Domain: c.Domains,
|
||||
Issuer: c.Provider,
|
||||
ExpiresAt: expires,
|
||||
Status: status,
|
||||
Provider: c.Provider,
|
||||
})
|
||||
}
|
||||
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// UploadCertificate saves a new custom certificate.
|
||||
func (s *CertificateService) UploadCertificate(name, certPEM, keyPEM string) (*models.SSLCertificate, error) {
|
||||
// Validate PEM
|
||||
block, _ := pem.Decode([]byte(certPEM))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("invalid certificate PEM")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
// Create DB entry
|
||||
sslCert := &models.SSLCertificate{
|
||||
UUID: uuid.New().String(),
|
||||
Name: name,
|
||||
Provider: "custom",
|
||||
Domains: cert.Subject.CommonName, // Or SANs
|
||||
Certificate: certPEM,
|
||||
PrivateKey: keyPEM,
|
||||
ExpiresAt: &cert.NotAfter,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Handle SANs if present
|
||||
if len(cert.DNSNames) > 0 {
|
||||
sslCert.Domains = strings.Join(cert.DNSNames, ",")
|
||||
}
|
||||
|
||||
if err := s.db.Create(sslCert).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sslCert, nil
|
||||
}
|
||||
|
||||
// DeleteCertificate removes a certificate.
|
||||
func (s *CertificateService) DeleteCertificate(id uint) error {
|
||||
var cert models.SSLCertificate
|
||||
if err := s.db.First(&cert, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cert.Provider == "letsencrypt" {
|
||||
// Best-effort file deletion
|
||||
certRoot := filepath.Join(s.dataDir, "certificates")
|
||||
_ = filepath.Walk(certRoot, func(path string, info os.FileInfo, err error) error {
|
||||
if err == nil && !info.IsDir() && strings.HasSuffix(info.Name(), ".crt") {
|
||||
if info.Name() == cert.Domains+".crt" {
|
||||
// Found it
|
||||
log.Printf("CertificateService: deleting ACME cert file %s", path)
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Printf("CertificateService: failed to delete cert file: %v", err)
|
||||
}
|
||||
// Try to delete key as well
|
||||
keyPath := strings.TrimSuffix(path, ".crt") + ".key"
|
||||
if _, err := os.Stat(keyPath); err == nil {
|
||||
os.Remove(keyPath)
|
||||
}
|
||||
// Also try to delete the json meta file
|
||||
jsonPath := strings.TrimSuffix(path, ".crt") + ".json"
|
||||
if _, err := os.Stat(jsonPath); err == nil {
|
||||
os.Remove(jsonPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return s.db.Delete(&models.SSLCertificate{}, "id = ?", id).Error
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
)
|
||||
|
||||
func generateTestCert(t *testing.T, domain string, expiry time.Time) []byte {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: domain,
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: expiry,
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
}
|
||||
|
||||
func TestCertificateService_GetCertificateInfo(t *testing.T) {
|
||||
// Create temp dir
|
||||
tmpDir, err := os.MkdirTemp("", "cert-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Setup in-memory DB
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&models.SSLCertificate{}); err != nil {
|
||||
t.Fatalf("Failed to migrate database: %v", err)
|
||||
}
|
||||
|
||||
cs := NewCertificateService(tmpDir, db)
|
||||
|
||||
// Case 1: Valid Certificate
|
||||
domain := "example.com"
|
||||
expiry := time.Now().Add(24 * time.Hour * 60) // 60 days
|
||||
certPEM := generateTestCert(t, domain, expiry)
|
||||
|
||||
// Create cert directory
|
||||
certDir := filepath.Join(tmpDir, "certificates", "acme-v02.api.letsencrypt.org-directory", domain)
|
||||
err = os.MkdirAll(certDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cert dir: %v", err)
|
||||
}
|
||||
|
||||
certPath := filepath.Join(certDir, domain+".crt")
|
||||
err = os.WriteFile(certPath, certPEM, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write cert file: %v", err)
|
||||
}
|
||||
|
||||
// List Certificates
|
||||
certs, err := cs.ListCertificates()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, certs, 1)
|
||||
if len(certs) > 0 {
|
||||
assert.Equal(t, domain, certs[0].Domain)
|
||||
assert.Equal(t, "valid", certs[0].Status)
|
||||
// Check expiry within a margin
|
||||
assert.WithinDuration(t, expiry, certs[0].ExpiresAt, time.Second)
|
||||
}
|
||||
|
||||
// Case 2: Expired Certificate
|
||||
expiredDomain := "expired.com"
|
||||
expiredExpiry := time.Now().Add(-24 * time.Hour) // Yesterday
|
||||
expiredCertPEM := generateTestCert(t, expiredDomain, expiredExpiry)
|
||||
|
||||
expiredCertDir := filepath.Join(tmpDir, "certificates", "other", expiredDomain)
|
||||
err = os.MkdirAll(expiredCertDir, 0755)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expiredCertPath := filepath.Join(expiredCertDir, expiredDomain+".crt")
|
||||
err = os.WriteFile(expiredCertPath, expiredCertPEM, 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
certs, err = cs.ListCertificates()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, certs, 2)
|
||||
|
||||
// Find the expired one
|
||||
var foundExpired bool
|
||||
for _, c := range certs {
|
||||
if c.Domain == expiredDomain {
|
||||
assert.Equal(t, "expired", c.Status)
|
||||
foundExpired = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundExpired, "Should find expired certificate")
|
||||
}
|
||||
|
||||
func TestCertificateService_UploadAndDelete(t *testing.T) {
|
||||
// Setup
|
||||
tmpDir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&models.SSLCertificate{}))
|
||||
|
||||
cs := NewCertificateService(tmpDir, db)
|
||||
|
||||
// Generate Cert
|
||||
domain := "custom.example.com"
|
||||
expiry := time.Now().Add(24 * time.Hour)
|
||||
certPEM := generateTestCert(t, domain, expiry)
|
||||
keyPEM := []byte("FAKE PRIVATE KEY")
|
||||
|
||||
// Test Upload
|
||||
cert, err := cs.UploadCertificate("My Custom Cert", string(certPEM), string(keyPEM))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cert)
|
||||
assert.Equal(t, "My Custom Cert", cert.Name)
|
||||
assert.Equal(t, "custom", cert.Provider)
|
||||
assert.Equal(t, domain, cert.Domains)
|
||||
|
||||
// Verify it's in List
|
||||
certs, err := cs.ListCertificates()
|
||||
require.NoError(t, err)
|
||||
var found bool
|
||||
for _, c := range certs {
|
||||
if c.ID == cert.ID {
|
||||
found = true
|
||||
assert.Equal(t, "custom", c.Provider)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
// Test Delete
|
||||
err = cs.DeleteCertificate(cert.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone
|
||||
certs, err = cs.ListCertificates()
|
||||
require.NoError(t, err)
|
||||
found = false
|
||||
for _, c := range certs {
|
||||
if c.ID == cert.ID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.False(t, found)
|
||||
}
|
||||
|
||||
func TestCertificateService_Persistence(t *testing.T) {
|
||||
// Setup
|
||||
tmpDir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&models.SSLCertificate{}))
|
||||
|
||||
cs := NewCertificateService(tmpDir, db)
|
||||
|
||||
// 1. Create a fake ACME cert file
|
||||
domain := "persist.example.com"
|
||||
expiry := time.Now().Add(24 * time.Hour)
|
||||
certPEM := generateTestCert(t, domain, expiry)
|
||||
|
||||
certDir := filepath.Join(tmpDir, "certificates", "acme-v02.api.letsencrypt.org-directory", domain)
|
||||
err = os.MkdirAll(certDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPath := filepath.Join(certDir, domain+".crt")
|
||||
err = os.WriteFile(certPath, certPEM, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Call ListCertificates to trigger scan and persistence
|
||||
certs, err := cs.ListCertificates()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's in the returned list
|
||||
var foundInList bool
|
||||
for _, c := range certs {
|
||||
if c.Domain == domain {
|
||||
foundInList = true
|
||||
assert.Equal(t, "letsencrypt", c.Provider)
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInList, "Certificate should be in the returned list")
|
||||
|
||||
// 3. Verify it's in the DB
|
||||
var dbCert models.SSLCertificate
|
||||
err = db.Where("domains = ? AND provider = ?", domain, "letsencrypt").First(&dbCert).Error
|
||||
assert.NoError(t, err, "Certificate should be persisted to DB")
|
||||
assert.Equal(t, domain, dbCert.Name)
|
||||
assert.Equal(t, string(certPEM), dbCert.Certificate)
|
||||
|
||||
// 4. Delete the certificate via Service (which should delete the file)
|
||||
err = cs.DeleteCertificate(dbCert.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file is gone
|
||||
_, err = os.Stat(certPath)
|
||||
assert.True(t, os.IsNotExist(err), "Cert file should be deleted")
|
||||
|
||||
// 5. Call ListCertificates again to trigger cleanup (though DB row is already gone)
|
||||
certs, err = cs.ListCertificates()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's NOT in the returned list
|
||||
foundInList = false
|
||||
for _, c := range certs {
|
||||
if c.Domain == domain {
|
||||
foundInList = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.False(t, foundInList, "Certificate should NOT be in the returned list after deletion")
|
||||
|
||||
// 6. Verify it's gone from the DB
|
||||
err = db.Where("domains = ? AND provider = ?", domain, "letsencrypt").First(&dbCert).Error
|
||||
assert.Error(t, err, "Certificate should be removed from DB")
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
type DockerPort struct {
|
||||
PrivatePort uint16 `json:"private_port"`
|
||||
PublicPort uint16 `json:"public_port"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type DockerContainer struct {
|
||||
ID string `json:"id"`
|
||||
Names []string `json:"names"`
|
||||
Image string `json:"image"`
|
||||
State string `json:"state"`
|
||||
Status string `json:"status"`
|
||||
Network string `json:"network"`
|
||||
IP string `json:"ip"`
|
||||
Ports []DockerPort `json:"ports"`
|
||||
}
|
||||
|
||||
type DockerService struct {
|
||||
client *client.Client
|
||||
}
|
||||
|
||||
func NewDockerService() (*DockerService, error) {
|
||||
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create docker client: %w", err)
|
||||
}
|
||||
return &DockerService{client: cli}, nil
|
||||
}
|
||||
|
||||
func (s *DockerService) ListContainers(ctx context.Context, host string) ([]DockerContainer, error) {
|
||||
var cli *client.Client
|
||||
var err error
|
||||
|
||||
if host == "" || host == "local" {
|
||||
cli = s.client
|
||||
} else {
|
||||
cli, err = client.NewClientWithOpts(client.WithHost(host), client.WithAPIVersionNegotiation())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create remote client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
}
|
||||
|
||||
containers, err := cli.ContainerList(ctx, container.ListOptions{All: false})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list containers: %w", err)
|
||||
}
|
||||
|
||||
var result []DockerContainer
|
||||
for _, c := range containers {
|
||||
// Get the first network's IP address if available
|
||||
networkName := ""
|
||||
ipAddress := ""
|
||||
if c.NetworkSettings != nil && len(c.NetworkSettings.Networks) > 0 {
|
||||
for name, net := range c.NetworkSettings.Networks {
|
||||
networkName = name
|
||||
ipAddress = net.IPAddress
|
||||
break // Just take the first one for now
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up names (remove leading slash)
|
||||
names := make([]string, len(c.Names))
|
||||
for i, name := range c.Names {
|
||||
names[i] = strings.TrimPrefix(name, "/")
|
||||
}
|
||||
|
||||
// Map ports
|
||||
var ports []DockerPort
|
||||
for _, p := range c.Ports {
|
||||
ports = append(ports, DockerPort{
|
||||
PrivatePort: p.PrivatePort,
|
||||
PublicPort: p.PublicPort,
|
||||
Type: p.Type,
|
||||
})
|
||||
}
|
||||
|
||||
result = append(result, DockerContainer{
|
||||
ID: c.ID[:12], // Short ID
|
||||
Names: names,
|
||||
Image: c.Image,
|
||||
State: c.State,
|
||||
Status: c.Status,
|
||||
Network: networkName,
|
||||
IP: ipAddress,
|
||||
Ports: ports,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDockerService_New(t *testing.T) {
|
||||
// This test might fail if docker socket is not available in the build environment
|
||||
// So we just check if it returns error or not, but don't fail the test if it's just "socket not found"
|
||||
// In a real CI environment with Docker-in-Docker, this would work.
|
||||
svc, err := NewDockerService()
|
||||
if err != nil {
|
||||
t.Logf("Skipping DockerService test: %v", err)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, svc)
|
||||
}
|
||||
|
||||
func TestDockerService_ListContainers(t *testing.T) {
|
||||
svc, err := NewDockerService()
|
||||
if err != nil {
|
||||
t.Logf("Skipping DockerService test: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Test local listing
|
||||
containers, err := svc.ListContainers(context.Background(), "")
|
||||
// If we can't connect to docker daemon, this will fail.
|
||||
// We should probably mock the client, but the docker client is an interface?
|
||||
// The official client struct is concrete.
|
||||
// For now, we just assert that if err is nil, containers is a slice.
|
||||
if err == nil {
|
||||
assert.IsType(t, []DockerContainer{}, containers)
|
||||
}
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/config"
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
)
|
||||
|
||||
type LogService struct {
|
||||
LogDir string
|
||||
}
|
||||
|
||||
func NewLogService(cfg *config.Config) *LogService {
|
||||
// Assuming logs are in data/logs relative to app root
|
||||
logDir := filepath.Join(filepath.Dir(cfg.DatabasePath), "logs")
|
||||
return &LogService{LogDir: logDir}
|
||||
}
|
||||
|
||||
type LogFile struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
ModTime string `json:"mod_time"`
|
||||
}
|
||||
|
||||
func (s *LogService) ListLogs() ([]LogFile, error) {
|
||||
entries, err := os.ReadDir(s.LogDir)
|
||||
if err != nil {
|
||||
// If directory doesn't exist, return empty list instead of error
|
||||
if os.IsNotExist(err) {
|
||||
return []LogFile{}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var logs []LogFile
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && (strings.HasSuffix(entry.Name(), ".log") || strings.Contains(entry.Name(), ".log.")) {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logs = append(logs, LogFile{
|
||||
Name: entry.Name(),
|
||||
Size: info.Size(),
|
||||
ModTime: info.ModTime().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// GetLogPath returns the absolute path to a log file if it exists and is valid
|
||||
func (s *LogService) GetLogPath(filename string) (string, error) {
|
||||
cleanName := filepath.Base(filename)
|
||||
if filename != cleanName {
|
||||
return "", fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
path := filepath.Join(s.LogDir, cleanName)
|
||||
if !strings.HasPrefix(path, filepath.Clean(s.LogDir)) {
|
||||
return "", fmt.Errorf("invalid filename: path traversal attempt detected")
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// QueryLogs parses and filters logs from a specific file
|
||||
func (s *LogService) QueryLogs(filename string, filter models.LogFilter) ([]models.CaddyAccessLog, int64, error) {
|
||||
path, err := s.GetLogPath(filename)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var logs []models.CaddyAccessLog
|
||||
var totalMatches int64 = 0
|
||||
|
||||
// Read file line by line
|
||||
// TODO: For large files, reading from end or indexing would be better
|
||||
// Current implementation reads all lines, filters, then paginates
|
||||
// This is acceptable for rotated logs (max 10MB)
|
||||
scanner := bufio.NewScanner(file)
|
||||
|
||||
// We'll store all matching logs first, then slice for pagination
|
||||
// This is memory intensive for very large matches but ensures correct sorting/filtering
|
||||
// Since we want latest first, we'll prepend or reverse later.
|
||||
// Actually, appending and then reversing is better.
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var entry models.CaddyAccessLog
|
||||
if err := json.Unmarshal([]byte(line), &entry); err != nil {
|
||||
// Handle non-JSON logs (like cpmp.log)
|
||||
// Try to parse standard Go log format: "2006/01/02 15:04:05 msg"
|
||||
parts := strings.SplitN(line, " ", 3)
|
||||
if len(parts) >= 3 {
|
||||
// Try parsing date/time
|
||||
ts, err := time.Parse("2006/01/02 15:04:05", parts[0]+" "+parts[1])
|
||||
if err == nil {
|
||||
entry.Ts = float64(ts.Unix())
|
||||
entry.Msg = parts[2]
|
||||
} else {
|
||||
entry.Msg = line
|
||||
}
|
||||
} else {
|
||||
entry.Msg = line
|
||||
}
|
||||
entry.Level = "INFO" // Default level for plain logs
|
||||
}
|
||||
|
||||
if s.matchesFilter(entry, filter) {
|
||||
logs = append(logs, entry)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Reverse logs to show newest first (default) unless sort is asc
|
||||
if filter.Sort != "asc" {
|
||||
for i, j := 0, len(logs)-1; i < j; i, j = i+1, j-1 {
|
||||
logs[i], logs[j] = logs[j], logs[i]
|
||||
}
|
||||
}
|
||||
|
||||
totalMatches = int64(len(logs))
|
||||
|
||||
// Apply pagination
|
||||
start := filter.Offset
|
||||
end := start + filter.Limit
|
||||
|
||||
if start >= len(logs) {
|
||||
return []models.CaddyAccessLog{}, totalMatches, nil
|
||||
}
|
||||
if end > len(logs) {
|
||||
end = len(logs)
|
||||
}
|
||||
|
||||
return logs[start:end], totalMatches, nil
|
||||
}
|
||||
|
||||
func (s *LogService) matchesFilter(entry models.CaddyAccessLog, filter models.LogFilter) bool {
|
||||
// Status Filter
|
||||
if filter.Status != "" {
|
||||
statusStr := strconv.Itoa(entry.Status)
|
||||
if strings.HasSuffix(filter.Status, "xx") {
|
||||
// Handle 2xx, 4xx, 5xx
|
||||
prefix := filter.Status[:1]
|
||||
if !strings.HasPrefix(statusStr, prefix) {
|
||||
return false
|
||||
}
|
||||
} else if statusStr != filter.Status {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Level Filter
|
||||
if filter.Level != "" {
|
||||
if !strings.EqualFold(entry.Level, filter.Level) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Host Filter
|
||||
if filter.Host != "" {
|
||||
if !strings.Contains(strings.ToLower(entry.Request.Host), strings.ToLower(filter.Host)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Search Filter (generic text search)
|
||||
if filter.Search != "" {
|
||||
term := strings.ToLower(filter.Search)
|
||||
// Search in common fields
|
||||
if !strings.Contains(strings.ToLower(entry.Request.URI), term) &&
|
||||
!strings.Contains(strings.ToLower(entry.Request.Method), term) &&
|
||||
!strings.Contains(strings.ToLower(entry.Request.RemoteIP), term) &&
|
||||
!strings.Contains(strings.ToLower(entry.Msg), term) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/config"
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogService(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "cpm-log-service-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dataDir := filepath.Join(tmpDir, "data")
|
||||
logsDir := filepath.Join(dataDir, "logs")
|
||||
err = os.MkdirAll(logsDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create sample JSON logs
|
||||
logEntry1 := models.CaddyAccessLog{
|
||||
Level: "info",
|
||||
Ts: 1600000000,
|
||||
Msg: "request handled",
|
||||
Status: 200,
|
||||
}
|
||||
logEntry1.Request.Method = "GET"
|
||||
logEntry1.Request.Host = "example.com"
|
||||
logEntry1.Request.URI = "/"
|
||||
logEntry1.Request.RemoteIP = "1.2.3.4"
|
||||
|
||||
logEntry2 := models.CaddyAccessLog{
|
||||
Level: "error",
|
||||
Ts: 1600000060,
|
||||
Msg: "error handled",
|
||||
Status: 500,
|
||||
}
|
||||
logEntry2.Request.Method = "POST"
|
||||
logEntry2.Request.Host = "api.example.com"
|
||||
logEntry2.Request.URI = "/submit"
|
||||
logEntry2.Request.RemoteIP = "5.6.7.8"
|
||||
|
||||
line1, _ := json.Marshal(logEntry1)
|
||||
line2, _ := json.Marshal(logEntry2)
|
||||
|
||||
content := string(line1) + "\n" + string(line2) + "\n"
|
||||
|
||||
err = os.WriteFile(filepath.Join(logsDir, "access.log"), []byte(content), 0644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(logsDir, "other.txt"), []byte("ignore me"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{DatabasePath: filepath.Join(dataDir, "cpm.db")}
|
||||
service := NewLogService(cfg)
|
||||
|
||||
// Test List
|
||||
logs, err := service.ListLogs()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, logs, 1)
|
||||
assert.Equal(t, "access.log", logs[0].Name)
|
||||
|
||||
// Test QueryLogs - All
|
||||
results, total, err := service.QueryLogs("access.log", models.LogFilter{Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), total)
|
||||
assert.Len(t, results, 2)
|
||||
// Should be reversed (newest first)
|
||||
assert.Equal(t, 500, results[0].Status)
|
||||
assert.Equal(t, 200, results[1].Status)
|
||||
|
||||
// Test QueryLogs - Filter Status
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Status: "5xx", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, 500, results[0].Status)
|
||||
|
||||
// Test QueryLogs - Filter Host
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Host: "api.example.com", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "api.example.com", results[0].Request.Host)
|
||||
|
||||
// Test QueryLogs - Search
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Search: "submit", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "/submit", results[0].Request.URI)
|
||||
|
||||
// Test GetLogPath
|
||||
path, err := service.GetLogPath("access.log")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(logsDir, "access.log"), path)
|
||||
|
||||
// Test GetLogPath non-existent
|
||||
_, err = service.GetLogPath("missing.log")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test GetLogPath - Invalid
|
||||
_, err = service.GetLogPath("nonexistent.log")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test GetLogPath - Traversal
|
||||
_, err = service.GetLogPath("../../etc/passwd")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid filename")
|
||||
|
||||
// Test ListLogs - Directory Not Exist
|
||||
nonExistService := NewLogService(&config.Config{DatabasePath: filepath.Join(t.TempDir(), "missing", "cpm.db")})
|
||||
logs, err = nonExistService.ListLogs()
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, logs)
|
||||
|
||||
// Test QueryLogs - Non-JSON Logs
|
||||
plainContent := "2023/10/27 10:00:00 Application started\nJust a plain line\n"
|
||||
err = os.WriteFile(filepath.Join(logsDir, "app.log"), []byte(plainContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
results, total, err = service.QueryLogs("app.log", models.LogFilter{Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), total)
|
||||
// Reverse order check
|
||||
assert.Equal(t, "Just a plain line", results[0].Msg)
|
||||
assert.Equal(t, "Application started", results[1].Msg)
|
||||
assert.Equal(t, "INFO", results[1].Level)
|
||||
|
||||
// Test QueryLogs - Pagination
|
||||
// We have 2 logs in access.log
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Limit: 1, Offset: 0})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, 500, results[0].Status) // Newest first
|
||||
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Limit: 1, Offset: 1})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, 200, results[0].Status) // Second newest
|
||||
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Limit: 10, Offset: 5})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, results)
|
||||
|
||||
// Test QueryLogs - Exact Status Match
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Status: "200", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Equal(t, 200, results[0].Status)
|
||||
|
||||
// Test QueryLogs - Search Fields
|
||||
// Search Method
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Search: "POST", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Equal(t, "POST", results[0].Request.Method)
|
||||
|
||||
// Search RemoteIP
|
||||
results, total, err = service.QueryLogs("access.log", models.LogFilter{Search: "5.6.7.8", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Equal(t, "5.6.7.8", results[0].Request.RemoteIP)
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/containrrr/shoutrrr"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type NotificationService struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func NewNotificationService(db *gorm.DB) *NotificationService {
|
||||
return &NotificationService{DB: db}
|
||||
}
|
||||
|
||||
var discordWebhookRegex = regexp.MustCompile(`^https://discord(?:app)?\.com/api/webhooks/(\d+)/([a-zA-Z0-9_-]+)`)
|
||||
|
||||
func normalizeURL(serviceType, rawURL string) string {
|
||||
if serviceType == "discord" {
|
||||
matches := discordWebhookRegex.FindStringSubmatch(rawURL)
|
||||
if len(matches) == 3 {
|
||||
id := matches[1]
|
||||
token := matches[2]
|
||||
return fmt.Sprintf("discord://%s@%s", token, id)
|
||||
}
|
||||
}
|
||||
return rawURL
|
||||
}
|
||||
|
||||
// Internal Notifications (DB)
|
||||
|
||||
func (s *NotificationService) Create(nType models.NotificationType, title, message string) (*models.Notification, error) {
|
||||
notification := &models.Notification{
|
||||
Type: nType,
|
||||
Title: title,
|
||||
Message: message,
|
||||
Read: false,
|
||||
}
|
||||
result := s.DB.Create(notification)
|
||||
return notification, result.Error
|
||||
}
|
||||
|
||||
func (s *NotificationService) List(unreadOnly bool) ([]models.Notification, error) {
|
||||
var notifications []models.Notification
|
||||
query := s.DB.Order("created_at desc")
|
||||
if unreadOnly {
|
||||
query = query.Where("read = ?", false)
|
||||
}
|
||||
result := query.Find(¬ifications)
|
||||
return notifications, result.Error
|
||||
}
|
||||
|
||||
func (s *NotificationService) MarkAsRead(id string) error {
|
||||
return s.DB.Model(&models.Notification{}).Where("id = ?", id).Update("read", true).Error
|
||||
}
|
||||
|
||||
func (s *NotificationService) MarkAllAsRead() error {
|
||||
return s.DB.Model(&models.Notification{}).Where("read = ?", false).Update("read", true).Error
|
||||
}
|
||||
|
||||
// External Notifications (Shoutrrr & Custom Webhooks)
|
||||
|
||||
func (s *NotificationService) SendExternal(eventType, title, message string, data map[string]interface{}) {
|
||||
var providers []models.NotificationProvider
|
||||
if err := s.DB.Where("enabled = ?", true).Find(&providers).Error; err != nil {
|
||||
log.Printf("Failed to fetch notification providers: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare data for templates
|
||||
if data == nil {
|
||||
data = make(map[string]interface{})
|
||||
}
|
||||
data["Title"] = title
|
||||
data["Message"] = message
|
||||
data["Time"] = time.Now().Format(time.RFC3339)
|
||||
data["EventType"] = eventType
|
||||
|
||||
for _, provider := range providers {
|
||||
// Filter based on preferences
|
||||
shouldSend := false
|
||||
switch eventType {
|
||||
case "proxy_host":
|
||||
shouldSend = provider.NotifyProxyHosts
|
||||
case "remote_server":
|
||||
shouldSend = provider.NotifyRemoteServers
|
||||
case "domain":
|
||||
shouldSend = provider.NotifyDomains
|
||||
case "cert":
|
||||
shouldSend = provider.NotifyCerts
|
||||
case "uptime":
|
||||
shouldSend = provider.NotifyUptime
|
||||
case "test":
|
||||
shouldSend = true
|
||||
default:
|
||||
// Default to true for unknown types or generic messages?
|
||||
// Or false to be safe? Let's say true for now to avoid missing things,
|
||||
// or maybe we should enforce types.
|
||||
shouldSend = true
|
||||
}
|
||||
|
||||
if !shouldSend {
|
||||
continue
|
||||
}
|
||||
|
||||
go func(p models.NotificationProvider) {
|
||||
if p.Type == "webhook" {
|
||||
s.sendCustomWebhook(p, data)
|
||||
} else {
|
||||
url := normalizeURL(p.Type, p.URL)
|
||||
if err := shoutrrr.Send(url, fmt.Sprintf("%s: %s", title, message)); err != nil {
|
||||
log.Printf("Failed to send notification to %s: %v", p.Name, err)
|
||||
}
|
||||
}
|
||||
}(provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NotificationService) sendCustomWebhook(p models.NotificationProvider, data map[string]interface{}) {
|
||||
// Default template if empty
|
||||
tmplStr := p.Config
|
||||
if tmplStr == "" {
|
||||
tmplStr = `{"content": "{{.Title}}: {{.Message}}"}`
|
||||
}
|
||||
|
||||
// Parse template
|
||||
tmpl, err := template.New("webhook").Parse(tmplStr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse webhook template for %s: %v", p.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
if err := tmpl.Execute(&body, data); err != nil {
|
||||
log.Printf("Failed to execute webhook template for %s: %v", p.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Send Request
|
||||
resp, err := http.Post(p.URL, "application/json", &body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to send webhook to %s: %v", p.Name, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("Webhook %s returned status: %d", p.Name, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NotificationService) TestProvider(provider models.NotificationProvider) error {
|
||||
if provider.Type == "webhook" {
|
||||
data := map[string]interface{}{
|
||||
"Title": "Test Notification",
|
||||
"Message": "This is a test notification from CaddyProxyManager+",
|
||||
"Status": "TEST",
|
||||
"Name": "Test Monitor",
|
||||
"Latency": 123,
|
||||
"Time": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
s.sendCustomWebhook(provider, data)
|
||||
return nil
|
||||
}
|
||||
url := normalizeURL(provider.Type, provider.URL)
|
||||
return shoutrrr.Send(url, "Test notification from CaddyProxyManager+")
|
||||
}
|
||||
|
||||
// Provider Management
|
||||
|
||||
func (s *NotificationService) ListProviders() ([]models.NotificationProvider, error) {
|
||||
var providers []models.NotificationProvider
|
||||
result := s.DB.Find(&providers)
|
||||
return providers, result.Error
|
||||
}
|
||||
|
||||
func (s *NotificationService) CreateProvider(provider *models.NotificationProvider) error {
|
||||
return s.DB.Create(provider).Error
|
||||
}
|
||||
|
||||
func (s *NotificationService) UpdateProvider(provider *models.NotificationProvider) error {
|
||||
return s.DB.Save(provider).Error
|
||||
}
|
||||
|
||||
func (s *NotificationService) DeleteProvider(id string) error {
|
||||
return s.DB.Delete(&models.NotificationProvider{}, "id = ?", id).Error
|
||||
}
|
||||
@@ -1,266 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupNotificationTestDB(t *testing.T) *gorm.DB {
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
db.AutoMigrate(&models.Notification{}, &models.NotificationProvider{})
|
||||
return db
|
||||
}
|
||||
|
||||
func TestNotificationService_Create(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
notif, err := svc.Create(models.NotificationTypeInfo, "Test", "Message")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test", notif.Title)
|
||||
assert.Equal(t, "Message", notif.Message)
|
||||
assert.False(t, notif.Read)
|
||||
}
|
||||
|
||||
func TestNotificationService_List(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
svc.Create(models.NotificationTypeInfo, "N1", "M1")
|
||||
svc.Create(models.NotificationTypeInfo, "N2", "M2")
|
||||
|
||||
list, err := svc.List(false)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list, 2)
|
||||
|
||||
// Mark one as read
|
||||
db.Model(&models.Notification{}).Where("title = ?", "N1").Update("read", true)
|
||||
|
||||
listUnread, err := svc.List(true)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, listUnread, 1)
|
||||
assert.Equal(t, "N2", listUnread[0].Title)
|
||||
}
|
||||
|
||||
func TestNotificationService_MarkAsRead(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
notif, _ := svc.Create(models.NotificationTypeInfo, "N1", "M1")
|
||||
|
||||
err := svc.MarkAsRead(fmt.Sprintf("%s", notif.ID))
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.Notification
|
||||
db.First(&updated, "id = ?", notif.ID)
|
||||
assert.True(t, updated.Read)
|
||||
}
|
||||
|
||||
func TestNotificationService_MarkAllAsRead(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
svc.Create(models.NotificationTypeInfo, "N1", "M1")
|
||||
svc.Create(models.NotificationTypeInfo, "N2", "M2")
|
||||
|
||||
err := svc.MarkAllAsRead()
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int64
|
||||
db.Model(&models.Notification{}).Where("read = ?", false).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestNotificationService_Providers(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
// Create
|
||||
provider := models.NotificationProvider{
|
||||
Name: "Discord",
|
||||
Type: "discord",
|
||||
URL: "http://example.com",
|
||||
}
|
||||
err := svc.CreateProvider(&provider)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, provider.ID)
|
||||
assert.Equal(t, "Discord", provider.Name)
|
||||
|
||||
// List
|
||||
list, err := svc.ListProviders()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list, 1)
|
||||
|
||||
// Update
|
||||
provider.Name = "Discord Updated"
|
||||
err = svc.UpdateProvider(&provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Discord Updated", provider.Name)
|
||||
|
||||
// Delete
|
||||
err = svc.DeleteProvider(provider.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
list, err = svc.ListProviders()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list, 0)
|
||||
}
|
||||
|
||||
func TestNotificationService_TestProvider_Webhook(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
// Start a test server
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
assert.Equal(t, "Test Notification", body["Title"])
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := models.NotificationProvider{
|
||||
Name: "Test Webhook",
|
||||
Type: "webhook",
|
||||
URL: ts.URL,
|
||||
Config: `{"Title": "{{.Title}}"}`,
|
||||
}
|
||||
|
||||
err := svc.TestProvider(provider)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNotificationService_SendExternal(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
received := make(chan struct{})
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(received)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := models.NotificationProvider{
|
||||
Name: "Test Webhook",
|
||||
Type: "webhook",
|
||||
URL: ts.URL,
|
||||
Enabled: true,
|
||||
NotifyProxyHosts: true,
|
||||
}
|
||||
svc.CreateProvider(&provider)
|
||||
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
|
||||
select {
|
||||
case <-received:
|
||||
// Success
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Timed out waiting for webhook")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationService_SendExternal_Filtered(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
received := make(chan struct{})
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(received)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := models.NotificationProvider{
|
||||
Name: "Test Webhook",
|
||||
Type: "webhook",
|
||||
URL: ts.URL,
|
||||
Enabled: true,
|
||||
NotifyProxyHosts: false, // Disabled
|
||||
}
|
||||
svc.CreateProvider(&provider)
|
||||
// Force update to false because GORM default tag might override zero value (false) on Create
|
||||
db.Model(&provider).Update("notify_proxy_hosts", false)
|
||||
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
|
||||
select {
|
||||
case <-received:
|
||||
t.Fatal("Should not have received webhook")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Success (timeout expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationService_SendExternal_Shoutrrr(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
provider := models.NotificationProvider{
|
||||
Name: "Test Discord",
|
||||
Type: "discord",
|
||||
URL: "discord://token@id",
|
||||
Enabled: true,
|
||||
NotifyProxyHosts: true,
|
||||
}
|
||||
svc.CreateProvider(&provider)
|
||||
|
||||
// This will log an error but should cover the code path
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
|
||||
// Give it a moment to run goroutine
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestNormalizeURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serviceType string
|
||||
rawURL string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Discord HTTPS",
|
||||
serviceType: "discord",
|
||||
rawURL: "https://discord.com/api/webhooks/123456789/abcdefg",
|
||||
expected: "discord://abcdefg@123456789",
|
||||
},
|
||||
{
|
||||
name: "Discord HTTPS with app",
|
||||
serviceType: "discord",
|
||||
rawURL: "https://discordapp.com/api/webhooks/123456789/abcdefg",
|
||||
expected: "discord://abcdefg@123456789",
|
||||
},
|
||||
{
|
||||
name: "Discord Shoutrrr",
|
||||
serviceType: "discord",
|
||||
rawURL: "discord://token@id",
|
||||
expected: "discord://token@id",
|
||||
},
|
||||
{
|
||||
name: "Other Service",
|
||||
serviceType: "slack",
|
||||
rawURL: "https://hooks.slack.com/services/...",
|
||||
expected: "https://hooks.slack.com/services/...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeURL(tt.serviceType, tt.rawURL)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
)
|
||||
|
||||
// ProxyHostService encapsulates business logic for proxy host management.
|
||||
type ProxyHostService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewProxyHostService creates a new proxy host service.
|
||||
func NewProxyHostService(db *gorm.DB) *ProxyHostService {
|
||||
return &ProxyHostService{db: db}
|
||||
}
|
||||
|
||||
// ValidateUniqueDomain ensures no duplicate domains exist before creation/update.
|
||||
func (s *ProxyHostService) ValidateUniqueDomain(domainNames string, excludeID uint) error {
|
||||
var count int64
|
||||
query := s.db.Model(&models.ProxyHost{}).Where("domain_names = ?", domainNames)
|
||||
|
||||
if excludeID > 0 {
|
||||
query = query.Where("id != ?", excludeID)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
return fmt.Errorf("checking domain uniqueness: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return errors.New("domain already exists")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create validates and creates a new proxy host.
|
||||
func (s *ProxyHostService) Create(host *models.ProxyHost) error {
|
||||
if err := s.ValidateUniqueDomain(host.DomainNames, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Create(host).Error
|
||||
}
|
||||
|
||||
// Update validates and updates an existing proxy host.
|
||||
func (s *ProxyHostService) Update(host *models.ProxyHost) error {
|
||||
if err := s.ValidateUniqueDomain(host.DomainNames, host.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Save(host).Error
|
||||
}
|
||||
|
||||
// Delete removes a proxy host.
|
||||
func (s *ProxyHostService) Delete(id uint) error {
|
||||
return s.db.Delete(&models.ProxyHost{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID retrieves a proxy host by ID.
|
||||
func (s *ProxyHostService) GetByID(id uint) (*models.ProxyHost, error) {
|
||||
var host models.ProxyHost
|
||||
if err := s.db.First(&host, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
// GetByUUID finds a proxy host by UUID.
|
||||
func (s *ProxyHostService) GetByUUID(uuid string) (*models.ProxyHost, error) {
|
||||
var host models.ProxyHost
|
||||
if err := s.db.Preload("Locations").Preload("Certificate").Where("uuid = ?", uuid).First(&host).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
// List returns all proxy hosts.
|
||||
func (s *ProxyHostService) List() ([]models.ProxyHost, error) {
|
||||
var hosts []models.ProxyHost
|
||||
if err := s.db.Preload("Locations").Preload("Certificate").Order("updated_at desc").Find(&hosts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
// TestConnection attempts to connect to the target host and port.
|
||||
func (s *ProxyHostService) TestConnection(host string, port int) error {
|
||||
if host == "" || port <= 0 {
|
||||
return errors.New("invalid host or port")
|
||||
}
|
||||
|
||||
target := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
conn, err := net.DialTimeout("tcp", target, 3*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connection failed: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupProxyHostTestDB(t *testing.T) *gorm.DB {
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&models.ProxyHost{}, &models.Location{}))
|
||||
return db
|
||||
}
|
||||
|
||||
func TestProxyHostService_ValidateUniqueDomain(t *testing.T) {
|
||||
db := setupProxyHostTestDB(t)
|
||||
service := NewProxyHostService(db)
|
||||
|
||||
// Create existing host
|
||||
existing := &models.ProxyHost{
|
||||
DomainNames: "example.com",
|
||||
ForwardHost: "127.0.0.1",
|
||||
ForwardPort: 8080,
|
||||
}
|
||||
require.NoError(t, db.Create(existing).Error)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domainNames string
|
||||
excludeID uint
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "New unique domain",
|
||||
domainNames: "new.example.com",
|
||||
excludeID: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Duplicate domain",
|
||||
domainNames: "example.com",
|
||||
excludeID: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Same domain but excluded ID (update self)",
|
||||
domainNames: "example.com",
|
||||
excludeID: existing.ID,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := service.ValidateUniqueDomain(tt.domainNames, tt.excludeID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHostService_CRUD(t *testing.T) {
|
||||
db := setupProxyHostTestDB(t)
|
||||
service := NewProxyHostService(db)
|
||||
|
||||
// Create
|
||||
host := &models.ProxyHost{
|
||||
UUID: "uuid-1",
|
||||
DomainNames: "test.example.com",
|
||||
ForwardHost: "127.0.0.1",
|
||||
ForwardPort: 8080,
|
||||
}
|
||||
err := service.Create(host)
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, host.ID)
|
||||
|
||||
// Create Duplicate
|
||||
dup := &models.ProxyHost{
|
||||
UUID: "uuid-2",
|
||||
DomainNames: "test.example.com",
|
||||
ForwardHost: "127.0.0.1",
|
||||
ForwardPort: 8081,
|
||||
}
|
||||
err = service.Create(dup)
|
||||
assert.Error(t, err)
|
||||
|
||||
// GetByID
|
||||
fetched, err := service.GetByID(host.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, host.DomainNames, fetched.DomainNames)
|
||||
|
||||
// GetByUUID
|
||||
fetchedUUID, err := service.GetByUUID(host.UUID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, host.ID, fetchedUUID.ID)
|
||||
|
||||
// Update
|
||||
host.ForwardPort = 9090
|
||||
err = service.Update(host)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fetched, err = service.GetByID(host.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 9090, fetched.ForwardPort)
|
||||
|
||||
// Update Duplicate
|
||||
host2 := &models.ProxyHost{
|
||||
UUID: "uuid-3",
|
||||
DomainNames: "other.example.com",
|
||||
ForwardHost: "127.0.0.1",
|
||||
ForwardPort: 8080,
|
||||
}
|
||||
service.Create(host2)
|
||||
|
||||
host.DomainNames = "other.example.com" // Conflict with host2
|
||||
err = service.Update(host)
|
||||
assert.Error(t, err)
|
||||
|
||||
// List
|
||||
hosts, err := service.List()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 2)
|
||||
|
||||
// Delete
|
||||
err = service.Delete(host.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = service.GetByID(host.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyHostService_TestConnection(t *testing.T) {
|
||||
db := setupProxyHostTestDB(t)
|
||||
service := NewProxyHostService(db)
|
||||
|
||||
// 1. Invalid Input
|
||||
err := service.TestConnection("", 80)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid host or port")
|
||||
|
||||
err = service.TestConnection("example.com", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid host or port")
|
||||
|
||||
// 2. Connection Failure (Unreachable)
|
||||
err = service.TestConnection("localhost", 54321)
|
||||
assert.Error(t, err)
|
||||
|
||||
// 3. Connection Success
|
||||
// Start a local listener
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
addr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
err = service.TestConnection(addr.IP.String(), addr.Port)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
)
|
||||
|
||||
// RemoteServerService encapsulates business logic for remote server management.
|
||||
type RemoteServerService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewRemoteServerService creates a new remote server service.
|
||||
func NewRemoteServerService(db *gorm.DB) *RemoteServerService {
|
||||
return &RemoteServerService{db: db}
|
||||
}
|
||||
|
||||
// ValidateUniqueServer ensures no duplicate name+host+port combinations.
|
||||
func (s *RemoteServerService) ValidateUniqueServer(name, host string, port int, excludeID uint) error {
|
||||
var count int64
|
||||
query := s.db.Model(&models.RemoteServer{}).Where("name = ? OR (host = ? AND port = ?)", name, host, port)
|
||||
|
||||
if excludeID > 0 {
|
||||
query = query.Where("id != ?", excludeID)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
return fmt.Errorf("checking server uniqueness: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return errors.New("server with same name or host:port already exists")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create validates and creates a new remote server.
|
||||
func (s *RemoteServerService) Create(server *models.RemoteServer) error {
|
||||
if err := s.ValidateUniqueServer(server.Name, server.Host, server.Port, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Create(server).Error
|
||||
}
|
||||
|
||||
// Update validates and updates an existing remote server.
|
||||
func (s *RemoteServerService) Update(server *models.RemoteServer) error {
|
||||
if err := s.ValidateUniqueServer(server.Name, server.Host, server.Port, server.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Save(server).Error
|
||||
}
|
||||
|
||||
// Delete removes a remote server.
|
||||
func (s *RemoteServerService) Delete(id uint) error {
|
||||
return s.db.Delete(&models.RemoteServer{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID retrieves a remote server by ID.
|
||||
func (s *RemoteServerService) GetByID(id uint) (*models.RemoteServer, error) {
|
||||
var server models.RemoteServer
|
||||
if err := s.db.First(&server, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
// GetByUUID retrieves a remote server by UUID.
|
||||
func (s *RemoteServerService) GetByUUID(uuid string) (*models.RemoteServer, error) {
|
||||
var server models.RemoteServer
|
||||
if err := s.db.Where("uuid = ?", uuid).First(&server).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
// List retrieves all remote servers, optionally filtering by enabled status.
|
||||
func (s *RemoteServerService) List(enabledOnly bool) ([]models.RemoteServer, error) {
|
||||
var servers []models.RemoteServer
|
||||
query := s.db
|
||||
|
||||
if enabledOnly {
|
||||
query = query.Where("enabled = ?", true)
|
||||
}
|
||||
|
||||
if err := query.Order("name ASC").Find(&servers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupRemoteServerTestDB(t *testing.T) *gorm.DB {
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared&mode=memory"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&models.RemoteServer{}))
|
||||
// Clear table
|
||||
db.Exec("DELETE FROM remote_servers")
|
||||
return db
|
||||
}
|
||||
|
||||
func TestRemoteServerService_ValidateUniqueServer(t *testing.T) {
|
||||
db := setupRemoteServerTestDB(t)
|
||||
service := NewRemoteServerService(db)
|
||||
|
||||
// Create existing server
|
||||
existing := &models.RemoteServer{
|
||||
Name: "Existing Server",
|
||||
Host: "192.168.1.100",
|
||||
Port: 8080,
|
||||
}
|
||||
require.NoError(t, db.Create(existing).Error)
|
||||
|
||||
// Test 1: Duplicate Name
|
||||
err := service.ValidateUniqueServer("Existing Server", "192.168.1.101", 9090, 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
|
||||
// Test 2: Duplicate Host:Port
|
||||
err = service.ValidateUniqueServer("New Name", "192.168.1.100", 8080, 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
|
||||
// Test 3: New Server
|
||||
err = service.ValidateUniqueServer("New Server", "192.168.1.101", 8080, 0)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test 4: Update existing (exclude self)
|
||||
err = service.ValidateUniqueServer("Existing Server", "192.168.1.100", 8080, existing.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRemoteServerService_CRUD(t *testing.T) {
|
||||
db := setupRemoteServerTestDB(t)
|
||||
service := NewRemoteServerService(db)
|
||||
|
||||
// Create
|
||||
rs := &models.RemoteServer{
|
||||
UUID: uuid.NewString(),
|
||||
Name: "Test Server",
|
||||
Host: "192.168.1.100",
|
||||
Port: 22,
|
||||
Provider: "manual",
|
||||
}
|
||||
err := service.Create(rs)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, rs.ID)
|
||||
assert.NotEmpty(t, rs.UUID)
|
||||
|
||||
// GetByID
|
||||
fetched, err := service.GetByID(rs.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rs.Name, fetched.Name)
|
||||
|
||||
// GetByUUID
|
||||
fetchedUUID, err := service.GetByUUID(rs.UUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rs.ID, fetchedUUID.ID)
|
||||
|
||||
// Update
|
||||
rs.Name = "Updated Server"
|
||||
err = service.Update(rs)
|
||||
require.NoError(t, err)
|
||||
|
||||
fetchedUpdated, err := service.GetByID(rs.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Server", fetchedUpdated.Name)
|
||||
|
||||
// List
|
||||
list, err := service.List(false)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list, 1)
|
||||
|
||||
// Delete
|
||||
err = service.Delete(rs.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify Delete
|
||||
_, err = service.GetByID(rs.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/version"
|
||||
)
|
||||
|
||||
type UpdateService struct {
|
||||
currentVersion string
|
||||
repoOwner string
|
||||
repoName string
|
||||
lastCheck time.Time
|
||||
cachedResult *UpdateInfo
|
||||
apiURL string // For testing
|
||||
}
|
||||
|
||||
type UpdateInfo struct {
|
||||
Available bool `json:"available"`
|
||||
LatestVersion string `json:"latest_version"`
|
||||
ChangelogURL string `json:"changelog_url"`
|
||||
}
|
||||
|
||||
type githubRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
}
|
||||
|
||||
func NewUpdateService() *UpdateService {
|
||||
return &UpdateService{
|
||||
currentVersion: version.Version,
|
||||
repoOwner: "Wikid82",
|
||||
repoName: "CaddyProxyManagerPlus",
|
||||
apiURL: "https://api.github.com/repos/Wikid82/CaddyProxyManagerPlus/releases/latest",
|
||||
}
|
||||
}
|
||||
|
||||
// SetAPIURL sets the GitHub API URL for testing.
|
||||
func (s *UpdateService) SetAPIURL(url string) {
|
||||
s.apiURL = url
|
||||
}
|
||||
|
||||
// SetCurrentVersion sets the current version for testing.
|
||||
func (s *UpdateService) SetCurrentVersion(v string) {
|
||||
s.currentVersion = v
|
||||
}
|
||||
|
||||
// ClearCache clears the update cache for testing.
|
||||
func (s *UpdateService) ClearCache() {
|
||||
s.cachedResult = nil
|
||||
s.lastCheck = time.Time{}
|
||||
}
|
||||
|
||||
func (s *UpdateService) CheckForUpdates() (*UpdateInfo, error) {
|
||||
// Cache for 1 hour
|
||||
if s.cachedResult != nil && time.Since(s.lastCheck) < 1*time.Hour {
|
||||
return s.cachedResult, nil
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
req, err := http.NewRequest("GET", s.apiURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "CPMP-Update-Checker")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// If rate limited or not found, just return no update available
|
||||
return &UpdateInfo{Available: false}, nil
|
||||
}
|
||||
|
||||
var release githubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Simple string comparison for now.
|
||||
// In production, use a semver library.
|
||||
// Assuming tags are "v0.1.0" and version is "0.1.0"
|
||||
latest := release.TagName
|
||||
if len(latest) > 0 && latest[0] == 'v' {
|
||||
latest = latest[1:]
|
||||
}
|
||||
|
||||
info := &UpdateInfo{
|
||||
Available: latest != s.currentVersion && latest != "",
|
||||
LatestVersion: release.TagName,
|
||||
ChangelogURL: release.HTMLURL,
|
||||
}
|
||||
|
||||
s.cachedResult = info
|
||||
s.lastCheck = time.Now()
|
||||
|
||||
return info, nil
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUpdateService_CheckForUpdates(t *testing.T) {
|
||||
// Mock GitHub API
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/releases/latest" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
release := githubRelease{
|
||||
TagName: "v1.0.0",
|
||||
HTMLURL: "https://github.com/Wikid82/CaddyProxyManagerPlus/releases/tag/v1.0.0",
|
||||
}
|
||||
json.NewEncoder(w).Encode(release)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
us := NewUpdateService()
|
||||
us.SetAPIURL(server.URL + "/releases/latest")
|
||||
// us.currentVersion is private, so we can't set it directly in test unless we export it or add a setter.
|
||||
// However, NewUpdateService sets it from version.Version.
|
||||
// We can temporarily change version.Version if it's a var, but it's likely a const or var in another package.
|
||||
// Let's check version package.
|
||||
// Assuming version.Version is a var we can change, or we add a SetCurrentVersion method for testing.
|
||||
// For now, let's assume we can't change it easily without a setter.
|
||||
// Let's add SetCurrentVersion to UpdateService for testing purposes.
|
||||
us.SetCurrentVersion("0.9.0")
|
||||
|
||||
// Test Update Available
|
||||
info, err := us.CheckForUpdates()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, info.Available)
|
||||
assert.Equal(t, "v1.0.0", info.LatestVersion)
|
||||
assert.Equal(t, "https://github.com/Wikid82/CaddyProxyManagerPlus/releases/tag/v1.0.0", info.ChangelogURL)
|
||||
|
||||
// Test No Update Available
|
||||
us.SetCurrentVersion("1.0.0")
|
||||
// us.cachedResult = nil // cachedResult is private
|
||||
// us.lastCheck = time.Time{} // lastCheck is private
|
||||
us.ClearCache() // Add this method
|
||||
|
||||
info, err = us.CheckForUpdates()
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, info.Available)
|
||||
assert.Equal(t, "v1.0.0", info.LatestVersion)
|
||||
|
||||
// Test Cache
|
||||
// If we call again immediately, it should use cache.
|
||||
// We can verify this by closing the server or changing the response, but cache logic is simple.
|
||||
// Let's change the server handler? No, httptest server handler is fixed.
|
||||
// But we can check if it returns the same object.
|
||||
info2, err := us.CheckForUpdates()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, info, info2)
|
||||
|
||||
// Test Error (Server Down)
|
||||
server.Close()
|
||||
us.cachedResult = nil
|
||||
us.lastCheck = time.Time{}
|
||||
|
||||
// Depending on implementation, it might return error or just available=false
|
||||
// Implementation:
|
||||
// resp, err := client.Do(req) -> returns error if connection refused
|
||||
// if err != nil { return nil, err }
|
||||
_, err = us.CheckForUpdates()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UptimeService struct {
|
||||
DB *gorm.DB
|
||||
NotificationService *NotificationService
|
||||
}
|
||||
|
||||
func NewUptimeService(db *gorm.DB, ns *NotificationService) *UptimeService {
|
||||
return &UptimeService{
|
||||
DB: db,
|
||||
NotificationService: ns,
|
||||
}
|
||||
}
|
||||
|
||||
// SyncMonitors ensures every ProxyHost has a corresponding UptimeMonitor
|
||||
func (s *UptimeService) SyncMonitors() error {
|
||||
var hosts []models.ProxyHost
|
||||
if err := s.DB.Find(&hosts).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
var monitor models.UptimeMonitor
|
||||
err := s.DB.Where("proxy_host_id = ?", host.ID).First(&monitor).Error
|
||||
|
||||
domains := strings.Split(host.DomainNames, ",")
|
||||
firstDomain := ""
|
||||
if len(domains) > 0 {
|
||||
firstDomain = strings.TrimSpace(domains[0])
|
||||
}
|
||||
|
||||
// Construct the public URL
|
||||
scheme := "http"
|
||||
if host.SSLForced {
|
||||
scheme = "https"
|
||||
}
|
||||
publicURL := fmt.Sprintf("%s://%s", scheme, firstDomain)
|
||||
internalURL := fmt.Sprintf("%s:%d", host.ForwardHost, host.ForwardPort)
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// Create new monitor
|
||||
name := host.Name
|
||||
if name == "" {
|
||||
name = firstDomain
|
||||
}
|
||||
|
||||
monitor = models.UptimeMonitor{
|
||||
ProxyHostID: &host.ID,
|
||||
Name: name,
|
||||
Type: "http", // Check public access
|
||||
URL: publicURL,
|
||||
Interval: 60,
|
||||
Enabled: true,
|
||||
Status: "pending",
|
||||
}
|
||||
if err := s.DB.Create(&monitor).Error; err != nil {
|
||||
log.Printf("Failed to create monitor for host %d: %v", host.ID, err)
|
||||
}
|
||||
} else if err == nil {
|
||||
// Update existing monitor if it looks like it's using the old default (TCP to internal upstream)
|
||||
// We check if it matches the internal upstream URL to avoid overwriting custom user settings
|
||||
if monitor.Type == "tcp" && monitor.URL == internalURL {
|
||||
monitor.Type = "http"
|
||||
monitor.URL = publicURL
|
||||
s.DB.Save(&monitor)
|
||||
log.Printf("Migrated monitor for host %d to check public URL: %s", host.ID, publicURL)
|
||||
}
|
||||
|
||||
// Upgrade to HTTPS if SSL is forced and we are currently checking HTTP
|
||||
if host.SSLForced && strings.HasPrefix(monitor.URL, "http://") {
|
||||
monitor.URL = strings.Replace(monitor.URL, "http://", "https://", 1)
|
||||
s.DB.Save(&monitor)
|
||||
log.Printf("Upgraded monitor for host %d to HTTPS: %s", host.ID, monitor.URL)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckAll runs checks for all enabled monitors
|
||||
func (s *UptimeService) CheckAll() {
|
||||
var monitors []models.UptimeMonitor
|
||||
if err := s.DB.Where("enabled = ?", true).Find(&monitors).Error; err != nil {
|
||||
log.Printf("Failed to fetch monitors: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, monitor := range monitors {
|
||||
go s.checkMonitor(monitor)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UptimeService) checkMonitor(monitor models.UptimeMonitor) {
|
||||
start := time.Now()
|
||||
success := false
|
||||
var msg string
|
||||
|
||||
switch monitor.Type {
|
||||
case "http", "https":
|
||||
client := http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Get(monitor.URL)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
// Accept 2xx, 3xx, and 401/403 (Unauthorized/Forbidden often means the service is up but protected)
|
||||
if (resp.StatusCode >= 200 && resp.StatusCode < 400) || resp.StatusCode == 401 || resp.StatusCode == 403 {
|
||||
success = true
|
||||
msg = fmt.Sprintf("HTTP %d", resp.StatusCode)
|
||||
} else {
|
||||
msg = fmt.Sprintf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
msg = err.Error()
|
||||
}
|
||||
case "tcp":
|
||||
conn, err := net.DialTimeout("tcp", monitor.URL, 10*time.Second)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
success = true
|
||||
msg = "Connection successful"
|
||||
} else {
|
||||
msg = err.Error()
|
||||
}
|
||||
default:
|
||||
msg = "Unknown monitor type"
|
||||
}
|
||||
|
||||
latency := time.Since(start).Milliseconds()
|
||||
status := "down"
|
||||
if success {
|
||||
status = "up"
|
||||
}
|
||||
|
||||
// Record Heartbeat
|
||||
heartbeat := models.UptimeHeartbeat{
|
||||
MonitorID: monitor.ID,
|
||||
Status: status,
|
||||
Latency: latency,
|
||||
Message: msg,
|
||||
}
|
||||
s.DB.Create(&heartbeat)
|
||||
|
||||
// Update Monitor Status
|
||||
oldStatus := monitor.Status
|
||||
monitor.Status = status
|
||||
monitor.LastCheck = time.Now()
|
||||
monitor.Latency = latency
|
||||
s.DB.Save(&monitor)
|
||||
|
||||
// Send Notification if status changed
|
||||
if oldStatus != "pending" && oldStatus != status {
|
||||
title := fmt.Sprintf("Monitor %s is %s", monitor.Name, status)
|
||||
|
||||
nType := models.NotificationTypeInfo
|
||||
if status == "down" {
|
||||
nType = models.NotificationTypeError
|
||||
} else if status == "up" {
|
||||
nType = models.NotificationTypeSuccess
|
||||
}
|
||||
|
||||
s.NotificationService.Create(
|
||||
nType,
|
||||
title,
|
||||
fmt.Sprintf("Monitor %s changed status from %s to %s. Latency: %dms. Message: %s", monitor.Name, oldStatus, status, latency, msg),
|
||||
)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Name": monitor.Name,
|
||||
"Status": status,
|
||||
"Latency": latency,
|
||||
"Message": msg,
|
||||
}
|
||||
s.NotificationService.SendExternal("uptime", title, msg, data)
|
||||
}
|
||||
}
|
||||
|
||||
// CRUD for Monitors
|
||||
|
||||
func (s *UptimeService) ListMonitors() ([]models.UptimeMonitor, error) {
|
||||
var monitors []models.UptimeMonitor
|
||||
result := s.DB.Find(&monitors)
|
||||
return monitors, result.Error
|
||||
}
|
||||
|
||||
func (s *UptimeService) GetMonitorHistory(id string, limit int) ([]models.UptimeHeartbeat, error) {
|
||||
var heartbeats []models.UptimeHeartbeat
|
||||
result := s.DB.Where("monitor_id = ?", id).Order("created_at desc").Limit(limit).Find(&heartbeats)
|
||||
return heartbeats, result.Error
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/CaddyProxyManagerPlus/backend/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupUptimeTestDB(t *testing.T) *gorm.DB {
|
||||
dsn := filepath.Join(t.TempDir(), "test.db") + "?_busy_timeout=5000&_journal_mode=WAL"
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
err = db.AutoMigrate(&models.Notification{}, &models.NotificationProvider{}, &models.Setting{}, &models.ProxyHost{}, &models.UptimeMonitor{}, &models.UptimeHeartbeat{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to migrate database: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func TestUptimeService_CheckAll(t *testing.T) {
|
||||
db := setupUptimeTestDB(t)
|
||||
ns := NewNotificationService(db)
|
||||
us := NewUptimeService(db, ns)
|
||||
|
||||
// Create a dummy HTTP server for a "UP" host
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start listener: %v", err)
|
||||
}
|
||||
addr := listener.Addr().(*net.TCPAddr)
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
}
|
||||
go server.Serve(listener)
|
||||
defer server.Close()
|
||||
|
||||
// Seed ProxyHosts
|
||||
// We use the listener address as the "DomainName" so the monitor checks this HTTP server
|
||||
upHost := models.ProxyHost{
|
||||
UUID: "uuid-1",
|
||||
DomainNames: fmt.Sprintf("127.0.0.1:%d", addr.Port),
|
||||
ForwardHost: "127.0.0.1",
|
||||
ForwardPort: addr.Port,
|
||||
Enabled: true,
|
||||
}
|
||||
db.Create(&upHost)
|
||||
|
||||
downHost := models.ProxyHost{
|
||||
UUID: "uuid-2",
|
||||
DomainNames: "down.example.com", // This won't resolve or connect
|
||||
ForwardHost: "127.0.0.1",
|
||||
ForwardPort: 54321,
|
||||
Enabled: true,
|
||||
}
|
||||
db.Create(&downHost)
|
||||
|
||||
// Sync Monitors (this creates UptimeMonitor records)
|
||||
err = us.SyncMonitors()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify Monitors created
|
||||
var monitors []models.UptimeMonitor
|
||||
db.Find(&monitors)
|
||||
assert.Equal(t, 2, len(monitors))
|
||||
|
||||
// Run CheckAll
|
||||
us.CheckAll()
|
||||
time.Sleep(200 * time.Millisecond) // Increased wait time for HTTP check
|
||||
|
||||
// Verify Heartbeats
|
||||
var heartbeats []models.UptimeHeartbeat
|
||||
db.Find(&heartbeats)
|
||||
assert.GreaterOrEqual(t, len(heartbeats), 2)
|
||||
|
||||
// Verify Status
|
||||
var upMonitor models.UptimeMonitor
|
||||
db.Where("proxy_host_id = ?", upHost.ID).First(&upMonitor)
|
||||
assert.Equal(t, "up", upMonitor.Status)
|
||||
|
||||
var downMonitor models.UptimeMonitor
|
||||
db.Where("proxy_host_id = ?", downHost.ID).First(&downMonitor)
|
||||
assert.Equal(t, "down", downMonitor.Status)
|
||||
|
||||
// Verify Notifications
|
||||
// We expect 0 notifications because initial state transition from "pending" is ignored
|
||||
var notifications []models.Notification
|
||||
db.Find(¬ifications)
|
||||
assert.Equal(t, 0, len(notifications), "Should have 0 notifications on first run")
|
||||
|
||||
// Now let's flip the status to trigger notification
|
||||
// Make upHost go DOWN by closing the listener
|
||||
server.Close()
|
||||
listener.Close()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
us.CheckAll()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
db.Where("proxy_host_id = ?", upHost.ID).First(&upMonitor)
|
||||
assert.Equal(t, "down", upMonitor.Status)
|
||||
|
||||
db.Find(¬ifications)
|
||||
assert.Equal(t, 1, len(notifications), "Should have 1 notification now")
|
||||
if len(notifications) > 0 {
|
||||
assert.Contains(t, notifications[0].Message, upHost.DomainNames, "Notification should mention the host")
|
||||
assert.Equal(t, models.NotificationTypeError, notifications[0].Type, "Notification type should be error for DOWN event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUptimeService_ListMonitors(t *testing.T) {
|
||||
db := setupUptimeTestDB(t)
|
||||
ns := NewNotificationService(db)
|
||||
us := NewUptimeService(db, ns)
|
||||
|
||||
db.Create(&models.UptimeMonitor{
|
||||
Name: "Test Monitor",
|
||||
Type: "http",
|
||||
URL: "http://example.com",
|
||||
})
|
||||
|
||||
monitors, err := us.ListMonitors()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, monitors, 1)
|
||||
assert.Equal(t, "Test Monitor", monitors[0].Name)
|
||||
}
|
||||
|
||||
func TestUptimeService_GetMonitorHistory(t *testing.T) {
|
||||
db := setupUptimeTestDB(t)
|
||||
ns := NewNotificationService(db)
|
||||
us := NewUptimeService(db, ns)
|
||||
|
||||
monitor := models.UptimeMonitor{
|
||||
ID: "monitor-1",
|
||||
Name: "Test Monitor",
|
||||
}
|
||||
db.Create(&monitor)
|
||||
|
||||
db.Create(&models.UptimeHeartbeat{
|
||||
MonitorID: monitor.ID,
|
||||
Status: "up",
|
||||
Latency: 10,
|
||||
CreatedAt: time.Now().Add(-1 * time.Minute),
|
||||
})
|
||||
db.Create(&models.UptimeHeartbeat{
|
||||
MonitorID: monitor.ID,
|
||||
Status: "down",
|
||||
Latency: 0,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
history, err := us.GetMonitorHistory(monitor.ID, 100)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, history, 2)
|
||||
assert.Equal(t, "down", history[0].Status)
|
||||
}
|
||||
Reference in New Issue
Block a user