diff --git a/backend/cmd/api/main_test.go b/backend/cmd/api/main_test.go new file mode 100644 index 00000000..4ba043b4 --- /dev/null +++ b/backend/cmd/api/main_test.go @@ -0,0 +1,59 @@ +package main + +import ( + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/Wikid82/charon/backend/internal/database" + "github.com/Wikid82/charon/backend/internal/models" +) + +func TestResetPasswordCommand_Succeeds(t *testing.T) { + if os.Getenv("CHARON_TEST_RUN_MAIN") == "1" { + // Child process: emulate CLI args and run main(). + email := os.Getenv("CHARON_TEST_EMAIL") + newPassword := os.Getenv("CHARON_TEST_NEW_PASSWORD") + os.Args = []string{"charon", "reset-password", email, newPassword} + main() + return + } + + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "data", "test.db") + if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { + t.Fatalf("mkdir db dir: %v", err) + } + + db, err := database.Connect(dbPath) + if err != nil { + t.Fatalf("connect db: %v", err) + } + if err := db.AutoMigrate(&models.User{}); err != nil { + t.Fatalf("automigrate: %v", err) + } + + email := "user@example.com" + user := models.User{UUID: "u-1", Email: email, Name: "User", Role: "admin", Enabled: true} + user.PasswordHash = "$2a$10$example_hashed_password" + if err := db.Create(&user).Error; err != nil { + t.Fatalf("seed user: %v", err) + } + + cmd := exec.Command(os.Args[0], "-test.run=TestResetPasswordCommand_Succeeds") + cmd.Dir = tmp + cmd.Env = append(os.Environ(), + "CHARON_TEST_RUN_MAIN=1", + "CHARON_TEST_EMAIL="+email, + "CHARON_TEST_NEW_PASSWORD=new-password", + "CHARON_DB_PATH="+dbPath, + "CHARON_CADDY_CONFIG_DIR="+filepath.Join(tmp, "caddy"), + "CHARON_IMPORT_DIR="+filepath.Join(tmp, "imports"), + ) + + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("expected exit 0; err=%v; output=%s", err, string(out)) + } +} diff --git a/backend/cmd/seed/main_test.go b/backend/cmd/seed/main_test.go new file mode 100644 index 00000000..ff6c8db7 --- /dev/null +++ b/backend/cmd/seed/main_test.go @@ -0,0 +1,85 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "os" + "path/filepath" + "testing" +) + +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSeedMain_CreatesDatabaseFile(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + + tmp := t.TempDir() + if err := os.Chdir(tmp); err != nil { + t.Fatalf("chdir: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(wd) }) + + if err := os.MkdirAll("data", 0o755); err != nil { + t.Fatalf("mkdir data: %v", err) + } + + main() + + dbPath := filepath.Join("data", "charon.db") + info, err := os.Stat(dbPath) + if err != nil { + t.Fatalf("expected db file to exist at %s: %v", dbPath, err) + } + if info.Size() == 0 { + t.Fatalf("expected db file to be non-empty") + } +} +package main +package main + +import ( + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +} } t.Fatalf("expected db file to be non-empty") if info.Size() == 0 { } t.Fatalf("expected db file to exist at %s: %v", dbPath, err) if err != nil { info, err := os.Stat(dbPath) dbPath := filepath.Join("data", "charon.db") main() } t.Fatalf("mkdir data: %v", err) if err := os.MkdirAll("data", 0o755); err != nil { t.Cleanup(func() { _ = os.Chdir(wd) }) } t.Fatalf("chdir: %v", err) if err := os.Chdir(tmp); err != nil { tmp := t.TempDir() } t.Fatalf("getwd: %v", err) if err != nil { wd, err := os.Getwd() t.Parallel()func TestSeedMain_CreatesDatabaseFile(t *testing.T) {) "testing" "path/filepath" "os" diff --git a/backend/cmd/seed/seed_smoke_test.go b/backend/cmd/seed/seed_smoke_test.go new file mode 100644 index 00000000..5a2b5fbc --- /dev/null +++ b/backend/cmd/seed/seed_smoke_test.go @@ -0,0 +1,31 @@ +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSeedMain_Smoke(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + + tmp := t.TempDir() + if err := os.Chdir(tmp); err != nil { + t.Fatalf("chdir: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(wd) }) + + if err := os.MkdirAll("data", 0o755); err != nil { + t.Fatalf("mkdir data: %v", err) + } + + main() + + p := filepath.Join("data", "charon.db") + if _, err := os.Stat(p); err != nil { + t.Fatalf("expected db file to exist: %v", err) + } +} diff --git a/backend/internal/api/handlers/access_list_handler.go b/backend/internal/api/handlers/access_list_handler.go index c97d5612..62f230ec 100644 --- a/backend/internal/api/handlers/access_list_handler.go +++ b/backend/internal/api/handlers/access_list_handler.go @@ -10,16 +10,23 @@ import ( "gorm.io/gorm" ) +// AccessListHandler handles access list API requests. type AccessListHandler struct { service *services.AccessListService } +// NewAccessListHandler creates a new AccessListHandler. func NewAccessListHandler(db *gorm.DB) *AccessListHandler { return &AccessListHandler{ service: services.NewAccessListService(db), } } +// SetGeoIPService sets the GeoIP service for geo-based ACL lookups. +func (h *AccessListHandler) SetGeoIPService(geoipSvc *services.GeoIPService) { + h.service.SetGeoIPService(geoipSvc) +} + // Create handles POST /api/v1/access-lists func (h *AccessListHandler) Create(c *gin.Context) { var acl models.AccessList diff --git a/backend/internal/api/handlers/access_list_handler_coverage_test.go b/backend/internal/api/handlers/access_list_handler_coverage_test.go index ad50fd9f..afc98556 100644 --- a/backend/internal/api/handlers/access_list_handler_coverage_test.go +++ b/backend/internal/api/handlers/access_list_handler_coverage_test.go @@ -7,12 +7,37 @@ import ( "testing" "github.com/Wikid82/charon/backend/internal/models" + "github.com/Wikid82/charon/backend/internal/services" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "gorm.io/driver/sqlite" "gorm.io/gorm" ) +func TestAccessListHandler_SetGeoIPService(t *testing.T) { + db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + db.AutoMigrate(&models.AccessList{}) + + handler := NewAccessListHandler(db) + + // Test setting GeoIP service + geoipSvc := &services.GeoIPService{} + handler.SetGeoIPService(geoipSvc) + + // No error or panic means success - the function is a simple setter + // We can't easily verify the internal state, but we can verify it doesn't panic +} + +func TestAccessListHandler_SetGeoIPService_Nil(t *testing.T) { + db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + db.AutoMigrate(&models.AccessList{}) + + handler := NewAccessListHandler(db) + + // Test setting nil GeoIP service (should not panic) + handler.SetGeoIPService(nil) +} + func TestAccessListHandler_Get_InvalidID(t *testing.T) { router, _ := setupAccessListTestRouter(t) @@ -250,3 +275,24 @@ func TestAccessListHandler_TestIP_LocalNetworkOnly(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) } + +func TestAccessListHandler_TestIP_InternalError(t *testing.T) { + // Create DB without migrating AccessList to cause internal error + db, _ := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + // Don't migrate - this causes a "no such table" error which is an internal error + + gin.SetMode(gin.TestMode) + router := gin.New() + + handler := NewAccessListHandler(db) + router.POST("/access-lists/:id/test", handler.TestIP) + + body := []byte(`{"ip_address":"192.168.1.1"}`) + req := httptest.NewRequest(http.MethodPost, "/access-lists/1/test", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // Should return 500 since table doesn't exist (internal error, not ErrAccessListNotFound) + assert.Equal(t, http.StatusInternalServerError, w.Code) +} diff --git a/backend/internal/api/handlers/crowdsec_handler.go b/backend/internal/api/handlers/crowdsec_handler.go index fae4dd52..4f7f13a9 100644 --- a/backend/internal/api/handlers/crowdsec_handler.go +++ b/backend/internal/api/handlers/crowdsec_handler.go @@ -878,6 +878,214 @@ type cscliDecision struct { Until string `json:"until"` } +// lapiDecision represents the JSON structure from CrowdSec LAPI /v1/decisions +type lapiDecision struct { + ID int64 `json:"id"` + Origin string `json:"origin"` + Type string `json:"type"` + Scope string `json:"scope"` + Value string `json:"value"` + Duration string `json:"duration"` + Scenario string `json:"scenario"` + CreatedAt string `json:"created_at,omitempty"` + Until string `json:"until,omitempty"` +} + +// GetLAPIDecisions queries CrowdSec LAPI directly for current decisions. +// This is an alternative to ListDecisions which uses cscli. +// Query params: +// - ip: filter by specific IP address +// - scope: filter by scope (e.g., "ip", "range") +// - type: filter by decision type (e.g., "ban", "captcha") +func (h *CrowdsecHandler) GetLAPIDecisions(c *gin.Context) { + // Get LAPI URL from security config or use default + lapiURL := "http://localhost:8080" + if h.Security != nil { + cfg, err := h.Security.Get() + if err == nil && cfg != nil && cfg.CrowdSecAPIURL != "" { + lapiURL = cfg.CrowdSecAPIURL + } + } + + // Build query string + queryParams := make([]string, 0) + if ip := c.Query("ip"); ip != "" { + queryParams = append(queryParams, "ip="+ip) + } + if scope := c.Query("scope"); scope != "" { + queryParams = append(queryParams, "scope="+scope) + } + if decisionType := c.Query("type"); decisionType != "" { + queryParams = append(queryParams, "type="+decisionType) + } + + // Build request URL + reqURL := strings.TrimRight(lapiURL, "/") + "/v1/decisions" + if len(queryParams) > 0 { + reqURL += "?" + strings.Join(queryParams, "&") + } + + // Get API key + apiKey := getLAPIKey() + + // Create HTTP request with timeout + ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, http.NoBody) + if err != nil { + logger.Log().WithError(err).Warn("Failed to create LAPI decisions request") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"}) + return + } + + // Add authentication header if API key is available + if apiKey != "" { + req.Header.Set("X-Api-Key", apiKey) + } + req.Header.Set("Accept", "application/json") + + // Execute request + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + logger.Log().WithError(err).WithField("lapi_url", lapiURL).Warn("Failed to query LAPI decisions") + // Fallback to cscli-based method + h.ListDecisions(c) + return + } + defer resp.Body.Close() + + // Handle non-200 responses + if resp.StatusCode == http.StatusUnauthorized { + c.JSON(http.StatusUnauthorized, gin.H{"error": "LAPI authentication failed - check API key configuration"}) + return + } + if resp.StatusCode != http.StatusOK { + logger.Log().WithField("status", resp.StatusCode).WithField("lapi_url", lapiURL).Warn("LAPI returned non-OK status") + // Fallback to cscli-based method + h.ListDecisions(c) + return + } + + // Check content-type to ensure we're getting JSON (not HTML from a proxy/frontend) + contentType := resp.Header.Get("Content-Type") + if contentType != "" && !strings.Contains(contentType, "application/json") { + logger.Log().WithField("content_type", contentType).WithField("lapi_url", lapiURL).Warn("LAPI returned non-JSON content-type, falling back to cscli") + // Fallback to cscli-based method + h.ListDecisions(c) + return + } + + // Parse response body + body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) // 10MB limit + if err != nil { + logger.Log().WithError(err).Warn("Failed to read LAPI response") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"}) + return + } + + // Handle null/empty responses + if len(body) == 0 || string(body) == "null" || string(body) == "null\n" { + c.JSON(http.StatusOK, gin.H{"decisions": []CrowdSecDecision{}, "total": 0, "source": "lapi"}) + return + } + + // Parse JSON + var lapiDecisions []lapiDecision + if err := json.Unmarshal(body, &lapiDecisions); err != nil { + logger.Log().WithError(err).WithField("body", string(body)).Warn("Failed to parse LAPI decisions") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to parse LAPI response"}) + return + } + + // Convert to our format + decisions := make([]CrowdSecDecision, 0, len(lapiDecisions)) + for _, d := range lapiDecisions { + var createdAt time.Time + if d.CreatedAt != "" { + createdAt, _ = time.Parse(time.RFC3339, d.CreatedAt) + } + decisions = append(decisions, CrowdSecDecision{ + ID: d.ID, + Origin: d.Origin, + Type: d.Type, + Scope: d.Scope, + Value: d.Value, + Duration: d.Duration, + Scenario: d.Scenario, + CreatedAt: createdAt, + Until: d.Until, + }) + } + + c.JSON(http.StatusOK, gin.H{"decisions": decisions, "total": len(decisions), "source": "lapi"}) +} + +// getLAPIKey retrieves the LAPI API key from environment variables. +func getLAPIKey() string { + envVars := []string{ + "CROWDSEC_API_KEY", + "CROWDSEC_BOUNCER_API_KEY", + "CERBERUS_SECURITY_CROWDSEC_API_KEY", + "CHARON_SECURITY_CROWDSEC_API_KEY", + "CPM_SECURITY_CROWDSEC_API_KEY", + } + for _, key := range envVars { + if val := os.Getenv(key); val != "" { + return val + } + } + return "" +} + +// CheckLAPIHealth verifies that CrowdSec LAPI is responding. +func (h *CrowdsecHandler) CheckLAPIHealth(c *gin.Context) { + // Get LAPI URL from security config or use default + lapiURL := "http://localhost:8080" + if h.Security != nil { + cfg, err := h.Security.Get() + if err == nil && cfg != nil && cfg.CrowdSecAPIURL != "" { + lapiURL = cfg.CrowdSecAPIURL + } + } + + // Create health check request + ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) + defer cancel() + + healthURL := strings.TrimRight(lapiURL, "/") + "/health" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, http.NoBody) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"healthy": false, "error": "failed to create request"}) + return + } + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + // Try decisions endpoint as fallback health check + decisionsURL := strings.TrimRight(lapiURL, "/") + "/v1/decisions" + req2, _ := http.NewRequestWithContext(ctx, http.MethodHead, decisionsURL, http.NoBody) + resp2, err2 := client.Do(req2) + if err2 != nil { + c.JSON(http.StatusOK, gin.H{"healthy": false, "error": "LAPI unreachable", "lapi_url": lapiURL}) + return + } + defer resp2.Body.Close() + // 401 is expected without auth but indicates LAPI is running + if resp2.StatusCode == http.StatusOK || resp2.StatusCode == http.StatusUnauthorized { + c.JSON(http.StatusOK, gin.H{"healthy": true, "lapi_url": lapiURL, "note": "health endpoint unavailable, verified via decisions endpoint"}) + return + } + c.JSON(http.StatusOK, gin.H{"healthy": false, "error": "unexpected status", "status": resp2.StatusCode, "lapi_url": lapiURL}) + return + } + defer resp.Body.Close() + + c.JSON(http.StatusOK, gin.H{"healthy": resp.StatusCode == http.StatusOK, "lapi_url": lapiURL, "status": resp.StatusCode}) +} + // ListDecisions calls cscli to get current decisions (banned IPs) func (h *CrowdsecHandler) ListDecisions(c *gin.Context) { ctx := c.Request.Context() @@ -1023,6 +1231,8 @@ func (h *CrowdsecHandler) RegisterRoutes(rg *gin.RouterGroup) { rg.GET("/admin/crowdsec/console/status", h.ConsoleStatus) // Decision management endpoints (Banned IP Dashboard) rg.GET("/admin/crowdsec/decisions", h.ListDecisions) + rg.GET("/admin/crowdsec/decisions/lapi", h.GetLAPIDecisions) + rg.GET("/admin/crowdsec/lapi/health", h.CheckLAPIHealth) rg.POST("/admin/crowdsec/ban", h.BanIP) rg.DELETE("/admin/crowdsec/ban/:ip", h.UnbanIP) } diff --git a/backend/internal/api/handlers/crowdsec_handler_test.go b/backend/internal/api/handlers/crowdsec_handler_test.go index fdbc617b..03cc2362 100644 --- a/backend/internal/api/handlers/crowdsec_handler_test.go +++ b/backend/internal/api/handlers/crowdsec_handler_test.go @@ -15,7 +15,9 @@ import ( "path/filepath" "strings" "testing" + "time" + "github.com/Wikid82/charon/backend/internal/crowdsec" "github.com/Wikid82/charon/backend/internal/models" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -519,3 +521,457 @@ func TestIsCerberusEnabledLegacyEnv(t *testing.T) { t.Fatalf("expected cerberus to be disabled for legacy env flag") } } + +// ============================================ +// Console Enrollment Tests +// ============================================ + +type mockEnvExecutor struct { + responses []struct { + out []byte + err error + } + defaultResponse struct { + out []byte + err error + } + calls []struct { + name string + args []string + } +} + +func (m *mockEnvExecutor) ExecuteWithEnv(ctx context.Context, name string, args []string, env map[string]string) ([]byte, error) { + m.calls = append(m.calls, struct { + name string + args []string + }{name, args}) + + if len(m.calls) <= len(m.responses) { + resp := m.responses[len(m.calls)-1] + return resp.out, resp.err + } + return m.defaultResponse.out, m.defaultResponse.err +} + +func setupTestConsoleEnrollment(t *testing.T) (*CrowdsecHandler, *mockEnvExecutor) { + t.Helper() + gin.SetMode(gin.TestMode) + db := OpenTestDB(t) + require.NoError(t, db.AutoMigrate(&models.CrowdsecConsoleEnrollment{})) + + exec := &mockEnvExecutor{} + dataDir := t.TempDir() + + h := NewCrowdsecHandler(db, &fakeExec{}, "/bin/false", dataDir) + // Replace the Console service with one that uses our mock executor + h.Console = crowdsec.NewConsoleEnrollmentService(db, exec, dataDir, "test-secret") + + return h, exec +} + +func TestConsoleEnrollDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "false") + + h := NewCrowdsecHandler(OpenTestDB(t), &fakeExec{}, "/bin/false", t.TempDir()) + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + body := `{"enrollment_key": "abc123456789", "agent_name": "test-agent"}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/crowdsec/console/enroll", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusNotFound, w.Code) + require.Contains(t, w.Body.String(), "disabled") +} + +func TestConsoleEnrollServiceUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h := NewCrowdsecHandler(OpenTestDB(t), &fakeExec{}, "/bin/false", t.TempDir()) + // Set Console to nil to simulate unavailable + h.Console = nil + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + body := `{"enrollment_key": "abc123456789", "agent_name": "test-agent"}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/crowdsec/console/enroll", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusServiceUnavailable, w.Code) + require.Contains(t, w.Body.String(), "unavailable") +} + +func TestConsoleEnrollInvalidPayload(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h, _ := setupTestConsoleEnrollment(t) + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/crowdsec/console/enroll", strings.NewReader("not-json")) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "invalid payload") +} + +func TestConsoleEnrollSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h, _ := setupTestConsoleEnrollment(t) + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + body := `{"enrollment_key": "abc123456789", "agent_name": "test-agent", "tenant": "my-tenant"}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/crowdsec/console/enroll", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Equal(t, "enrolled", resp["status"]) +} + +func TestConsoleEnrollMissingAgentName(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h, _ := setupTestConsoleEnrollment(t) + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + body := `{"enrollment_key": "abc123456789", "agent_name": ""}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/crowdsec/console/enroll", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "required") +} + +func TestConsoleStatusDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "false") + + h := NewCrowdsecHandler(OpenTestDB(t), &fakeExec{}, "/bin/false", t.TempDir()) + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/crowdsec/console/status", http.NoBody) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusNotFound, w.Code) + require.Contains(t, w.Body.String(), "disabled") +} + +func TestConsoleStatusServiceUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h := NewCrowdsecHandler(OpenTestDB(t), &fakeExec{}, "/bin/false", t.TempDir()) + // Set Console to nil to simulate unavailable + h.Console = nil + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/crowdsec/console/status", http.NoBody) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusServiceUnavailable, w.Code) + require.Contains(t, w.Body.String(), "unavailable") +} + +func TestConsoleStatusSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h, _ := setupTestConsoleEnrollment(t) + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + // Get status when not enrolled yet + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/crowdsec/console/status", http.NoBody) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Equal(t, "not_enrolled", resp["status"]) +} + +func TestConsoleStatusAfterEnroll(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h, _ := setupTestConsoleEnrollment(t) + r := gin.New() + g := r.Group("/api/v1") + h.RegisterRoutes(g) + + // First enroll + body := `{"enrollment_key": "abc123456789", "agent_name": "test-agent"}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/crowdsec/console/enroll", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + // Then check status + w2 := httptest.NewRecorder() + req2 := httptest.NewRequest(http.MethodGet, "/api/v1/admin/crowdsec/console/status", http.NoBody) + r.ServeHTTP(w2, req2) + + require.Equal(t, http.StatusOK, w2.Code) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w2.Body.Bytes(), &resp)) + require.Equal(t, "enrolled", resp["status"]) + require.Equal(t, "test-agent", resp["agent_name"]) +} + +// ============================================ +// isConsoleEnrollmentEnabled Tests +// ============================================ + +func TestIsConsoleEnrollmentEnabledFromDB(t *testing.T) { + gin.SetMode(gin.TestMode) + db := OpenTestDB(t) + require.NoError(t, db.AutoMigrate(&models.Setting{})) + require.NoError(t, db.Create(&models.Setting{Key: "feature.crowdsec.console_enrollment", Value: "true"}).Error) + + h := NewCrowdsecHandler(db, &fakeExec{}, "/bin/false", t.TempDir()) + require.True(t, h.isConsoleEnrollmentEnabled()) +} + +func TestIsConsoleEnrollmentDisabledFromDB(t *testing.T) { + gin.SetMode(gin.TestMode) + db := OpenTestDB(t) + require.NoError(t, db.AutoMigrate(&models.Setting{})) + require.NoError(t, db.Create(&models.Setting{Key: "feature.crowdsec.console_enrollment", Value: "false"}).Error) + + h := NewCrowdsecHandler(db, &fakeExec{}, "/bin/false", t.TempDir()) + require.False(t, h.isConsoleEnrollmentEnabled()) +} + +func TestIsConsoleEnrollmentEnabledFromEnv(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "true") + + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + require.True(t, h.isConsoleEnrollmentEnabled()) +} + +func TestIsConsoleEnrollmentDisabledFromEnv(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "0") + + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + require.False(t, h.isConsoleEnrollmentEnabled()) +} + +func TestIsConsoleEnrollmentInvalidEnv(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("FEATURE_CROWDSEC_CONSOLE_ENROLLMENT", "invalid") + + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + require.False(t, h.isConsoleEnrollmentEnabled()) +} + +func TestIsConsoleEnrollmentDefaultDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + require.False(t, h.isConsoleEnrollmentEnabled()) +} + +func TestIsConsoleEnrollmentDBTrueVariants(t *testing.T) { + tests := []struct { + value string + expected bool + }{ + {"true", true}, + {"TRUE", true}, + {"True", true}, + {"1", true}, + {"yes", true}, + {"YES", true}, + {"false", false}, + {"FALSE", false}, + {"0", false}, + {"no", false}, + } + + for _, tc := range tests { + t.Run(tc.value, func(t *testing.T) { + gin.SetMode(gin.TestMode) + db := OpenTestDB(t) + require.NoError(t, db.AutoMigrate(&models.Setting{})) + require.NoError(t, db.Create(&models.Setting{Key: "feature.crowdsec.console_enrollment", Value: tc.value}).Error) + + h := NewCrowdsecHandler(db, &fakeExec{}, "/bin/false", t.TempDir()) + require.Equal(t, tc.expected, h.isConsoleEnrollmentEnabled(), "value %q", tc.value) + }) + } +} + +// ============================================ +// actorFromContext Tests +// ============================================ + +func TestActorFromContextWithUserID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Set("userID", "user-123") + + actor := actorFromContext(c) + require.Equal(t, "user:user-123", actor) +} + +func TestActorFromContextWithNumericUserID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Set("userID", 456) + + actor := actorFromContext(c) + require.Equal(t, "user:456", actor) +} + +func TestActorFromContextNoUser(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + actor := actorFromContext(c) + require.Equal(t, "unknown", actor) +} + +// ============================================ +// ttlRemainingSeconds Tests +// ============================================ + +func TestTTLRemainingSeconds(t *testing.T) { + now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + retrieved := time.Date(2024, 1, 1, 11, 0, 0, 0, time.UTC) // 1 hour ago + cacheTTL := 2 * time.Hour + + // Should have 1 hour remaining + remaining := ttlRemainingSeconds(now, retrieved, cacheTTL) + require.NotNil(t, remaining) + require.Equal(t, int64(3600), *remaining) // 1 hour in seconds +} + +func TestTTLRemainingSecondsExpired(t *testing.T) { + now := time.Date(2024, 1, 1, 14, 0, 0, 0, time.UTC) + retrieved := time.Date(2024, 1, 1, 11, 0, 0, 0, time.UTC) // 3 hours ago + cacheTTL := 2 * time.Hour + + // Should be expired (negative or zero) + remaining := ttlRemainingSeconds(now, retrieved, cacheTTL) + require.NotNil(t, remaining) + require.Equal(t, int64(0), *remaining) +} + +func TestTTLRemainingSecondsZeroTime(t *testing.T) { + now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + var retrieved time.Time // zero time + cacheTTL := 2 * time.Hour + + // With zero time, should return nil + remaining := ttlRemainingSeconds(now, retrieved, cacheTTL) + require.Nil(t, remaining) +} + +func TestTTLRemainingSecondsZeroTTL(t *testing.T) { + now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + retrieved := time.Date(2024, 1, 1, 11, 0, 0, 0, time.UTC) + cacheTTL := time.Duration(0) + + remaining := ttlRemainingSeconds(now, retrieved, cacheTTL) + require.Nil(t, remaining) +} + +// ============================================ +// hubEndpoints Tests +// ============================================ + +func TestHubEndpointsNil(t *testing.T) { + gin.SetMode(gin.TestMode) + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + h.Hub = nil + + endpoints := h.hubEndpoints() + require.Nil(t, endpoints) +} + +func TestHubEndpointsDeduplicates(t *testing.T) { + gin.SetMode(gin.TestMode) + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + // Hub is created by NewCrowdsecHandler, modify its fields + if h.Hub != nil { + h.Hub.HubBaseURL = "https://hub.crowdsec.net" + h.Hub.MirrorBaseURL = "https://hub.crowdsec.net" // Same URL + } + + endpoints := h.hubEndpoints() + require.Len(t, endpoints, 1) + require.Equal(t, "https://hub.crowdsec.net", endpoints[0]) +} + +func TestHubEndpointsMultiple(t *testing.T) { + gin.SetMode(gin.TestMode) + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + if h.Hub != nil { + h.Hub.HubBaseURL = "https://hub.crowdsec.net" + h.Hub.MirrorBaseURL = "https://mirror.example.com" + } + + endpoints := h.hubEndpoints() + require.Len(t, endpoints, 2) + require.Contains(t, endpoints, "https://hub.crowdsec.net") + require.Contains(t, endpoints, "https://mirror.example.com") +} + +func TestHubEndpointsSkipsEmpty(t *testing.T) { + gin.SetMode(gin.TestMode) + h := NewCrowdsecHandler(nil, &fakeExec{}, "/bin/false", t.TempDir()) + if h.Hub != nil { + h.Hub.HubBaseURL = "https://hub.crowdsec.net" + h.Hub.MirrorBaseURL = "" // Empty + } + + endpoints := h.hubEndpoints() + require.Len(t, endpoints, 1) + require.Equal(t, "https://hub.crowdsec.net", endpoints[0]) +} diff --git a/backend/internal/api/handlers/crowdsec_lapi_test.go b/backend/internal/api/handlers/crowdsec_lapi_test.go new file mode 100644 index 00000000..b120a8a8 --- /dev/null +++ b/backend/internal/api/handlers/crowdsec_lapi_test.go @@ -0,0 +1,142 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestGetLAPIDecisions_FallbackToCscli(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + + // Create handler with mock executor + handler := &CrowdsecHandler{ + CmdExec: &mockCommandExecutor{output: []byte(`[]`), err: nil}, + DataDir: t.TempDir(), + } + + router.GET("/admin/crowdsec/decisions/lapi", handler.GetLAPIDecisions) + + // This test will fallback to cscli since localhost:8080 LAPI is not running + req := httptest.NewRequest(http.MethodGet, "/admin/crowdsec/decisions/lapi", http.NoBody) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // Should return success (from cscli fallback) + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + // Should have decisions array (empty from mock) + _, hasDecisions := response["decisions"] + assert.True(t, hasDecisions) +} + +func TestGetLAPIDecisions_EmptyResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + + // Create handler with mock executor that returns empty array + handler := &CrowdsecHandler{ + CmdExec: &mockCommandExecutor{output: []byte(`[]`), err: nil}, + DataDir: t.TempDir(), + } + + router.GET("/admin/crowdsec/decisions/lapi", handler.GetLAPIDecisions) + + req := httptest.NewRequest(http.MethodGet, "/admin/crowdsec/decisions/lapi", http.NoBody) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + // Will fallback to cscli which returns empty + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + // Should have decisions array (may be empty) + _, hasDecisions := response["decisions"] + assert.True(t, hasDecisions) +} + +func TestCheckLAPIHealth_Handler(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + + handler := &CrowdsecHandler{ + CmdExec: &mockCommandExecutor{output: []byte(`[]`), err: nil}, + DataDir: t.TempDir(), + } + + router.GET("/admin/crowdsec/lapi/health", handler.CheckLAPIHealth) + + req := httptest.NewRequest(http.MethodGet, "/admin/crowdsec/lapi/health", http.NoBody) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + // Should have healthy field + _, hasHealthy := response["healthy"] + assert.True(t, hasHealthy) + + // Should have lapi_url field + _, hasURL := response["lapi_url"] + assert.True(t, hasURL) +} + +func TestGetLAPIKey_FromEnv(t *testing.T) { + // Save and restore original env + original := os.Getenv("CROWDSEC_API_KEY") + defer func() { + if original != "" { + _ = os.Setenv("CROWDSEC_API_KEY", original) + } else { + _ = os.Unsetenv("CROWDSEC_API_KEY") + } + }() + + // Set test value + _ = os.Setenv("CROWDSEC_API_KEY", "test-key-123") + + key := getLAPIKey() + assert.Equal(t, "test-key-123", key) +} + +func TestGetLAPIKey_Empty(t *testing.T) { + // Save and restore original env vars + envVars := []string{ + "CROWDSEC_API_KEY", + "CROWDSEC_BOUNCER_API_KEY", + "CERBERUS_SECURITY_CROWDSEC_API_KEY", + "CHARON_SECURITY_CROWDSEC_API_KEY", + "CPM_SECURITY_CROWDSEC_API_KEY", + } + + originals := make(map[string]string) + for _, key := range envVars { + originals[key] = os.Getenv(key) + _ = os.Unsetenv(key) + } + defer func() { + for key, val := range originals { + if val != "" { + _ = os.Setenv(key, val) + } + } + }() + + key := getLAPIKey() + assert.Empty(t, key) +} diff --git a/backend/internal/api/handlers/security_geoip_endpoints_test.go b/backend/internal/api/handlers/security_geoip_endpoints_test.go new file mode 100644 index 00000000..086fc5bb --- /dev/null +++ b/backend/internal/api/handlers/security_geoip_endpoints_test.go @@ -0,0 +1,122 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wikid82/charon/backend/internal/config" + "github.com/Wikid82/charon/backend/internal/services" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSecurityHandler_GetGeoIPStatus_NotInitialized(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewSecurityHandler(config.SecurityConfig{}, nil, nil) + r := gin.New() + r.GET("/security/geoip/status", h.GetGeoIPStatus) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/security/geoip/status", http.NoBody) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + assert.Equal(t, false, body["loaded"]) + assert.Equal(t, "GeoIP service not initialized", body["message"]) +} + +func TestSecurityHandler_GetGeoIPStatus_Initialized_NotLoaded(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewSecurityHandler(config.SecurityConfig{}, nil, nil) + h.SetGeoIPService(&services.GeoIPService{}) + + r := gin.New() + r.GET("/security/geoip/status", h.GetGeoIPStatus) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/security/geoip/status", http.NoBody) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + assert.Equal(t, false, body["loaded"]) + assert.Equal(t, "GeoIP service available", body["message"]) +} + +func TestSecurityHandler_ReloadGeoIP_NotInitialized(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewSecurityHandler(config.SecurityConfig{}, nil, nil) + r := gin.New() + r.POST("/security/geoip/reload", h.ReloadGeoIP) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/security/geoip/reload", http.NoBody) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusServiceUnavailable, w.Code) +} + +func TestSecurityHandler_ReloadGeoIP_LoadError(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewSecurityHandler(config.SecurityConfig{}, nil, nil) + h.SetGeoIPService(&services.GeoIPService{}) // dbPath empty => Load() will error + + r := gin.New() + r.POST("/security/geoip/reload", h.ReloadGeoIP) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/security/geoip/reload", http.NoBody) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Contains(t, w.Body.String(), "Failed to reload GeoIP database") +} + +func TestSecurityHandler_LookupGeoIP_MissingIPAddress(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewSecurityHandler(config.SecurityConfig{}, nil, nil) + r := gin.New() + r.POST("/security/geoip/lookup", h.LookupGeoIP) + + payload := []byte(`{}`) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/security/geoip/lookup", bytes.NewReader(payload)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "ip_address is required") +} + +func TestSecurityHandler_LookupGeoIP_ServiceUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewSecurityHandler(config.SecurityConfig{}, nil, nil) + h.SetGeoIPService(&services.GeoIPService{}) // present but not loaded + + r := gin.New() + r.POST("/security/geoip/lookup", h.LookupGeoIP) + + payload, _ := json.Marshal(map[string]string{"ip_address": "8.8.8.8"}) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/security/geoip/lookup", bytes.NewReader(payload)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + assert.Contains(t, w.Body.String(), "GeoIP service not available") +} diff --git a/backend/internal/api/handlers/security_handler.go b/backend/internal/api/handlers/security_handler.go index 3bc574e6..d81e575f 100644 --- a/backend/internal/api/handlers/security_handler.go +++ b/backend/internal/api/handlers/security_handler.go @@ -1,6 +1,7 @@ package handlers import ( + "encoding/json" "errors" "net" "net/http" @@ -17,12 +18,27 @@ import ( "github.com/Wikid82/charon/backend/internal/services" ) +// WAFExclusionRequest represents a rule exclusion for false positives +type WAFExclusionRequest struct { + RuleID int `json:"rule_id" binding:"required"` + Target string `json:"target,omitempty"` // e.g., "ARGS:password" + Description string `json:"description,omitempty"` // Human-readable reason +} + +// WAFExclusion represents a stored rule exclusion +type WAFExclusion struct { + RuleID int `json:"rule_id"` + Target string `json:"target,omitempty"` + Description string `json:"description,omitempty"` +} + // SecurityHandler handles security-related API requests. type SecurityHandler struct { cfg config.SecurityConfig db *gorm.DB svc *services.SecurityService caddyManager *caddy.Manager + geoipSvc *services.GeoIPService } // NewSecurityHandler creates a new SecurityHandler. @@ -31,6 +47,11 @@ func NewSecurityHandler(cfg config.SecurityConfig, db *gorm.DB, caddyManager *ca return &SecurityHandler{cfg: cfg, db: db, svc: svc, caddyManager: caddyManager} } +// SetGeoIPService sets the GeoIP service for the handler. +func (h *SecurityHandler) SetGeoIPService(geoipSvc *services.GeoIPService) { + h.geoipSvc = geoipSvc +} + // GetStatus returns the current status of all security services. func (h *SecurityHandler) GetStatus(c *gin.Context) { enabled := h.cfg.CerberusEnabled @@ -443,3 +464,323 @@ func (h *SecurityHandler) Disable(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{"enabled": false}) } + +// GetRateLimitPresets returns predefined rate limit configurations +func (h *SecurityHandler) GetRateLimitPresets(c *gin.Context) { + presets := []map[string]interface{}{ + { + "id": "standard", + "name": "Standard Web", + "description": "Balanced protection for general web applications", + "requests": 100, + "window_sec": 60, + "burst": 20, + }, + { + "id": "api", + "name": "API Protection", + "description": "Stricter limits for API endpoints", + "requests": 30, + "window_sec": 60, + "burst": 10, + }, + { + "id": "login", + "name": "Login Protection", + "description": "Aggressive protection against brute-force", + "requests": 5, + "window_sec": 300, + "burst": 2, + }, + { + "id": "relaxed", + "name": "High Traffic", + "description": "Higher limits for trusted, high-traffic apps", + "requests": 500, + "window_sec": 60, + "burst": 100, + }, + } + c.JSON(http.StatusOK, gin.H{"presets": presets}) +} + +// GetGeoIPStatus returns the current status of the GeoIP service. +func (h *SecurityHandler) GetGeoIPStatus(c *gin.Context) { + if h.geoipSvc == nil { + c.JSON(http.StatusOK, gin.H{ + "loaded": false, + "message": "GeoIP service not initialized", + "db_path": "", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "loaded": h.geoipSvc.IsLoaded(), + "db_path": h.geoipSvc.GetDatabasePath(), + "message": "GeoIP service available", + }) +} + +// ReloadGeoIP reloads the GeoIP database from disk. +func (h *SecurityHandler) ReloadGeoIP(c *gin.Context) { + if h.geoipSvc == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "GeoIP service not initialized", + }) + return + } + + if err := h.geoipSvc.Load(); err != nil { + log.WithError(err).Error("Failed to reload GeoIP database") + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to reload GeoIP database: " + err.Error(), + }) + return + } + + // Log audit event + actor := c.GetString("user_id") + if actor == "" { + actor = c.ClientIP() + } + _ = h.svc.LogAudit(&models.SecurityAudit{Actor: actor, Action: "reload_geoip", Details: "GeoIP database reloaded successfully"}) + + c.JSON(http.StatusOK, gin.H{ + "message": "GeoIP database reloaded successfully", + "loaded": h.geoipSvc.IsLoaded(), + "db_path": h.geoipSvc.GetDatabasePath(), + }) +} + +// LookupGeoIP performs a GeoIP lookup for a given IP address. +func (h *SecurityHandler) LookupGeoIP(c *gin.Context) { + var req struct { + IPAddress string `json:"ip_address" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "ip_address is required"}) + return + } + + if h.geoipSvc == nil || !h.geoipSvc.IsLoaded() { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "GeoIP service not available", + }) + return + } + + country, err := h.geoipSvc.LookupCountry(req.IPAddress) + if err != nil { + if errors.Is(err, services.ErrInvalidGeoIP) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid IP address"}) + return + } + if errors.Is(err, services.ErrCountryNotFound) { + c.JSON(http.StatusOK, gin.H{ + "ip_address": req.IPAddress, + "country_code": "", + "found": false, + "message": "No country found for this IP address", + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "GeoIP lookup failed: " + err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "ip_address": req.IPAddress, + "country_code": country, + "found": true, + }) +} + +// GetWAFExclusions returns current WAF rule exclusions from SecurityConfig +func (h *SecurityHandler) GetWAFExclusions(c *gin.Context) { + cfg, err := h.svc.Get() + if err != nil { + if err == services.ErrSecurityConfigNotFound { + c.JSON(http.StatusOK, gin.H{"exclusions": []WAFExclusion{}}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read security config"}) + return + } + + var exclusions []WAFExclusion + if cfg.WAFExclusions != "" { + if err := json.Unmarshal([]byte(cfg.WAFExclusions), &exclusions); err != nil { + log.WithError(err).Warn("Failed to parse WAF exclusions") + exclusions = []WAFExclusion{} + } + } + + c.JSON(http.StatusOK, gin.H{"exclusions": exclusions}) +} + +// AddWAFExclusion adds a rule exclusion to the WAF configuration +func (h *SecurityHandler) AddWAFExclusion(c *gin.Context) { + var req WAFExclusionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "rule_id is required"}) + return + } + + if req.RuleID <= 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "rule_id must be a positive integer"}) + return + } + + cfg, err := h.svc.Get() + if err != nil { + if err == services.ErrSecurityConfigNotFound { + // Create default config with the exclusion + cfg = &models.SecurityConfig{Name: "default"} + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read security config"}) + return + } + } + + // Parse existing exclusions + var exclusions []WAFExclusion + if cfg.WAFExclusions != "" { + if err := json.Unmarshal([]byte(cfg.WAFExclusions), &exclusions); err != nil { + log.WithError(err).Warn("Failed to parse existing WAF exclusions") + exclusions = []WAFExclusion{} + } + } + + // Check for duplicate rule_id with same target + for _, e := range exclusions { + if e.RuleID == req.RuleID && e.Target == req.Target { + c.JSON(http.StatusConflict, gin.H{"error": "exclusion for this rule_id and target already exists"}) + return + } + } + + // Add the new exclusion - convert request to WAFExclusion type + newExclusion := WAFExclusion(req) + exclusions = append(exclusions, newExclusion) + + // Marshal back to JSON + exclusionsJSON, err := json.Marshal(exclusions) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to serialize exclusions"}) + return + } + + cfg.WAFExclusions = string(exclusionsJSON) + if err := h.svc.Upsert(cfg); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save exclusion"}) + return + } + + // Apply updated config to Caddy + if h.caddyManager != nil { + if err := h.caddyManager.ApplyConfig(c.Request.Context()); err != nil { + log.WithError(err).Warn("failed to apply WAF exclusion changes to Caddy") + } + } + + // Log audit event + actor := c.GetString("user_id") + if actor == "" { + actor = c.ClientIP() + } + _ = h.svc.LogAudit(&models.SecurityAudit{ + Actor: actor, + Action: "add_waf_exclusion", + Details: strconv.Itoa(req.RuleID), + }) + + c.JSON(http.StatusOK, gin.H{"exclusion": newExclusion}) +} + +// DeleteWAFExclusion removes a rule exclusion by rule_id +func (h *SecurityHandler) DeleteWAFExclusion(c *gin.Context) { + ruleIDParam := c.Param("rule_id") + if ruleIDParam == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "rule_id is required"}) + return + } + + ruleID, err := strconv.Atoi(ruleIDParam) + if err != nil || ruleID <= 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid rule_id"}) + return + } + + // Get optional target query parameter (for exclusions with specific targets) + target := c.Query("target") + + cfg, err := h.svc.Get() + if err != nil { + if err == services.ErrSecurityConfigNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "exclusion not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read security config"}) + return + } + + // Parse existing exclusions + var exclusions []WAFExclusion + if cfg.WAFExclusions != "" { + if err := json.Unmarshal([]byte(cfg.WAFExclusions), &exclusions); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to parse exclusions"}) + return + } + } + + // Find and remove the exclusion + found := false + newExclusions := make([]WAFExclusion, 0, len(exclusions)) + for _, e := range exclusions { + // Match by rule_id and target (empty target matches exclusions without target) + if e.RuleID == ruleID && e.Target == target { + found = true + continue // Skip this one (delete it) + } + newExclusions = append(newExclusions, e) + } + + if !found { + c.JSON(http.StatusNotFound, gin.H{"error": "exclusion not found"}) + return + } + + // Marshal back to JSON + exclusionsJSON, err := json.Marshal(newExclusions) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to serialize exclusions"}) + return + } + + cfg.WAFExclusions = string(exclusionsJSON) + if err := h.svc.Upsert(cfg); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save exclusions"}) + return + } + + // Apply updated config to Caddy + if h.caddyManager != nil { + if err := h.caddyManager.ApplyConfig(c.Request.Context()); err != nil { + log.WithError(err).Warn("failed to apply WAF exclusion changes to Caddy") + } + } + + // Log audit event + actor := c.GetString("user_id") + if actor == "" { + actor = c.ClientIP() + } + _ = h.svc.LogAudit(&models.SecurityAudit{ + Actor: actor, + Action: "delete_waf_exclusion", + Details: ruleIDParam, + }) + + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} diff --git a/backend/internal/api/handlers/security_handler_fixed_test.go b/backend/internal/api/handlers/security_handler_fixed_test.go new file mode 100644 index 00000000..768c1952 --- /dev/null +++ b/backend/internal/api/handlers/security_handler_fixed_test.go @@ -0,0 +1,112 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + + "github.com/Wikid82/charon/backend/internal/config" +) + +func TestSecurityHandler_GetStatus_Fixed(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + cfg config.SecurityConfig + expectedStatus int + expectedBody map[string]interface{} + }{ + { + name: "All Disabled", + cfg: config.SecurityConfig{ + CrowdSecMode: "disabled", + WAFMode: "disabled", + RateLimitMode: "disabled", + ACLMode: "disabled", + }, + expectedStatus: http.StatusOK, + expectedBody: map[string]interface{}{ + "cerberus": map[string]interface{}{"enabled": false}, + "crowdsec": map[string]interface{}{ + "mode": "disabled", + "api_url": "", + "enabled": false, + }, + "waf": map[string]interface{}{ + "mode": "disabled", + "enabled": false, + }, + "rate_limit": map[string]interface{}{ + "mode": "disabled", + "enabled": false, + }, + "acl": map[string]interface{}{ + "mode": "disabled", + "enabled": false, + }, + }, + }, + { + name: "All Enabled", + cfg: config.SecurityConfig{ + CerberusEnabled: true, // Required for ACL to be effective + CrowdSecMode: "local", + WAFMode: "enabled", + RateLimitMode: "enabled", + ACLMode: "enabled", + }, + expectedStatus: http.StatusOK, + expectedBody: map[string]interface{}{ + "cerberus": map[string]interface{}{"enabled": true}, + "crowdsec": map[string]interface{}{ + "mode": "local", + "api_url": "", + "enabled": true, + }, + "waf": map[string]interface{}{ + "mode": "enabled", + "enabled": true, + }, + "rate_limit": map[string]interface{}{ + "mode": "enabled", + "enabled": true, + }, + "acl": map[string]interface{}{ + "mode": "enabled", + "enabled": true, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := NewSecurityHandler(tt.cfg, nil, nil) + router := gin.New() + router.GET("/security/status", handler.GetStatus) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/status", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + + expectedJSON, _ := json.Marshal(tt.expectedBody) + var expectedNormalized map[string]interface{} + if err := json.Unmarshal(expectedJSON, &expectedNormalized); err != nil { + t.Fatalf("failed to unmarshal expected JSON: %v", err) + } + + assert.Equal(t, expectedNormalized, response) + }) + } +} diff --git a/backend/internal/api/handlers/security_handler_test_fixed.go b/backend/internal/api/handlers/security_handler_test_fixed.go index 23bf1efb..779b1b88 100644 --- a/backend/internal/api/handlers/security_handler_test_fixed.go +++ b/backend/internal/api/handlers/security_handler_test_fixed.go @@ -1,3 +1,6 @@ +//go:build ignore +// +build ignore + package handlers import ( diff --git a/backend/internal/api/handlers/security_handler_waf_test.go b/backend/internal/api/handlers/security_handler_waf_test.go new file mode 100644 index 00000000..12fbc3e5 --- /dev/null +++ b/backend/internal/api/handlers/security_handler_waf_test.go @@ -0,0 +1,691 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/Wikid82/charon/backend/internal/config" + "github.com/Wikid82/charon/backend/internal/models" +) + +// Tests for GetWAFExclusions handler +func TestSecurityHandler_GetWAFExclusions_Empty(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.GET("/security/waf/exclusions", handler.GetWAFExclusions) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + exclusions := resp["exclusions"].([]interface{}) + assert.Len(t, exclusions, 0) +} + +func TestSecurityHandler_GetWAFExclusions_WithExclusions(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + // Create config with exclusions + exclusionsJSON := `[{"rule_id":942100,"description":"SQL Injection rule"},{"rule_id":941100,"target":"ARGS:password"}]` + cfg := models.SecurityConfig{Name: "default", WAFExclusions: exclusionsJSON} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.GET("/security/waf/exclusions", handler.GetWAFExclusions) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + exclusions := resp["exclusions"].([]interface{}) + assert.Len(t, exclusions, 2) + + // Verify first exclusion + first := exclusions[0].(map[string]interface{}) + assert.Equal(t, float64(942100), first["rule_id"]) + assert.Equal(t, "SQL Injection rule", first["description"]) +} + +func TestSecurityHandler_GetWAFExclusions_InvalidJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + // Create config with invalid JSON + cfg := models.SecurityConfig{Name: "default", WAFExclusions: "invalid json"} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.GET("/security/waf/exclusions", handler.GetWAFExclusions) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + + // Should return empty array on parse failure + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + exclusions := resp["exclusions"].([]interface{}) + assert.Len(t, exclusions, 0) +} + +// Tests for AddWAFExclusion handler +func TestSecurityHandler_AddWAFExclusion_Success(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + payload := map[string]interface{}{ + "rule_id": 942100, + "description": "SQL Injection false positive", + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + + exclusion := resp["exclusion"].(map[string]interface{}) + assert.Equal(t, float64(942100), exclusion["rule_id"]) + assert.Equal(t, "SQL Injection false positive", exclusion["description"]) +} + +func TestSecurityHandler_AddWAFExclusion_WithTarget(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + payload := map[string]interface{}{ + "rule_id": 942100, + "target": "ARGS:password", + "description": "Skip password field for SQL injection", + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + + exclusion := resp["exclusion"].(map[string]interface{}) + assert.Equal(t, "ARGS:password", exclusion["target"]) +} + +func TestSecurityHandler_AddWAFExclusion_ToExistingConfig(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + // Create config with existing exclusion + existingExclusions := `[{"rule_id":941100,"description":"XSS rule"}]` + cfg := models.SecurityConfig{Name: "default", WAFExclusions: existingExclusions} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + router.GET("/security/waf/exclusions", handler.GetWAFExclusions) + + // Add new exclusion + payload := map[string]interface{}{ + "rule_id": 942100, + "description": "SQL Injection rule", + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + // Verify both exclusions exist + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + exclusions := resp["exclusions"].([]interface{}) + assert.Len(t, exclusions, 2) +} + +func TestSecurityHandler_AddWAFExclusion_Duplicate(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + // Create config with existing exclusion + existingExclusions := `[{"rule_id":942100,"description":"SQL Injection rule"}]` + cfg := models.SecurityConfig{Name: "default", WAFExclusions: existingExclusions} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + // Try to add duplicate + payload := map[string]interface{}{ + "rule_id": 942100, + "description": "Another description", + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusConflict, w.Code) +} + +func TestSecurityHandler_AddWAFExclusion_DuplicateWithDifferentTarget(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + // Create config with existing exclusion (no target) + existingExclusions := `[{"rule_id":942100}]` + cfg := models.SecurityConfig{Name: "default", WAFExclusions: existingExclusions} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + // Add same rule_id with different target - should succeed + payload := map[string]interface{}{ + "rule_id": 942100, + "target": "ARGS:password", + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestSecurityHandler_AddWAFExclusion_MissingRuleID(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + payload := map[string]interface{}{ + "description": "Missing rule_id", + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestSecurityHandler_AddWAFExclusion_InvalidRuleID(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + // Zero rule_id + payload := map[string]interface{}{ + "rule_id": 0, + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestSecurityHandler_AddWAFExclusion_NegativeRuleID(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + payload := map[string]interface{}{ + "rule_id": -1, + } + body, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestSecurityHandler_AddWAFExclusion_InvalidPayload(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/security/waf/exclusions", strings.NewReader("invalid json")) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +// Tests for DeleteWAFExclusion handler +func TestSecurityHandler_DeleteWAFExclusion_Success(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + // Create config with exclusions + exclusionsJSON := `[{"rule_id":942100},{"rule_id":941100}]` + cfg := models.SecurityConfig{Name: "default", WAFExclusions: exclusionsJSON} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + router.GET("/security/waf/exclusions", handler.GetWAFExclusions) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/security/waf/exclusions/942100", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + assert.True(t, resp["deleted"].(bool)) + + // Verify only one exclusion remains + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + + json.Unmarshal(w.Body.Bytes(), &resp) + exclusions := resp["exclusions"].([]interface{}) + assert.Len(t, exclusions, 1) + first := exclusions[0].(map[string]interface{}) + assert.Equal(t, float64(941100), first["rule_id"]) +} + +func TestSecurityHandler_DeleteWAFExclusion_WithTarget(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + // Create config with targeted exclusion + exclusionsJSON := `[{"rule_id":942100,"target":"ARGS:password"},{"rule_id":942100}]` + cfg := models.SecurityConfig{Name: "default", WAFExclusions: exclusionsJSON} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + router.GET("/security/waf/exclusions", handler.GetWAFExclusions) + + // Delete exclusion with target + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/security/waf/exclusions/942100?target=ARGS:password", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + // Verify only the non-targeted exclusion remains + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + exclusions := resp["exclusions"].([]interface{}) + assert.Len(t, exclusions, 1) + first := exclusions[0].(map[string]interface{}) + assert.Equal(t, float64(942100), first["rule_id"]) + assert.Empty(t, first["target"]) +} + +func TestSecurityHandler_DeleteWAFExclusion_NotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + // Create config with exclusions + exclusionsJSON := `[{"rule_id":942100}]` + cfg := models.SecurityConfig{Name: "default", WAFExclusions: exclusionsJSON} + db.Create(&cfg) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/security/waf/exclusions/999999", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestSecurityHandler_DeleteWAFExclusion_NoConfig(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/security/waf/exclusions/942100", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestSecurityHandler_DeleteWAFExclusion_InvalidRuleID(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/security/waf/exclusions/invalid", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestSecurityHandler_DeleteWAFExclusion_ZeroRuleID(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/security/waf/exclusions/0", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestSecurityHandler_DeleteWAFExclusion_NegativeRuleID(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/security/waf/exclusions/-1", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +// Integration test: Full WAF exclusion workflow +func TestSecurityHandler_WAFExclusion_FullWorkflow(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{}, &models.SecurityAudit{})) + + handler := NewSecurityHandler(config.SecurityConfig{}, db, nil) + router := gin.New() + router.GET("/security/waf/exclusions", handler.GetWAFExclusions) + router.POST("/security/waf/exclusions", handler.AddWAFExclusion) + router.DELETE("/security/waf/exclusions/:rule_id", handler.DeleteWAFExclusion) + + // Step 1: Start with empty exclusions + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + assert.Len(t, resp["exclusions"].([]interface{}), 0) + + // Step 2: Add first exclusion (full rule removal) + payload := map[string]interface{}{ + "rule_id": 942100, + "description": "SQL Injection false positive", + } + body, _ := json.Marshal(payload) + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Step 3: Add second exclusion (targeted) + payload = map[string]interface{}{ + "rule_id": 941100, + "target": "ARGS:content", + "description": "XSS false positive in content field", + } + body, _ = json.Marshal(payload) + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/security/waf/exclusions", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Step 4: Verify both exclusions exist + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + json.Unmarshal(w.Body.Bytes(), &resp) + assert.Len(t, resp["exclusions"].([]interface{}), 2) + + // Step 5: Delete first exclusion + w = httptest.NewRecorder() + req, _ = http.NewRequest("DELETE", "/security/waf/exclusions/942100", http.NoBody) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Step 6: Verify only second exclusion remains + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/security/waf/exclusions", http.NoBody) + router.ServeHTTP(w, req) + json.Unmarshal(w.Body.Bytes(), &resp) + exclusions := resp["exclusions"].([]interface{}) + assert.Len(t, exclusions, 1) + first := exclusions[0].(map[string]interface{}) + assert.Equal(t, float64(941100), first["rule_id"]) + assert.Equal(t, "ARGS:content", first["target"]) +} + +// Test WAFDisabled field on ProxyHost +func TestProxyHost_WAFDisabled_DefaultFalse(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.ProxyHost{})) + + host := models.ProxyHost{ + UUID: "test-uuid", + DomainNames: "example.com", + ForwardHost: "backend", + ForwardPort: 8080, + Enabled: true, + } + db.Create(&host) + + var retrieved models.ProxyHost + db.First(&retrieved, host.ID) + + assert.False(t, retrieved.WAFDisabled, "WAFDisabled should default to false") +} + +func TestProxyHost_WAFDisabled_SetTrue(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.ProxyHost{})) + + host := models.ProxyHost{ + UUID: "test-uuid", + DomainNames: "example.com", + ForwardHost: "backend", + ForwardPort: 8080, + Enabled: true, + WAFDisabled: true, + } + db.Create(&host) + + var retrieved models.ProxyHost + db.First(&retrieved, host.ID) + + assert.True(t, retrieved.WAFDisabled, "WAFDisabled should be true when set") +} + +// Test WAFParanoiaLevel field on SecurityConfig +func TestSecurityConfig_WAFParanoiaLevel_Default(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + cfg := models.SecurityConfig{ + Name: "default", + WAFMode: "block", + } + db.Create(&cfg) + + var retrieved models.SecurityConfig + db.First(&retrieved, cfg.ID) + + // GORM default is 1 + assert.Equal(t, 1, retrieved.WAFParanoiaLevel, "WAFParanoiaLevel should default to 1") +} + +func TestSecurityConfig_WAFParanoiaLevel_CustomValue(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + cfg := models.SecurityConfig{ + Name: "default", + WAFMode: "block", + WAFParanoiaLevel: 3, + } + db.Create(&cfg) + + var retrieved models.SecurityConfig + db.First(&retrieved, cfg.ID) + + assert.Equal(t, 3, retrieved.WAFParanoiaLevel, "WAFParanoiaLevel should be 3") +} + +// Test WAFExclusions field on SecurityConfig +func TestSecurityConfig_WAFExclusions_Empty(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + cfg := models.SecurityConfig{ + Name: "default", + WAFMode: "block", + } + db.Create(&cfg) + + var retrieved models.SecurityConfig + db.First(&retrieved, cfg.ID) + + assert.Empty(t, retrieved.WAFExclusions, "WAFExclusions should be empty by default") +} + +func TestSecurityConfig_WAFExclusions_JSONArray(t *testing.T) { + gin.SetMode(gin.TestMode) + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate(&models.SecurityConfig{})) + + exclusions := `[{"rule_id":942100,"target":"ARGS:password","description":"Skip password field"}]` + cfg := models.SecurityConfig{ + Name: "default", + WAFMode: "block", + WAFExclusions: exclusions, + } + db.Create(&cfg) + + var retrieved models.SecurityConfig + db.First(&retrieved, cfg.ID) + + assert.Equal(t, exclusions, retrieved.WAFExclusions) + + // Verify it can be parsed + var parsed []map[string]interface{} + err := json.Unmarshal([]byte(retrieved.WAFExclusions), &parsed) + require.NoError(t, err) + assert.Len(t, parsed, 1) + assert.Equal(t, float64(942100), parsed[0]["rule_id"]) +} diff --git a/backend/internal/api/handlers/security_ratelimit_test.go b/backend/internal/api/handlers/security_ratelimit_test.go new file mode 100644 index 00000000..3017f42a --- /dev/null +++ b/backend/internal/api/handlers/security_ratelimit_test.go @@ -0,0 +1,101 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/Wikid82/charon/backend/internal/config" +) + +func TestSecurityHandler_GetRateLimitPresets(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := config.SecurityConfig{} + handler := NewSecurityHandler(cfg, nil, nil) + router := gin.New() + router.GET("/security/rate-limit/presets", handler.GetRateLimitPresets) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/rate-limit/presets", http.NoBody) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + presets, ok := response["presets"].([]interface{}) + require.True(t, ok, "presets should be an array") + require.Len(t, presets, 4, "should have 4 presets") + + // Verify preset structure + expectedIDs := []string{"standard", "api", "login", "relaxed"} + for i, p := range presets { + preset := p.(map[string]interface{}) + assert.Equal(t, expectedIDs[i], preset["id"]) + assert.NotEmpty(t, preset["name"]) + assert.NotEmpty(t, preset["description"]) + assert.NotNil(t, preset["requests"]) + assert.NotNil(t, preset["window_sec"]) + assert.NotNil(t, preset["burst"]) + } +} + +func TestSecurityHandler_GetRateLimitPresets_StandardPreset(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := config.SecurityConfig{} + handler := NewSecurityHandler(cfg, nil, nil) + router := gin.New() + router.GET("/security/rate-limit/presets", handler.GetRateLimitPresets) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/rate-limit/presets", http.NoBody) + router.ServeHTTP(w, req) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + presets := response["presets"].([]interface{}) + standardPreset := presets[0].(map[string]interface{}) + + assert.Equal(t, "standard", standardPreset["id"]) + assert.Equal(t, "Standard Web", standardPreset["name"]) + assert.Equal(t, float64(100), standardPreset["requests"]) + assert.Equal(t, float64(60), standardPreset["window_sec"]) + assert.Equal(t, float64(20), standardPreset["burst"]) +} + +func TestSecurityHandler_GetRateLimitPresets_LoginPreset(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := config.SecurityConfig{} + handler := NewSecurityHandler(cfg, nil, nil) + router := gin.New() + router.GET("/security/rate-limit/presets", handler.GetRateLimitPresets) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/security/rate-limit/presets", http.NoBody) + router.ServeHTTP(w, req) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + presets := response["presets"].([]interface{}) + loginPreset := presets[2].(map[string]interface{}) + + assert.Equal(t, "login", loginPreset["id"]) + assert.Equal(t, "Login Protection", loginPreset["name"]) + assert.Equal(t, float64(5), loginPreset["requests"]) + assert.Equal(t, float64(300), loginPreset["window_sec"]) + assert.Equal(t, float64(2), loginPreset["burst"]) +} diff --git a/backend/internal/api/routes/routes.go b/backend/internal/api/routes/routes.go index 9943d1e7..30db4a40 100644 --- a/backend/internal/api/routes/routes.go +++ b/backend/internal/api/routes/routes.go @@ -4,6 +4,7 @@ package routes import ( "context" "fmt" + "os" "time" "github.com/gin-contrib/gzip" @@ -301,8 +302,30 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error { caddyClient := caddy.NewClient(cfg.CaddyAdminAPI) caddyManager = caddy.NewManager(caddyClient, db, cfg.CaddyConfigDir, cfg.FrontendDir, cfg.ACMEStaging, cfg.Security) + // Initialize GeoIP service if database exists + geoipPath := os.Getenv("CHARON_GEOIP_DB_PATH") + if geoipPath == "" { + geoipPath = "/app/data/geoip/GeoLite2-Country.mmdb" + } + + var geoipSvc *services.GeoIPService + if _, err := os.Stat(geoipPath); err == nil { + var geoErr error + geoipSvc, geoErr = services.NewGeoIPService(geoipPath) + if geoErr != nil { + logger.Log().WithError(geoErr).WithField("path", geoipPath).Warn("Failed to load GeoIP database - geo-blocking features will be unavailable") + } else { + logger.Log().WithField("path", geoipPath).Info("GeoIP database loaded successfully") + } + } else { + logger.Log().WithField("path", geoipPath).Info("GeoIP database not found - geo-blocking features will be unavailable") + } + // Security Status securityHandler := handlers.NewSecurityHandler(cfg.Security, db, caddyManager) + if geoipSvc != nil { + securityHandler.SetGeoIPService(geoipSvc) + } protected.GET("/security/status", securityHandler.GetStatus) // Security Config management protected.GET("/security/config", securityHandler.GetConfig) @@ -315,6 +338,15 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error { protected.GET("/security/rulesets", securityHandler.ListRuleSets) protected.POST("/security/rulesets", securityHandler.UpsertRuleSet) protected.DELETE("/security/rulesets/:id", securityHandler.DeleteRuleSet) + protected.GET("/security/rate-limit/presets", securityHandler.GetRateLimitPresets) + // GeoIP endpoints + protected.GET("/security/geoip/status", securityHandler.GetGeoIPStatus) + protected.POST("/security/geoip/reload", securityHandler.ReloadGeoIP) + protected.POST("/security/geoip/lookup", securityHandler.LookupGeoIP) + // WAF exclusion endpoints + protected.GET("/security/waf/exclusions", securityHandler.GetWAFExclusions) + protected.POST("/security/waf/exclusions", securityHandler.AddWAFExclusion) + protected.DELETE("/security/waf/exclusions/:rule_id", securityHandler.DeleteWAFExclusion) // CrowdSec process management and import // Data dir for crowdsec (persisted on host via volumes) @@ -325,6 +357,9 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error { // Access Lists accessListHandler := handlers.NewAccessListHandler(db) + if geoipSvc != nil { + accessListHandler.SetGeoIPService(geoipSvc) + } protected.GET("/access-lists/templates", accessListHandler.GetTemplates) protected.GET("/access-lists", accessListHandler.List) protected.POST("/access-lists", accessListHandler.Create) diff --git a/backend/internal/api/tests/integration_test.go b/backend/internal/api/tests/integration_test.go index d11e40af..6cc21b9a 100644 --- a/backend/internal/api/tests/integration_test.go +++ b/backend/internal/api/tests/integration_test.go @@ -16,6 +16,8 @@ import ( ) // TestIntegration_WAF_BlockAndMonitor exercises middleware behavior and metrics exposure. +// Note: Actual WAF blocking is handled by Coraza at the Caddy layer, not by the API middleware. +// The cerberus middleware only tracks metrics and handles ACL enforcement. func TestIntegration_WAF_BlockAndMonitor(t *testing.T) { gin.SetMode(gin.TestMode) @@ -37,13 +39,17 @@ func TestIntegration_WAF_BlockAndMonitor(t *testing.T) { return r, db } - // Block mode should reject suspicious payload on an API route covered by middleware + // Block mode: cerberus middleware doesn't block requests - that's Coraza's job at the Caddy layer + // The API middleware only tracks metrics when WAF is enabled rBlock, _ := newServer("block") req := httptest.NewRequest(http.MethodGet, "/api/v1/remote-servers?test=