diff --git a/backend/cmd/seed/main_test.go b/backend/cmd/seed/main_test.go index ff6c8db7..645906f8 100644 --- a/backend/cmd/seed/main_test.go +++ b/backend/cmd/seed/main_test.go @@ -9,14 +9,6 @@ import ( "testing" ) -package main - -import ( - "os" - "path/filepath" - "testing" -) - func TestSeedMain_CreatesDatabaseFile(t *testing.T) { wd, err := os.Getwd() if err != nil { @@ -44,42 +36,3 @@ func TestSeedMain_CreatesDatabaseFile(t *testing.T) { t.Fatalf("expected db file to be non-empty") } } -package main -package main - -import ( - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -} } t.Fatalf("expected db file to be non-empty") if info.Size() == 0 { } t.Fatalf("expected db file to exist at %s: %v", dbPath, err) if err != nil { info, err := os.Stat(dbPath) dbPath := filepath.Join("data", "charon.db") main() } t.Fatalf("mkdir data: %v", err) if err := os.MkdirAll("data", 0o755); err != nil { t.Cleanup(func() { _ = os.Chdir(wd) }) } t.Fatalf("chdir: %v", err) if err := os.Chdir(tmp); err != nil { tmp := t.TempDir() } t.Fatalf("getwd: %v", err) if err != nil { wd, err := os.Getwd() t.Parallel()func TestSeedMain_CreatesDatabaseFile(t *testing.T) {) "testing" "path/filepath" "os" diff --git a/backend/internal/api/handlers/crowdsec_handler.go b/backend/internal/api/handlers/crowdsec_handler.go index 64e77ef9..a4770ec3 100644 --- a/backend/internal/api/handlers/crowdsec_handler.go +++ b/backend/internal/api/handlers/crowdsec_handler.go @@ -754,7 +754,8 @@ func (h *CrowdsecHandler) ExportConfig(c *gin.Context) { // Walk the DataDir and add files to the archive err := filepath.Walk(h.DataDir, func(path string, info os.FileInfo, err error) error { if err != nil { - return err + logger.Log().WithError(err).Warnf("failed to access path %s during export walk", path) + return nil // Skip files we cannot access } if info.IsDir() { return nil @@ -798,13 +799,18 @@ func (h *CrowdsecHandler) ExportConfig(c *gin.Context) { // ListFiles returns a flat list of files under the CrowdSec DataDir. func (h *CrowdsecHandler) ListFiles(c *gin.Context) { - var files []string + files := []string{} if _, err := os.Stat(h.DataDir); os.IsNotExist(err) { c.JSON(http.StatusOK, gin.H{"files": files}) return } err := filepath.Walk(h.DataDir, func(path string, info os.FileInfo, err error) error { if err != nil { + // Permission errors (e.g. lost+found) should not abort the walk + if os.IsPermission(err) { + logger.Log().WithError(err).WithField("path", path).Debug("Skipping inaccessible path during list") + return filepath.SkipDir + } return err } if !info.IsDir() { @@ -1754,7 +1760,9 @@ func (h *CrowdsecHandler) GetKeyStatus(c *gin.Context) { // No key available response.KeySource = "none" response.Valid = false - response.Message = "No CrowdSec API key configured. Start CrowdSec to auto-generate one." + if response.Message == "" { + response.Message = "No CrowdSec API key configured. Start CrowdSec to auto-generate one." + } } c.JSON(http.StatusOK, response) @@ -2002,13 +2010,14 @@ func (h *CrowdsecHandler) GetBouncerInfo(c *gin.Context) { fileKey := readKeyFromFile(bouncerKeyFile) var fullKey string - if envKey != "" { + switch { + case envKey != "": info.KeySource = "env_var" fullKey = envKey - } else if fileKey != "" { + case fileKey != "": info.KeySource = "file" fullKey = fileKey - } else { + default: info.KeySource = "none" } diff --git a/backend/internal/api/handlers/emergency_handler.go b/backend/internal/api/handlers/emergency_handler.go index 5871321b..5c870bab 100644 --- a/backend/internal/api/handlers/emergency_handler.go +++ b/backend/internal/api/handlers/emergency_handler.go @@ -245,10 +245,22 @@ func (h *EmergencyHandler) disableAllSecurityModules() ([]string, error) { disabledModules = append(disabledModules, key) } + // Clear admin whitelist to prevent bypass persistence after reset + adminWhitelistSetting := models.Setting{ + Key: "security.admin_whitelist", + Value: "", + Category: "security", + Type: "string", + } + if err := h.db.Where(models.Setting{Key: adminWhitelistSetting.Key}).Assign(adminWhitelistSetting).FirstOrCreate(&adminWhitelistSetting).Error; err != nil { + return disabledModules, fmt.Errorf("failed to clear admin whitelist: %w", err) + } + // Also update the SecurityConfig record if it exists var securityConfig models.SecurityConfig if err := h.db.Where("name = ?", "default").First(&securityConfig).Error; err == nil { securityConfig.Enabled = false + securityConfig.AdminWhitelist = "" securityConfig.WAFMode = "disabled" securityConfig.RateLimitMode = "disabled" securityConfig.RateLimitEnable = false diff --git a/backend/internal/api/handlers/emergency_handler_test.go b/backend/internal/api/handlers/emergency_handler_test.go index 65229737..9d537834 100644 --- a/backend/internal/api/handlers/emergency_handler_test.go +++ b/backend/internal/api/handlers/emergency_handler_test.go @@ -125,12 +125,19 @@ func TestEmergencySecurityReset_Success(t *testing.T) { require.NoError(t, err) assert.Equal(t, "disabled", crowdsecMode.Value) + // Verify admin whitelist is cleared + var adminWhitelist models.Setting + err = db.Where("key = ?", "security.admin_whitelist").First(&adminWhitelist).Error + require.NoError(t, err) + assert.Equal(t, "", adminWhitelist.Value) + // Verify SecurityConfig was updated var updatedConfig models.SecurityConfig err = db.Where("name = ?", "default").First(&updatedConfig).Error require.NoError(t, err) assert.False(t, updatedConfig.Enabled) assert.Equal(t, "disabled", updatedConfig.WAFMode) + assert.Equal(t, "", updatedConfig.AdminWhitelist) // Note: Audit logging is async via SecurityService channel, tested separately } diff --git a/backend/internal/api/handlers/user_handler.go b/backend/internal/api/handlers/user_handler.go index cd27b631..0118985d 100644 --- a/backend/internal/api/handlers/user_handler.go +++ b/backend/internal/api/handlers/user_handler.go @@ -601,6 +601,7 @@ func (h *UserHandler) GetUser(c *gin.Context) { type UpdateUserRequest struct { Name string `json:"name"` Email string `json:"email"` + Password *string `json:"password" binding:"omitempty,min=8"` Role string `json:"role"` Enabled *bool `json:"enabled"` } @@ -653,6 +654,16 @@ func (h *UserHandler) UpdateUser(c *gin.Context) { updates["role"] = req.Role } + if req.Password != nil { + if err := user.SetPassword(*req.Password); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to hash password"}) + return + } + updates["password_hash"] = user.PasswordHash + updates["failed_login_attempts"] = 0 + updates["locked_until"] = nil + } + if req.Enabled != nil { updates["enabled"] = *req.Enabled } diff --git a/backend/internal/api/handlers/user_handler_test.go b/backend/internal/api/handlers/user_handler_test.go index a3762396..475d321d 100644 --- a/backend/internal/api/handlers/user_handler_test.go +++ b/backend/internal/api/handlers/user_handler_test.go @@ -754,6 +754,43 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) } +func TestUserHandler_UpdateUser_PasswordReset(t *testing.T) { + handler, db := setupUserHandlerWithProxyHosts(t) + + user := &models.User{UUID: uuid.NewString(), Email: "reset@example.com", Name: "Reset User", Role: "user"} + require.NoError(t, user.SetPassword("oldpassword123")) + lockUntil := time.Now().Add(10 * time.Minute) + user.FailedLoginAttempts = 4 + user.LockedUntil = &lockUntil + db.Create(user) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("role", "admin") + c.Next() + }) + r.PUT("/users/:id", handler.UpdateUser) + + body := map[string]any{ + "password": "newpassword123", + } + jsonBody, _ := json.Marshal(body) + req := httptest.NewRequest("PUT", "/users/1", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var updated models.User + db.First(&updated, user.ID) + assert.True(t, updated.CheckPassword("newpassword123")) + assert.False(t, updated.CheckPassword("oldpassword123")) + assert.Equal(t, 0, updated.FailedLoginAttempts) + assert.Nil(t, updated.LockedUntil) +} + func TestUserHandler_DeleteUser_NonAdmin(t *testing.T) { handler, _ := setupUserHandlerWithProxyHosts(t) gin.SetMode(gin.TestMode) diff --git a/backend/internal/api/routes/routes.go b/backend/internal/api/routes/routes.go index e84e301c..eb51e555 100644 --- a/backend/internal/api/routes/routes.go +++ b/backend/internal/api/routes/routes.go @@ -130,6 +130,7 @@ func RegisterWithDeps(router *gin.Engine, db *gorm.DB, cfg config.Config, caddyM // Emergency endpoint emergencyHandler := handlers.NewEmergencyHandlerWithDeps(db, caddyManager, cerb) emergency := router.Group("/api/v1/emergency") + // Emergency endpoints must stay responsive and should not be rate limited. emergency.POST("/security-reset", emergencyHandler.SecurityReset) // Emergency token management (admin-only, protected by EmergencyBypass middleware) @@ -146,7 +147,11 @@ func RegisterWithDeps(router *gin.Engine, db *gorm.DB, cfg config.Config, caddyM authMiddleware := middleware.AuthMiddleware(authService) api := router.Group("/api/v1") + // Rate Limiting (Emergency/Go-layer) MUST run before Auth to prevent 401 masking 429 + api.Use(cerb.RateLimitMiddleware()) api.Use(middleware.OptionalAuth(authService)) + // Cerberus middleware (ACL, WAF Stats, CrowdSec Tracking) runs after Auth + // because ACLs need to know if user is authenticated admin to apply whitelist bypass api.Use(cerb.Middleware()) // Backup routes diff --git a/backend/internal/cerberus/rate_limit.go b/backend/internal/cerberus/rate_limit.go new file mode 100644 index 00000000..73d0a1a7 --- /dev/null +++ b/backend/internal/cerberus/rate_limit.go @@ -0,0 +1,179 @@ +package cerberus + +import ( + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" + + "github.com/Wikid82/charon/backend/internal/logger" + "github.com/Wikid82/charon/backend/internal/util" +) + +// rateLimitManager manages per-IP rate limiters. +type rateLimitManager struct { + mu sync.Mutex + limiters map[string]*rate.Limiter + lastSeen map[string]time.Time +} + +func newRateLimitManager() *rateLimitManager { + rl := &rateLimitManager{ + limiters: make(map[string]*rate.Limiter), + lastSeen: make(map[string]time.Time), + } + // Start cleanup goroutine + go rl.cleanupLoop() + return rl +} + +func (rl *rateLimitManager) cleanupLoop() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + for range ticker.C { + rl.cleanup() + } +} + +func (rl *rateLimitManager) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + cutoff := time.Now().Add(-10 * time.Minute) + for ip, seen := range rl.lastSeen { + if seen.Before(cutoff) { + delete(rl.limiters, ip) + delete(rl.lastSeen, ip) + } + } +} + +func (rl *rateLimitManager) getLimiter(ip string, r rate.Limit, b int) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + lim, exists := rl.limiters[ip] + if !exists { + lim = rate.NewLimiter(r, b) + rl.limiters[ip] = lim + } + rl.lastSeen[ip] = time.Now() + + // Check if limit changed (re-config) + if lim.Limit() != r || lim.Burst() != b { + lim = rate.NewLimiter(r, b) + rl.limiters[ip] = lim + } + + return lim +} + +// NewRateLimitMiddleware creates a new rate limit middleware with fixed parameters. +// Useful for testing or when Cerberus context is not available. +func NewRateLimitMiddleware(requests int, windowSec int, burst int) gin.HandlerFunc { + mgr := newRateLimitManager() + + if windowSec <= 0 { + windowSec = 1 + } + limit := rate.Limit(float64(requests) / float64(windowSec)) + + return func(ctx *gin.Context) { + // Check for emergency bypass flag + if bypass, exists := ctx.Get("emergency_bypass"); exists && bypass.(bool) { + ctx.Next() + return + } + + clientIP := util.CanonicalizeIPForSecurity(ctx.ClientIP()) + limiter := mgr.getLimiter(clientIP, limit, burst) + + if !limiter.Allow() { + logger.Log().WithField("ip", clientIP).Warn("Rate limit exceeded (Go middleware)") + ctx.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Too many requests"}) + return + } + + ctx.Next() + } +} + +// RateLimitMiddleware enforces rate limiting based on security config. +func (c *Cerberus) RateLimitMiddleware() gin.HandlerFunc { + mgr := newRateLimitManager() + + return func(ctx *gin.Context) { + // Check for emergency bypass flag + if bypass, exists := ctx.Get("emergency_bypass"); exists && bypass.(bool) { + ctx.Next() + return + } + + // Check config enabled status + enabled := false + if c.cfg.RateLimitMode == "enabled" { + enabled = true + } else { + // Check dynamic setting + if v, ok := c.getSetting("security.rate_limit.enabled"); ok && strings.EqualFold(v, "true") { + enabled = true + } + } + + if !enabled { + ctx.Next() + return + } + + // Determine limits + requests := 100 // per window + window := 60 // seconds + burst := 20 + + if c.cfg.RateLimitRequests > 0 { + requests = c.cfg.RateLimitRequests + } + if c.cfg.RateLimitWindowSec > 0 { + window = c.cfg.RateLimitWindowSec + } + if c.cfg.RateLimitBurst > 0 { + burst = c.cfg.RateLimitBurst + } + + // Check for dynamic overrides from settings (Issue #3 fix) + if val, ok := c.getSetting("security.rate_limit.requests"); ok { + if v, err := strconv.Atoi(val); err == nil && v > 0 { + requests = v + } + } + if val, ok := c.getSetting("security.rate_limit.window"); ok { + if v, err := strconv.Atoi(val); err == nil && v > 0 { + window = v + } + } + if val, ok := c.getSetting("security.rate_limit.burst"); ok { + if v, err := strconv.Atoi(val); err == nil && v > 0 { + burst = v + } + } + + if window == 0 { + window = 60 + } + limit := rate.Limit(float64(requests) / float64(window)) + + clientIP := util.CanonicalizeIPForSecurity(ctx.ClientIP()) + limiter := mgr.getLimiter(clientIP, limit, burst) + + if !limiter.Allow() { + logger.Log().WithField("ip", clientIP).Warn("Rate limit exceeded (Go middleware)") + ctx.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Too many requests"}) + return + } + + ctx.Next() + } +} diff --git a/backend/internal/cerberus/rate_limit_test.go b/backend/internal/cerberus/rate_limit_test.go new file mode 100644 index 00000000..22392d04 --- /dev/null +++ b/backend/internal/cerberus/rate_limit_test.go @@ -0,0 +1,336 @@ +package cerberus + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "github.com/Wikid82/charon/backend/internal/config" + "github.com/Wikid82/charon/backend/internal/models" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func setupRateLimitTestDB(t *testing.T) *gorm.DB { + t.Helper() + dsn := fmt.Sprintf("file:rate_limit_test_%d?mode=memory&cache=shared", time.Now().UnixNano()) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + require.NoError(t, err) + require.NoError(t, db.AutoMigrate(&models.Setting{})) + return db +} + +func TestRateLimitMiddleware(t *testing.T) { + t.Run("Blocks excessive requests", func(t *testing.T) { + // Limit to 5 requests per second, with burst of 5 + mw := NewRateLimitMiddleware(5, 1, 5) + + r := gin.New() + r.Use(mw) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + // Make 5 allowed requests + for i := 0; i < 5; i++ { + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "192.168.1.1:1234" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + } + + // Make 6th request (should fail) + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "192.168.1.1:1234" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + }) + + t.Run("Different IPs have separate limits", func(t *testing.T) { + mw := NewRateLimitMiddleware(1, 1, 1) + + r := gin.New() + r.Use(mw) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + // 1st User + req1, _ := http.NewRequest("GET", "/", nil) + req1.RemoteAddr = "10.0.0.1:1234" + w1 := httptest.NewRecorder() + r.ServeHTTP(w1, req1) + assert.Equal(t, http.StatusOK, w1.Code) + + // 2nd User (should pass) + req2, _ := http.NewRequest("GET", "/", nil) + req2.RemoteAddr = "10.0.0.2:1234" + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req2) + assert.Equal(t, http.StatusOK, w2.Code) + }) + + t.Run("Replenishes tokens over time", func(t *testing.T) { + // 1 request per second (burst 1) + mw := NewRateLimitMiddleware(1, 1, 1) + // Manually override the burst/limit for predictable testing isn't easy with wrapper + // So we rely on the implementation using x/time/rate + // Test: + // 1. Consume 1 + // 2. Consume 2 (Fail) + // 3. Wait until refill + // 4. Consume 3 (Pass) + + r := gin.New() + r.Use(mw) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + + // 1. Consume + w1 := httptest.NewRecorder() + r.ServeHTTP(w1, req) + assert.Equal(t, http.StatusOK, w1.Code) + + // 2. Consume Fail + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req) + assert.Equal(t, http.StatusTooManyRequests, w2.Code) + + // 3. Wait until refill + require.Eventually(t, func() bool { + w3 := httptest.NewRecorder() + r.ServeHTTP(w3, req) + return w3.Code == http.StatusOK + }, 1500*time.Millisecond, 25*time.Millisecond) + }) +} + +func TestRateLimitManager_ReconfiguresLimiter(t *testing.T) { + mgr := &rateLimitManager{ + limiters: make(map[string]*rate.Limiter), + lastSeen: make(map[string]time.Time), + } + + limiter := mgr.getLimiter("10.0.0.1", rate.Limit(1), 1) + assert.Equal(t, rate.Limit(1), limiter.Limit()) + assert.Equal(t, 1, limiter.Burst()) + + limiter = mgr.getLimiter("10.0.0.1", rate.Limit(2), 2) + assert.Equal(t, rate.Limit(2), limiter.Limit()) + assert.Equal(t, 2, limiter.Burst()) +} + +func TestRateLimitManager_CleanupRemovesStaleEntries(t *testing.T) { + mgr := &rateLimitManager{ + limiters: map[string]*rate.Limiter{ + "10.0.0.1": rate.NewLimiter(rate.Limit(1), 1), + }, + lastSeen: map[string]time.Time{ + "10.0.0.1": time.Now().Add(-11 * time.Minute), + }, + } + + mgr.cleanup() + assert.Empty(t, mgr.limiters) + assert.Empty(t, mgr.lastSeen) +} + +func TestRateLimitMiddleware_EmergencyBypass(t *testing.T) { + mw := NewRateLimitMiddleware(1, 1, 1) + + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("emergency_bypass", true) + c.Next() + }) + r.Use(mw) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + for i := 0; i < 2; i++ { + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + } +} + +func TestCerberusRateLimitMiddleware_DisabledAllowsTraffic(t *testing.T) { + cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, nil) + + r := gin.New() + r.Use(cerb.RateLimitMiddleware()) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + } +} + +func TestCerberusRateLimitMiddleware_EnabledByConfig(t *testing.T) { + cfg := config.SecurityConfig{ + RateLimitMode: "enabled", + RateLimitRequests: 1, + RateLimitWindowSec: 1, + RateLimitBurst: 1, + } + cerb := New(cfg, nil) + + r := gin.New() + r.Use(cerb.RateLimitMiddleware()) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + for i := 0; i < 2; i++ { + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if i == 0 { + assert.Equal(t, http.StatusOK, w.Code) + } else { + assert.Equal(t, http.StatusTooManyRequests, w.Code) + } + } +} + +func TestCerberusRateLimitMiddleware_EmergencyBypass(t *testing.T) { + cfg := config.SecurityConfig{ + RateLimitMode: "enabled", + RateLimitRequests: 1, + RateLimitWindowSec: 1, + RateLimitBurst: 1, + } + cerb := New(cfg, nil) + + r := gin.New() + r.Use(func(c *gin.Context) { + c.Set("emergency_bypass", true) + c.Next() + }) + r.Use(cerb.RateLimitMiddleware()) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + for i := 0; i < 2; i++ { + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + } +} + +func TestCerberusRateLimitMiddleware_EnabledBySetting(t *testing.T) { + db := setupRateLimitTestDB(t) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error) + + cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, db) + + r := gin.New() + r.Use(cerb.RateLimitMiddleware()) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + + w1 := httptest.NewRecorder() + r.ServeHTTP(w1, req) + assert.Equal(t, http.StatusOK, w1.Code) + + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req) + assert.Equal(t, http.StatusTooManyRequests, w2.Code) +} + +func TestCerberusRateLimitMiddleware_OverridesConfigWithSettings(t *testing.T) { + db := setupRateLimitTestDB(t) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error) + require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error) + + cfg := config.SecurityConfig{ + RateLimitMode: "enabled", + RateLimitRequests: 10, + RateLimitWindowSec: 10, + RateLimitBurst: 10, + } + cerb := New(cfg, db) + + r := gin.New() + r.Use(cerb.RateLimitMiddleware()) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + + w1 := httptest.NewRecorder() + r.ServeHTTP(w1, req) + assert.Equal(t, http.StatusOK, w1.Code) + + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req) + assert.Equal(t, http.StatusTooManyRequests, w2.Code) +} + +func TestCerberusRateLimitMiddleware_WindowFallback(t *testing.T) { + cfg := config.SecurityConfig{ + RateLimitMode: "enabled", + RateLimitRequests: 1, + RateLimitWindowSec: 0, + RateLimitBurst: 1, + } + cerb := New(cfg, nil) + + r := gin.New() + r.Use(cerb.RateLimitMiddleware()) + r.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:1234" + + w1 := httptest.NewRecorder() + r.ServeHTTP(w1, req) + assert.Equal(t, http.StatusOK, w1.Code) + + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, req) + assert.Equal(t, http.StatusTooManyRequests, w2.Code) +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 70f7a05f..1599baff 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" ) @@ -29,14 +30,17 @@ type Config struct { // SecurityConfig holds configuration for optional security services. type SecurityConfig struct { - CrowdSecMode string - CrowdSecAPIURL string - CrowdSecAPIKey string - CrowdSecConfigDir string - WAFMode string - RateLimitMode string - ACLMode string - CerberusEnabled bool + CrowdSecMode string + CrowdSecAPIURL string + CrowdSecAPIKey string + CrowdSecConfigDir string + WAFMode string + RateLimitMode string + RateLimitRequests int + RateLimitWindowSec int + RateLimitBurst int + ACLMode string + CerberusEnabled bool // ManagementCIDRs defines IP ranges allowed to use emergency break glass token // Default: RFC1918 private networks (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 127.0.0.0/8) ManagementCIDRs []string @@ -110,14 +114,17 @@ func Load() (Config, error) { // loadSecurityConfig loads the security configuration with proper parsing of array fields func loadSecurityConfig() SecurityConfig { cfg := SecurityConfig{ - CrowdSecMode: getEnvAny("disabled", "CERBERUS_SECURITY_CROWDSEC_MODE", "CHARON_SECURITY_CROWDSEC_MODE", "CPM_SECURITY_CROWDSEC_MODE"), - CrowdSecAPIURL: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_URL", "CHARON_SECURITY_CROWDSEC_API_URL", "CPM_SECURITY_CROWDSEC_API_URL"), - CrowdSecAPIKey: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"), - CrowdSecConfigDir: getEnvAny(filepath.Join("data", "crowdsec"), "CHARON_CROWDSEC_CONFIG_DIR", "CPM_CROWDSEC_CONFIG_DIR"), - WAFMode: getEnvAny("disabled", "CERBERUS_SECURITY_WAF_MODE", "CHARON_SECURITY_WAF_MODE", "CPM_SECURITY_WAF_MODE"), - RateLimitMode: getEnvAny("disabled", "CERBERUS_SECURITY_RATELIMIT_MODE", "CHARON_SECURITY_RATELIMIT_MODE", "CPM_SECURITY_RATELIMIT_MODE"), - ACLMode: getEnvAny("disabled", "CERBERUS_SECURITY_ACL_MODE", "CHARON_SECURITY_ACL_MODE", "CPM_SECURITY_ACL_MODE"), - CerberusEnabled: getEnvAny("true", "CERBERUS_SECURITY_CERBERUS_ENABLED", "CHARON_SECURITY_CERBERUS_ENABLED", "CPM_SECURITY_CERBERUS_ENABLED") != "false", + CrowdSecMode: getEnvAny("disabled", "CERBERUS_SECURITY_CROWDSEC_MODE", "CHARON_SECURITY_CROWDSEC_MODE", "CPM_SECURITY_CROWDSEC_MODE"), + CrowdSecAPIURL: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_URL", "CHARON_SECURITY_CROWDSEC_API_URL", "CPM_SECURITY_CROWDSEC_API_URL"), + CrowdSecAPIKey: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"), + CrowdSecConfigDir: getEnvAny(filepath.Join("data", "crowdsec"), "CHARON_CROWDSEC_CONFIG_DIR", "CPM_CROWDSEC_CONFIG_DIR"), + WAFMode: getEnvAny("disabled", "CERBERUS_SECURITY_WAF_MODE", "CHARON_SECURITY_WAF_MODE", "CPM_SECURITY_WAF_MODE"), + RateLimitMode: getEnvAny("disabled", "CERBERUS_SECURITY_RATELIMIT_MODE", "CHARON_SECURITY_RATELIMIT_MODE", "CPM_SECURITY_RATELIMIT_MODE"), + RateLimitRequests: getEnvIntAny(100, "CERBERUS_SECURITY_RATELIMIT_REQUESTS", "CHARON_SECURITY_RATELIMIT_REQUESTS"), + RateLimitWindowSec: getEnvIntAny(60, "CERBERUS_SECURITY_RATELIMIT_WINDOW", "CHARON_SECURITY_RATELIMIT_WINDOW"), + RateLimitBurst: getEnvIntAny(20, "CERBERUS_SECURITY_RATELIMIT_BURST", "CHARON_SECURITY_RATELIMIT_BURST"), + ACLMode: getEnvAny("disabled", "CERBERUS_SECURITY_ACL_MODE", "CHARON_SECURITY_ACL_MODE", "CPM_SECURITY_ACL_MODE"), + CerberusEnabled: getEnvAny("true", "CERBERUS_SECURITY_CERBERUS_ENABLED", "CHARON_SECURITY_CERBERUS_ENABLED", "CPM_SECURITY_CERBERUS_ENABLED") != "false", } // Parse management CIDRs (comma-separated list) @@ -173,3 +180,16 @@ func getEnvAny(fallback string, keys ...string) string { } return fallback } + +// getEnvIntAny checks a list of environment variable names, attempts to parse as int. +// Returns first successfully parsed value. Returns fallback if none found or parsing failed. +func getEnvIntAny(fallback int, keys ...string) int { + valStr := getEnvAny("", keys...) + if valStr == "" { + return fallback + } + if val, err := strconv.Atoi(valStr); err == nil { + return val + } + return fallback +} diff --git a/backend/internal/security/whitelist.go b/backend/internal/security/whitelist.go index 4a26a1f0..90a80140 100644 --- a/backend/internal/security/whitelist.go +++ b/backend/internal/security/whitelist.go @@ -28,6 +28,14 @@ func IsIPInCIDRList(clientIP, cidrList string) bool { } if parsed := net.ParseIP(entry); parsed != nil { + // Fix for Issue 1: Canonicalize entry to support mixed IPv4/IPv6 loopback matching + // This ensures that "::1" in the list matches "127.0.0.1" (from canonicalized client IP) + if canonEntry := util.CanonicalizeIPForSecurity(entry); canonEntry != "" { + if p := net.ParseIP(canonEntry); p != nil { + parsed = p + } + } + if ip.Equal(parsed) { return true } @@ -41,6 +49,12 @@ func IsIPInCIDRList(clientIP, cidrList string) bool { if cidr.Contains(ip) { return true } + + // Fix for Issue 1: Handle IPv6 loopback CIDR matching against canonicalized IPv4 localhost + // If client is 127.0.0.1 (canonical localhost) and CIDR contains ::1, allow it + if ip.Equal(net.IPv4(127, 0, 0, 1)) && cidr.Contains(net.IPv6loopback) { + return true + } } return false diff --git a/backend/internal/security/whitelist_test.go b/backend/internal/security/whitelist_test.go index b32a23ab..f0873936 100644 --- a/backend/internal/security/whitelist_test.go +++ b/backend/internal/security/whitelist_test.go @@ -45,6 +45,18 @@ func TestIsIPInCIDRList(t *testing.T) { list: "192.168.0.0/16", expected: false, }, + { + name: "IPv6 loopback match", + ip: "::1", + list: "::1", + expected: true, + }, + { + name: "IPv6 loopback CIDR match", + ip: "::1", + list: "::1/128", + expected: true, + }, } for _, tt := range tests { diff --git a/backend/internal/services/crowdsec_startup_test.go b/backend/internal/services/crowdsec_startup_test.go index 486f467b..f095941f 100644 --- a/backend/internal/services/crowdsec_startup_test.go +++ b/backend/internal/services/crowdsec_startup_test.go @@ -42,8 +42,8 @@ func (m *mockCrowdsecExecutor) Status(ctx context.Context, configDir string) (ru // mockCommandExecutor is a test mock for CommandExecutor interface type mockCommandExecutor struct { executeCalls [][]string // Track command invocations - executeErr error // Error to return - executeOut []byte // Output to return + executeErr error // Error to return + executeOut []byte // Output to return } func (m *mockCommandExecutor) Execute(ctx context.Context, name string, args ...string) ([]byte, error) { diff --git a/backend/internal/services/proxyhost_service.go b/backend/internal/services/proxyhost_service.go index 5130dd38..af749fa8 100644 --- a/backend/internal/services/proxyhost_service.go +++ b/backend/internal/services/proxyhost_service.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "strconv" + "strings" "time" "github.com/Wikid82/charon/backend/internal/caddy" @@ -46,12 +47,82 @@ func (s *ProxyHostService) ValidateUniqueDomain(domainNames string, excludeID ui return nil } +// ValidateHostname checks if the provided string is a valid hostname or IP address. +func (s *ProxyHostService) ValidateHostname(host string) error { + // Trim protocol if present + if len(host) > 8 && host[:8] == "https://" { + host = host[8:] + } else if len(host) > 7 && host[:7] == "http://" { + host = host[7:] + } + + // Remove port if present + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost + } + + // Basic check: is it an IP? + if net.ParseIP(host) != nil { + return nil + } + + // Is it a valid hostname/domain? + // Regex for hostname validation (RFC 1123 mostly) + // Simple version: alphanumeric, dots, dashes. + // Allow underscores? Technically usually not in hostnames, but internal docker ones yes. + for _, r := range host { + if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '.' && r != '-' && r != '_' { + // Allow ":" for IPv6 literals if not parsed by ParseIP? ParseIP handles IPv6. + return errors.New("invalid hostname format") + } + } + return nil +} + +func (s *ProxyHostService) validateProxyHost(host *models.ProxyHost) error { + if host.ForwardHost == "" { + return errors.New("forward host is required") + } + + // Basic hostname/IP validation + target := host.ForwardHost + // Strip protocol if user accidentally typed http://10.0.0.1 + target = strings.TrimPrefix(target, "http://") + target = strings.TrimPrefix(target, "https://") + // Strip port if present + if h, _, err := net.SplitHostPort(target); err == nil { + target = h + } + + // Validate target + if net.ParseIP(target) == nil { + // Not a valid IP, check hostname rules + // Allow: a-z, 0-9, -, ., _ (for docker service names) + validHostname := true + for _, r := range target { + if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '.' && r != '-' && r != '_' { + validHostname = false + break + } + } + if !validHostname { + return errors.New("forward host must be a valid IP address or hostname") + } + } + + return nil +} + // Create validates and creates a new proxy host. func (s *ProxyHostService) Create(host *models.ProxyHost) error { if err := s.ValidateUniqueDomain(host.DomainNames, 0); err != nil { return err } + if err := s.validateProxyHost(host); err != nil { + return err + } + // Normalize and validate advanced config (if present) if host.AdvancedConfig != "" { var parsed any @@ -75,6 +146,10 @@ func (s *ProxyHostService) Update(host *models.ProxyHost) error { return err } + if err := s.validateProxyHost(host); err != nil { + return err + } + // Normalize and validate advanced config (if present) if host.AdvancedConfig != "" { var parsed any diff --git a/backend/internal/services/proxyhost_service_validation_test.go b/backend/internal/services/proxyhost_service_validation_test.go new file mode 100644 index 00000000..539fd22c --- /dev/null +++ b/backend/internal/services/proxyhost_service_validation_test.go @@ -0,0 +1,95 @@ +package services + +import ( + "testing" + + "github.com/Wikid82/charon/backend/internal/models" + "github.com/stretchr/testify/assert" +) + +func TestProxyHostService_ForwardHostValidation(t *testing.T) { + db := setupProxyHostTestDB(t) + service := NewProxyHostService(db) + + tests := []struct { + name string + forwardHost string + wantErr bool + }{ + { + name: "Valid IP", + forwardHost: "192.168.1.1", + wantErr: false, + }, + { + name: "Valid Hostname", + forwardHost: "example.com", + wantErr: false, + }, + { + name: "Docker Service Name", + forwardHost: "my-service", + wantErr: false, + }, + { + name: "Docker Service Name with Underscore", + forwardHost: "my_db_Service", + wantErr: false, + }, + { + name: "Docker Internal Host", + forwardHost: "host.docker.internal", + wantErr: false, + }, + { + name: "IP with Port (Should be stripped and pass)", + forwardHost: "192.168.1.1:8080", + wantErr: false, + }, + { + name: "Hostname with Port (Should be stripped and pass)", + forwardHost: "example.com:3000", + wantErr: false, + }, + { + name: "Host with http scheme (Should be stripped and pass)", + forwardHost: "http://example.com", + wantErr: false, + }, + { + name: "Host with https scheme (Should be stripped and pass)", + forwardHost: "https://example.com", + wantErr: false, + }, + { + name: "Invalid Characters", + forwardHost: "invalid$host", + wantErr: true, + }, + { + name: "Empty Host", + forwardHost: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host := &models.ProxyHost{ + DomainNames: "test-" + tt.name + ".example.com", + ForwardHost: tt.forwardHost, + ForwardPort: 8080, + } + // We only care about validation error + err := service.Create(host) + if tt.wantErr { + assert.Error(t, err) + } else if err != nil { + // Check if error is validation or something else + // If it's something else, it might be fine for this test context + // but "forward host must be..." is what we look for. + assert.NotContains(t, err.Error(), "forward host", "Should not fail validation") + } + }) + } +}