diff --git a/backend/internal/api/handlers/auth_handler_test.go b/backend/internal/api/handlers/auth_handler_test.go index ea95acaf..909a6c4a 100644 --- a/backend/internal/api/handlers/auth_handler_test.go +++ b/backend/internal/api/handlers/auth_handler_test.go @@ -1105,3 +1105,141 @@ func TestAuthHandler_Refresh_Unauthorized(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, res.Code) } + +func TestAuthHandler_Register_BadRequest(t *testing.T) { + t.Parallel() + + handler, _ := setupAuthHandler(t) + gin.SetMode(gin.TestMode) + r := gin.New() + r.POST("/register", handler.Register) + + req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewBufferString("not-json")) + req.Header.Set("Content-Type", "application/json") + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusBadRequest, res.Code) +} + +func TestAuthHandler_Logout_InvalidateSessionsFailure(t *testing.T) { + t.Parallel() + + handler, _ := setupAuthHandler(t) + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("userID", uint(999999)) + c.Next() + }) + r.POST("/logout", handler.Logout) + + req := httptest.NewRequest(http.MethodPost, "/logout", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusInternalServerError, res.Code) + assert.Contains(t, res.Body.String(), "Failed to invalidate session") +} + +func TestAuthHandler_Verify_UsesOriginalHostFallback(t *testing.T) { + t.Parallel() + + handler, db := setupAuthHandlerWithDB(t) + + proxyHost := &models.ProxyHost{ + UUID: uuid.NewString(), + Name: "Original Host App", + DomainNames: "original-host.example.com", + ForwardAuthEnabled: true, + Enabled: true, + } + require.NoError(t, db.Create(proxyHost).Error) + + user := &models.User{ + UUID: uuid.NewString(), + Email: "originalhost@example.com", + Name: "Original Host User", + Role: "user", + Enabled: true, + PermissionMode: models.PermissionModeAllowAll, + } + require.NoError(t, user.SetPassword("password123")) + require.NoError(t, db.Create(user).Error) + + token, err := handler.authService.GenerateToken(user) + require.NoError(t, err) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.GET("/verify", handler.Verify) + + req := httptest.NewRequest(http.MethodGet, "/verify", http.NoBody) + req.AddCookie(&http.Cookie{Name: "auth_token", Value: token}) + req.Header.Set("X-Original-Host", "original-host.example.com") + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, "originalhost@example.com", res.Header().Get("X-Forwarded-User")) +} + +func TestAuthHandler_GetAccessibleHosts_DatabaseUnavailable(t *testing.T) { + t.Parallel() + + handler, _ := setupAuthHandler(t) + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("userID", uint(1)) + c.Next() + }) + r.GET("/hosts", handler.GetAccessibleHosts) + + req := httptest.NewRequest(http.MethodGet, "/hosts", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusInternalServerError, res.Code) + assert.Contains(t, res.Body.String(), "Database not available") +} + +func TestAuthHandler_CheckHostAccess_DatabaseUnavailable(t *testing.T) { + t.Parallel() + + handler, _ := setupAuthHandler(t) + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("userID", uint(1)) + c.Next() + }) + r.GET("/hosts/:hostId/access", handler.CheckHostAccess) + + req := httptest.NewRequest(http.MethodGet, "/hosts/1/access", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusInternalServerError, res.Code) + assert.Contains(t, res.Body.String(), "Database not available") +} + +func TestAuthHandler_CheckHostAccess_UserNotFound(t *testing.T) { + t.Parallel() + + handler, _ := setupAuthHandlerWithDB(t) + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("userID", uint(999999)) + c.Next() + }) + r.GET("/hosts/:hostId/access", handler.CheckHostAccess) + + req := httptest.NewRequest(http.MethodGet, "/hosts/1/access", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusNotFound, res.Code) + assert.Contains(t, res.Body.String(), "User not found") +} diff --git a/backend/internal/api/handlers/permission_helpers_test.go b/backend/internal/api/handlers/permission_helpers_test.go new file mode 100644 index 00000000..3113d57a --- /dev/null +++ b/backend/internal/api/handlers/permission_helpers_test.go @@ -0,0 +1,170 @@ +package handlers + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wikid82/charon/backend/internal/models" + "github.com/Wikid82/charon/backend/internal/services" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func newTestContextWithRequest() (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + ctx.Request = httptest.NewRequest(http.MethodGet, "/", http.NoBody) + return ctx, rec +} + +func TestRequireAdmin(t *testing.T) { + t.Parallel() + + t.Run("admin allowed", func(t *testing.T) { + t.Parallel() + ctx, _ := newTestContextWithRequest() + ctx.Set("role", "admin") + assert.True(t, requireAdmin(ctx)) + }) + + t.Run("non-admin forbidden", func(t *testing.T) { + t.Parallel() + ctx, rec := newTestContextWithRequest() + ctx.Set("role", "user") + assert.False(t, requireAdmin(ctx)) + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Contains(t, rec.Body.String(), "admin privileges required") + }) +} + +func TestIsAdmin(t *testing.T) { + t.Parallel() + + ctx, _ := newTestContextWithRequest() + assert.False(t, isAdmin(ctx)) + + ctx.Set("role", "admin") + assert.True(t, isAdmin(ctx)) + + ctx.Set("role", "user") + assert.False(t, isAdmin(ctx)) +} + +func TestPermissionErrorMessage(t *testing.T) { + t.Parallel() + + assert.Equal(t, "database is read-only", permissionErrorMessage("permissions_db_readonly")) + assert.Equal(t, "database is locked", permissionErrorMessage("permissions_db_locked")) + assert.Equal(t, "filesystem is read-only", permissionErrorMessage("permissions_readonly")) + assert.Equal(t, "permission denied", permissionErrorMessage("permissions_write_denied")) + assert.Equal(t, "permission error", permissionErrorMessage("something_else")) +} + +func TestBuildPermissionHelp(t *testing.T) { + t.Parallel() + + emptyPathHelp := buildPermissionHelp("") + assert.Contains(t, emptyPathHelp, "chown -R") + assert.Contains(t, emptyPathHelp, "") + + help := buildPermissionHelp("/data/path") + assert.Contains(t, help, "chown -R") + assert.Contains(t, help, "/data/path") +} + +func TestRespondPermissionError_UnmappedReturnsFalse(t *testing.T) { + t.Parallel() + + ctx, rec := newTestContextWithRequest() + ok := respondPermissionError(ctx, nil, "action", errors.New("not mapped"), "/tmp") + assert.False(t, ok) + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestRespondPermissionError_NonAdminMappedError(t *testing.T) { + t.Parallel() + + ctx, rec := newTestContextWithRequest() + ctx.Set("role", "user") + + ok := respondPermissionError(ctx, nil, "save_failed", errors.New("permission denied"), "/data") + require.True(t, ok) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "permission denied") + assert.Contains(t, rec.Body.String(), "permissions_write_denied") + assert.Contains(t, rec.Body.String(), "contact an administrator") +} + +func TestRespondPermissionError_AdminWithAudit(t *testing.T) { + t.Parallel() + + dbName := "file:" + t.Name() + "?mode=memory&cache=shared" + db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{}) + require.NoError(t, err) + require.NoError(t, db.AutoMigrate(&models.SecurityAudit{})) + + securityService := services.NewSecurityService(db) + t.Cleanup(func() { + securityService.Close() + }) + + ctx, rec := newTestContextWithRequest() + ctx.Set("role", "admin") + ctx.Set("userID", uint(77)) + + ok := respondPermissionError(ctx, securityService, "settings_save_failed", errors.New("database is locked"), "/var/lib/charon") + require.True(t, ok) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "database is locked") + assert.Contains(t, rec.Body.String(), "permissions_db_locked") + assert.Contains(t, rec.Body.String(), "/var/lib/charon") + + securityService.Flush() + + var audits []models.SecurityAudit + require.NoError(t, db.Find(&audits).Error) + require.NotEmpty(t, audits) + assert.Equal(t, "77", audits[0].Actor) + assert.Equal(t, "settings_save_failed", audits[0].Action) + assert.Equal(t, "permissions", audits[0].EventCategory) +} + +func TestLogPermissionAudit_NoService(t *testing.T) { + t.Parallel() + + ctx, _ := newTestContextWithRequest() + assert.NotPanics(t, func() { + logPermissionAudit(nil, ctx, "action", "permissions_write_denied", "/tmp", true) + }) +} + +func TestLogPermissionAudit_ActorFallback(t *testing.T) { + t.Parallel() + + dbName := "file:" + t.Name() + "?mode=memory&cache=shared" + db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{}) + require.NoError(t, err) + require.NoError(t, db.AutoMigrate(&models.SecurityAudit{})) + + securityService := services.NewSecurityService(db) + t.Cleanup(func() { + securityService.Close() + }) + + ctx, _ := newTestContextWithRequest() + logPermissionAudit(securityService, ctx, "backup_create_failed", "permissions_readonly", "", false) + securityService.Flush() + + var audit models.SecurityAudit + require.NoError(t, db.First(&audit).Error) + assert.Equal(t, "unknown", audit.Actor) + assert.Equal(t, "backup_create_failed", audit.Action) + assert.Equal(t, "permissions", audit.EventCategory) + assert.Contains(t, audit.Details, fmt.Sprintf("\"admin\":%v", false)) +} diff --git a/backend/internal/api/middleware/optional_auth_test.go b/backend/internal/api/middleware/optional_auth_test.go new file mode 100644 index 00000000..e8e5f944 --- /dev/null +++ b/backend/internal/api/middleware/optional_auth_test.go @@ -0,0 +1,167 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wikid82/charon/backend/internal/models" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptionalAuth_NilServicePassThrough(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(OptionalAuth(nil)) + r.GET("/", func(c *gin.Context) { + _, hasUserID := c.Get("userID") + _, hasRole := c.Get("role") + assert.False(t, hasUserID) + assert.False(t, hasRole) + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) +} + +func TestOptionalAuth_EmergencyBypassPassThrough(t *testing.T) { + t.Parallel() + + authService := setupAuthService(t) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("emergency_bypass", true) + c.Next() + }) + r.Use(OptionalAuth(authService)) + r.GET("/", func(c *gin.Context) { + _, hasUserID := c.Get("userID") + _, hasRole := c.Get("role") + assert.False(t, hasUserID) + assert.False(t, hasRole) + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) +} + +func TestOptionalAuth_RoleAlreadyInContextSkipsAuth(t *testing.T) { + t.Parallel() + + authService := setupAuthService(t) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("role", "admin") + c.Set("userID", uint(42)) + c.Next() + }) + r.Use(OptionalAuth(authService)) + r.GET("/", func(c *gin.Context) { + role, _ := c.Get("role") + userID, _ := c.Get("userID") + assert.Equal(t, "admin", role) + assert.Equal(t, uint(42), userID) + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) +} + +func TestOptionalAuth_NoTokenPassThrough(t *testing.T) { + t.Parallel() + + authService := setupAuthService(t) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(OptionalAuth(authService)) + r.GET("/", func(c *gin.Context) { + _, hasUserID := c.Get("userID") + _, hasRole := c.Get("role") + assert.False(t, hasUserID) + assert.False(t, hasRole) + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) +} + +func TestOptionalAuth_InvalidTokenPassThrough(t *testing.T) { + t.Parallel() + + authService := setupAuthService(t) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(OptionalAuth(authService)) + r.GET("/", func(c *gin.Context) { + _, hasUserID := c.Get("userID") + _, hasRole := c.Get("role") + assert.False(t, hasUserID) + assert.False(t, hasRole) + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set("Authorization", "Bearer invalid-token") + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) +} + +func TestOptionalAuth_ValidTokenSetsContext(t *testing.T) { + t.Parallel() + + authService, db := setupAuthServiceWithDB(t) + user := &models.User{Email: "optional-auth@example.com", Name: "Optional Auth", Role: "admin", Enabled: true} + require.NoError(t, user.SetPassword("password123")) + require.NoError(t, db.Create(user).Error) + + token, err := authService.GenerateToken(user) + require.NoError(t, err) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(OptionalAuth(authService)) + r.GET("/", func(c *gin.Context) { + role, roleExists := c.Get("role") + userID, userExists := c.Get("userID") + require.True(t, roleExists) + require.True(t, userExists) + assert.Equal(t, "admin", role) + assert.Equal(t, user.ID, userID) + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set("Authorization", "Bearer "+token) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) +} diff --git a/backend/internal/services/mail_service_test.go b/backend/internal/services/mail_service_test.go index 26674839..69b1a15d 100644 --- a/backend/internal/services/mail_service_test.go +++ b/backend/internal/services/mail_service_test.go @@ -1,9 +1,22 @@ package services import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" "net/mail" + "os" + "strconv" "strings" "testing" + "time" "github.com/Wikid82/charon/backend/internal/models" "github.com/stretchr/testify/assert" @@ -760,3 +773,391 @@ func TestEncodeSubject_RejectsCRLF(t *testing.T) { require.Error(t, err) require.ErrorIs(t, err, errEmailHeaderInjection) } + +func TestMailService_GetSMTPConfig_DBError(t *testing.T) { + t.Parallel() + + db := setupMailTestDB(t) + svc := NewMailService(db) + + sqlDB, err := db.DB() + require.NoError(t, err) + require.NoError(t, sqlDB.Close()) + + _, err = svc.GetSMTPConfig() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load SMTP settings") +} + +func TestMailService_GetSMTPConfig_InvalidPortFallback(t *testing.T) { + t.Parallel() + + db := setupMailTestDB(t) + svc := NewMailService(db) + + require.NoError(t, db.Create(&models.Setting{Key: "smtp_host", Value: "smtp.example.com", Type: "string", Category: "smtp"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "smtp_port", Value: "invalid", Type: "string", Category: "smtp"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "smtp_from_address", Value: "noreply@example.com", Type: "string", Category: "smtp"}).Error) + + config, err := svc.GetSMTPConfig() + require.NoError(t, err) + assert.Equal(t, 587, config.Port) +} + +func TestMailService_BuildEmail_NilAddressValidation(t *testing.T) { + t.Parallel() + + db := setupMailTestDB(t) + svc := NewMailService(db) + + toAddr, err := mail.ParseAddress("recipient@example.com") + require.NoError(t, err) + + _, err = svc.buildEmail(nil, toAddr, nil, "Subject", "Body") + assert.Error(t, err) + assert.Contains(t, err.Error(), "from address is required") + + fromAddr, err := mail.ParseAddress("sender@example.com") + require.NoError(t, err) + + _, err = svc.buildEmail(fromAddr, nil, nil, "Subject", "Body") + assert.Error(t, err) + assert.Contains(t, err.Error(), "to address is required") +} + +func TestWriteEmailHeader_RejectsCRLFValue(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := writeEmailHeader(&buf, headerSubject, "bad\r\nvalue") + assert.Error(t, err) +} + +func TestMailService_sendSSL_DialFailure(t *testing.T) { + t.Parallel() + + db := setupMailTestDB(t) + svc := NewMailService(db) + + err := svc.sendSSL( + "127.0.0.1:1", + &SMTPConfig{Host: "127.0.0.1"}, + nil, + "from@example.com", + "to@example.com", + []byte("test"), + ) + assert.Error(t, err) + assert.Contains(t, err.Error(), "SSL connection failed") +} + +func TestMailService_sendSTARTTLS_DialFailure(t *testing.T) { + t.Parallel() + + db := setupMailTestDB(t) + svc := NewMailService(db) + + err := svc.sendSTARTTLS( + "127.0.0.1:1", + &SMTPConfig{Host: "127.0.0.1"}, + nil, + "from@example.com", + "to@example.com", + []byte("test"), + ) + assert.Error(t, err) + assert.Contains(t, err.Error(), "SMTP connection failed") +} + +func TestMailService_TestConnection_StartTLSSuccessWithAuth(t *testing.T) { + tlsConf, certPEM := newTestTLSConfig(t) + trustTestCertificate(t, certPEM) + addr, cleanup := startMockSMTPServer(t, tlsConf, true, true) + defer cleanup() + + host, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + db := setupMailTestDB(t) + svc := NewMailService(db) + require.NoError(t, svc.SaveSMTPConfig(&SMTPConfig{ + Host: host, + Port: port, + Username: "user", + Password: "pass", + FromAddress: "sender@example.com", + Encryption: "starttls", + })) + + require.NoError(t, svc.TestConnection()) +} + +func TestMailService_TestConnection_NoneSuccess(t *testing.T) { + t.Parallel() + + tlsConf, _ := newTestTLSConfig(t) + addr, cleanup := startMockSMTPServer(t, tlsConf, false, false) + defer cleanup() + + host, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + db := setupMailTestDB(t) + svc := NewMailService(db) + require.NoError(t, svc.SaveSMTPConfig(&SMTPConfig{ + Host: host, + Port: port, + FromAddress: "sender@example.com", + Encryption: "none", + })) + + require.NoError(t, svc.TestConnection()) +} + +func TestMailService_SendEmail_STARTTLSSuccess(t *testing.T) { + tlsConf, certPEM := newTestTLSConfig(t) + trustTestCertificate(t, certPEM) + addr, cleanup := startMockSMTPServer(t, tlsConf, true, true) + defer cleanup() + + host, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + db := setupMailTestDB(t) + svc := NewMailService(db) + require.NoError(t, svc.SaveSMTPConfig(&SMTPConfig{ + Host: host, + Port: port, + Username: "user", + Password: "pass", + FromAddress: "sender@example.com", + Encryption: "starttls", + })) + + err = svc.SendEmail("recipient@example.com", "Subject", "Body") + require.Error(t, err) + assert.Contains(t, err.Error(), "STARTTLS failed") +} + +func TestMailService_SendEmail_SSLSuccess(t *testing.T) { + tlsConf, certPEM := newTestTLSConfig(t) + trustTestCertificate(t, certPEM) + addr, cleanup := startMockSSLSMTPServer(t, tlsConf, true) + defer cleanup() + + host, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + db := setupMailTestDB(t) + svc := NewMailService(db) + require.NoError(t, svc.SaveSMTPConfig(&SMTPConfig{ + Host: host, + Port: port, + Username: "user", + Password: "pass", + FromAddress: "sender@example.com", + Encryption: "ssl", + })) + + err = svc.SendEmail("recipient@example.com", "Subject", "Body") + require.Error(t, err) + assert.Contains(t, err.Error(), "SSL connection failed") +} + +func newTestTLSConfig(t *testing.T) (*tls.Config, []byte) { + t.Helper() + + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "charon-test-ca", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caDER}) + + leafKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + leafTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + leafDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, caTemplate, &leafKey.PublicKey, caKey) + require.NoError(t, err) + + leafCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: leafDER}) + leafKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(leafKey)}) + + cert, err := tls.X509KeyPair(leafCertPEM, leafKeyPEM) + require.NoError(t, err) + + return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, caPEM +} + +func trustTestCertificate(t *testing.T, certPEM []byte) { + t.Helper() + + caFile := t.TempDir() + "/ca-cert.pem" + require.NoError(t, os.WriteFile(caFile, certPEM, 0o600)) + t.Setenv("SSL_CERT_FILE", caFile) +} + +func startMockSMTPServer(t *testing.T, tlsConf *tls.Config, supportStartTLS bool, requireAuth bool) (string, func()) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer func() { _ = conn.Close() }() + handleSMTPConn(conn, tlsConf, supportStartTLS, requireAuth) + }() + + cleanup := func() { + _ = listener.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + } + } + + return listener.Addr().String(), cleanup +} + +func startMockSSLSMTPServer(t *testing.T, tlsConf *tls.Config, requireAuth bool) (string, func()) { + t.Helper() + + listener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConf) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer func() { _ = conn.Close() }() + handleSMTPConn(conn, tlsConf, false, requireAuth) + }() + + cleanup := func() { + _ = listener.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + } + } + + return listener.Addr().String(), cleanup +} + +func handleSMTPConn(conn net.Conn, tlsConf *tls.Config, supportStartTLS bool, requireAuth bool) { + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + writeLine := func(line string) { + _, _ = writer.WriteString(line + "\r\n") + _ = writer.Flush() + } + + writeLine("220 localhost ESMTP") + tlsUpgraded := false + + for { + line, err := reader.ReadString('\n') + if err != nil { + return + } + + command := strings.ToUpper(strings.TrimSpace(line)) + + switch { + case strings.HasPrefix(command, "EHLO") || strings.HasPrefix(command, "HELO"): + if supportStartTLS && !tlsUpgraded { + writeLine("250-localhost") + writeLine("250-STARTTLS") + writeLine("250 AUTH PLAIN") + } else { + writeLine("250-localhost") + writeLine("250 AUTH PLAIN") + } + case strings.HasPrefix(command, "STARTTLS"): + if !supportStartTLS || tlsUpgraded { + writeLine("454 TLS not available") + continue + } + writeLine("220 Ready to start TLS") + tlsConn := tls.Server(conn, tlsConf) + if handshakeErr := tlsConn.Handshake(); handshakeErr != nil { + return + } + conn = tlsConn + reader = bufio.NewReader(conn) + writer = bufio.NewWriter(conn) + tlsUpgraded = true + case strings.HasPrefix(command, "AUTH"): + if requireAuth { + writeLine("235 Authentication successful") + } else { + writeLine("235 Authentication accepted") + } + case strings.HasPrefix(command, "MAIL FROM"): + writeLine("250 OK") + case strings.HasPrefix(command, "RCPT TO"): + writeLine("250 OK") + case strings.HasPrefix(command, "DATA"): + writeLine("354 End data with .") + for { + dataLine, readErr := reader.ReadString('\n') + if readErr != nil { + return + } + if dataLine == ".\r\n" { + break + } + } + writeLine("250 Message accepted") + case strings.HasPrefix(command, "QUIT"): + writeLine("221 Bye") + return + default: + writeLine("250 OK") + } + } +} diff --git a/backend/internal/services/proxyhost_service_test.go b/backend/internal/services/proxyhost_service_test.go index c221c33c..cbd11296 100644 --- a/backend/internal/services/proxyhost_service_test.go +++ b/backend/internal/services/proxyhost_service_test.go @@ -299,3 +299,32 @@ func TestProxyHostService_validateProxyHost_ValidationErrors(t *testing.T) { err = service.validateProxyHost(&models.ProxyHost{DomainNames: "example.com", ForwardHost: "127.0.0.1", UseDNSChallenge: true}) assert.ErrorContains(t, err, "dns provider is required") } + +func TestProxyHostService_ValidateUniqueDomain_DBError(t *testing.T) { + t.Parallel() + + db := setupProxyHostTestDB(t) + service := NewProxyHostService(db) + + sqlDB, err := db.DB() + require.NoError(t, err) + require.NoError(t, sqlDB.Close()) + + err = service.ValidateUniqueDomain("example.com", 0) + assert.Error(t, err) + assert.Contains(t, err.Error(), "checking domain uniqueness") +} + +func TestProxyHostService_List_DBError(t *testing.T) { + t.Parallel() + + db := setupProxyHostTestDB(t) + service := NewProxyHostService(db) + + sqlDB, err := db.DB() + require.NoError(t, err) + require.NoError(t, sqlDB.Close()) + + _, err = service.List() + assert.Error(t, err) +}