fix: enhance security features

- Updated `crowdsec_handler.go` to log inaccessible paths during config export and handle permission errors gracefully.
- Modified `emergency_handler.go` to clear admin whitelist during security reset and ensure proper updates to security configurations.
- Enhanced user password update functionality in `user_handler.go` to reset failed login attempts and lockout status.
- Introduced rate limiting middleware in `cerberus` to manage request rates and prevent abuse, with comprehensive tests for various scenarios.
- Added validation for proxy host entries in `proxyhost_service.go` to ensure valid hostnames and IP addresses, including tests for various cases.
- Improved IP matching logic in `whitelist.go` to support both IPv4 and IPv6 loopback addresses.
- Updated configuration loading in `config.go` to include rate limiting parameters from environment variables.
- Added tests for new functionalities and validations to ensure robustness and reliability.
This commit is contained in:
GitHub Actions
2026-02-07 23:48:13 +00:00
parent 1e2d16cf13
commit 9ec23cd48b
15 changed files with 836 additions and 71 deletions

View File

@@ -9,14 +9,6 @@ import (
"testing"
)
package main
import (
"os"
"path/filepath"
"testing"
)
func TestSeedMain_CreatesDatabaseFile(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
@@ -44,42 +36,3 @@ func TestSeedMain_CreatesDatabaseFile(t *testing.T) {
t.Fatalf("expected db file to be non-empty")
}
}
package main
package main
import (
} } t.Fatalf("expected db file to be non-empty") if info.Size() == 0 { } t.Fatalf("expected db file to exist at %s: %v", dbPath, err) if err != nil { info, err := os.Stat(dbPath) dbPath := filepath.Join("data", "charon.db") main() } t.Fatalf("mkdir data: %v", err) if err := os.MkdirAll("data", 0o755); err != nil { t.Cleanup(func() { _ = os.Chdir(wd) }) } t.Fatalf("chdir: %v", err) if err := os.Chdir(tmp); err != nil { tmp := t.TempDir() } t.Fatalf("getwd: %v", err) if err != nil { wd, err := os.Getwd() t.Parallel()func TestSeedMain_CreatesDatabaseFile(t *testing.T) {) "testing" "path/filepath" "os"

View File

@@ -754,7 +754,8 @@ func (h *CrowdsecHandler) ExportConfig(c *gin.Context) {
// Walk the DataDir and add files to the archive
err := filepath.Walk(h.DataDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
logger.Log().WithError(err).Warnf("failed to access path %s during export walk", path)
return nil // Skip files we cannot access
}
if info.IsDir() {
return nil
@@ -798,13 +799,18 @@ func (h *CrowdsecHandler) ExportConfig(c *gin.Context) {
// ListFiles returns a flat list of files under the CrowdSec DataDir.
func (h *CrowdsecHandler) ListFiles(c *gin.Context) {
var files []string
files := []string{}
if _, err := os.Stat(h.DataDir); os.IsNotExist(err) {
c.JSON(http.StatusOK, gin.H{"files": files})
return
}
err := filepath.Walk(h.DataDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
// Permission errors (e.g. lost+found) should not abort the walk
if os.IsPermission(err) {
logger.Log().WithError(err).WithField("path", path).Debug("Skipping inaccessible path during list")
return filepath.SkipDir
}
return err
}
if !info.IsDir() {
@@ -1754,7 +1760,9 @@ func (h *CrowdsecHandler) GetKeyStatus(c *gin.Context) {
// No key available
response.KeySource = "none"
response.Valid = false
response.Message = "No CrowdSec API key configured. Start CrowdSec to auto-generate one."
if response.Message == "" {
response.Message = "No CrowdSec API key configured. Start CrowdSec to auto-generate one."
}
}
c.JSON(http.StatusOK, response)
@@ -2002,13 +2010,14 @@ func (h *CrowdsecHandler) GetBouncerInfo(c *gin.Context) {
fileKey := readKeyFromFile(bouncerKeyFile)
var fullKey string
if envKey != "" {
switch {
case envKey != "":
info.KeySource = "env_var"
fullKey = envKey
} else if fileKey != "" {
case fileKey != "":
info.KeySource = "file"
fullKey = fileKey
} else {
default:
info.KeySource = "none"
}

View File

@@ -245,10 +245,22 @@ func (h *EmergencyHandler) disableAllSecurityModules() ([]string, error) {
disabledModules = append(disabledModules, key)
}
// Clear admin whitelist to prevent bypass persistence after reset
adminWhitelistSetting := models.Setting{
Key: "security.admin_whitelist",
Value: "",
Category: "security",
Type: "string",
}
if err := h.db.Where(models.Setting{Key: adminWhitelistSetting.Key}).Assign(adminWhitelistSetting).FirstOrCreate(&adminWhitelistSetting).Error; err != nil {
return disabledModules, fmt.Errorf("failed to clear admin whitelist: %w", err)
}
// Also update the SecurityConfig record if it exists
var securityConfig models.SecurityConfig
if err := h.db.Where("name = ?", "default").First(&securityConfig).Error; err == nil {
securityConfig.Enabled = false
securityConfig.AdminWhitelist = ""
securityConfig.WAFMode = "disabled"
securityConfig.RateLimitMode = "disabled"
securityConfig.RateLimitEnable = false

View File

@@ -125,12 +125,19 @@ func TestEmergencySecurityReset_Success(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "disabled", crowdsecMode.Value)
// Verify admin whitelist is cleared
var adminWhitelist models.Setting
err = db.Where("key = ?", "security.admin_whitelist").First(&adminWhitelist).Error
require.NoError(t, err)
assert.Equal(t, "", adminWhitelist.Value)
// Verify SecurityConfig was updated
var updatedConfig models.SecurityConfig
err = db.Where("name = ?", "default").First(&updatedConfig).Error
require.NoError(t, err)
assert.False(t, updatedConfig.Enabled)
assert.Equal(t, "disabled", updatedConfig.WAFMode)
assert.Equal(t, "", updatedConfig.AdminWhitelist)
// Note: Audit logging is async via SecurityService channel, tested separately
}

View File

@@ -601,6 +601,7 @@ func (h *UserHandler) GetUser(c *gin.Context) {
type UpdateUserRequest struct {
Name string `json:"name"`
Email string `json:"email"`
Password *string `json:"password" binding:"omitempty,min=8"`
Role string `json:"role"`
Enabled *bool `json:"enabled"`
}
@@ -653,6 +654,16 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
updates["role"] = req.Role
}
if req.Password != nil {
if err := user.SetPassword(*req.Password); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to hash password"})
return
}
updates["password_hash"] = user.PasswordHash
updates["failed_login_attempts"] = 0
updates["locked_until"] = nil
}
if req.Enabled != nil {
updates["enabled"] = *req.Enabled
}

View File

@@ -754,6 +754,43 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
}
func TestUserHandler_UpdateUser_PasswordReset(t *testing.T) {
handler, db := setupUserHandlerWithProxyHosts(t)
user := &models.User{UUID: uuid.NewString(), Email: "reset@example.com", Name: "Reset User", Role: "user"}
require.NoError(t, user.SetPassword("oldpassword123"))
lockUntil := time.Now().Add(10 * time.Minute)
user.FailedLoginAttempts = 4
user.LockedUntil = &lockUntil
db.Create(user)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", "admin")
c.Next()
})
r.PUT("/users/:id", handler.UpdateUser)
body := map[string]any{
"password": "newpassword123",
}
jsonBody, _ := json.Marshal(body)
req := httptest.NewRequest("PUT", "/users/1", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var updated models.User
db.First(&updated, user.ID)
assert.True(t, updated.CheckPassword("newpassword123"))
assert.False(t, updated.CheckPassword("oldpassword123"))
assert.Equal(t, 0, updated.FailedLoginAttempts)
assert.Nil(t, updated.LockedUntil)
}
func TestUserHandler_DeleteUser_NonAdmin(t *testing.T) {
handler, _ := setupUserHandlerWithProxyHosts(t)
gin.SetMode(gin.TestMode)

View File

@@ -130,6 +130,7 @@ func RegisterWithDeps(router *gin.Engine, db *gorm.DB, cfg config.Config, caddyM
// Emergency endpoint
emergencyHandler := handlers.NewEmergencyHandlerWithDeps(db, caddyManager, cerb)
emergency := router.Group("/api/v1/emergency")
// Emergency endpoints must stay responsive and should not be rate limited.
emergency.POST("/security-reset", emergencyHandler.SecurityReset)
// Emergency token management (admin-only, protected by EmergencyBypass middleware)
@@ -146,7 +147,11 @@ func RegisterWithDeps(router *gin.Engine, db *gorm.DB, cfg config.Config, caddyM
authMiddleware := middleware.AuthMiddleware(authService)
api := router.Group("/api/v1")
// Rate Limiting (Emergency/Go-layer) MUST run before Auth to prevent 401 masking 429
api.Use(cerb.RateLimitMiddleware())
api.Use(middleware.OptionalAuth(authService))
// Cerberus middleware (ACL, WAF Stats, CrowdSec Tracking) runs after Auth
// because ACLs need to know if user is authenticated admin to apply whitelist bypass
api.Use(cerb.Middleware())
// Backup routes

View File

@@ -0,0 +1,179 @@
package cerberus
import (
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/Wikid82/charon/backend/internal/util"
)
// rateLimitManager manages per-IP rate limiters.
type rateLimitManager struct {
mu sync.Mutex
limiters map[string]*rate.Limiter
lastSeen map[string]time.Time
}
func newRateLimitManager() *rateLimitManager {
rl := &rateLimitManager{
limiters: make(map[string]*rate.Limiter),
lastSeen: make(map[string]time.Time),
}
// Start cleanup goroutine
go rl.cleanupLoop()
return rl
}
func (rl *rateLimitManager) cleanupLoop() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.cleanup()
}
}
func (rl *rateLimitManager) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
cutoff := time.Now().Add(-10 * time.Minute)
for ip, seen := range rl.lastSeen {
if seen.Before(cutoff) {
delete(rl.limiters, ip)
delete(rl.lastSeen, ip)
}
}
}
func (rl *rateLimitManager) getLimiter(ip string, r rate.Limit, b int) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
lim, exists := rl.limiters[ip]
if !exists {
lim = rate.NewLimiter(r, b)
rl.limiters[ip] = lim
}
rl.lastSeen[ip] = time.Now()
// Check if limit changed (re-config)
if lim.Limit() != r || lim.Burst() != b {
lim = rate.NewLimiter(r, b)
rl.limiters[ip] = lim
}
return lim
}
// NewRateLimitMiddleware creates a new rate limit middleware with fixed parameters.
// Useful for testing or when Cerberus context is not available.
func NewRateLimitMiddleware(requests int, windowSec int, burst int) gin.HandlerFunc {
mgr := newRateLimitManager()
if windowSec <= 0 {
windowSec = 1
}
limit := rate.Limit(float64(requests) / float64(windowSec))
return func(ctx *gin.Context) {
// Check for emergency bypass flag
if bypass, exists := ctx.Get("emergency_bypass"); exists && bypass.(bool) {
ctx.Next()
return
}
clientIP := util.CanonicalizeIPForSecurity(ctx.ClientIP())
limiter := mgr.getLimiter(clientIP, limit, burst)
if !limiter.Allow() {
logger.Log().WithField("ip", clientIP).Warn("Rate limit exceeded (Go middleware)")
ctx.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Too many requests"})
return
}
ctx.Next()
}
}
// RateLimitMiddleware enforces rate limiting based on security config.
func (c *Cerberus) RateLimitMiddleware() gin.HandlerFunc {
mgr := newRateLimitManager()
return func(ctx *gin.Context) {
// Check for emergency bypass flag
if bypass, exists := ctx.Get("emergency_bypass"); exists && bypass.(bool) {
ctx.Next()
return
}
// Check config enabled status
enabled := false
if c.cfg.RateLimitMode == "enabled" {
enabled = true
} else {
// Check dynamic setting
if v, ok := c.getSetting("security.rate_limit.enabled"); ok && strings.EqualFold(v, "true") {
enabled = true
}
}
if !enabled {
ctx.Next()
return
}
// Determine limits
requests := 100 // per window
window := 60 // seconds
burst := 20
if c.cfg.RateLimitRequests > 0 {
requests = c.cfg.RateLimitRequests
}
if c.cfg.RateLimitWindowSec > 0 {
window = c.cfg.RateLimitWindowSec
}
if c.cfg.RateLimitBurst > 0 {
burst = c.cfg.RateLimitBurst
}
// Check for dynamic overrides from settings (Issue #3 fix)
if val, ok := c.getSetting("security.rate_limit.requests"); ok {
if v, err := strconv.Atoi(val); err == nil && v > 0 {
requests = v
}
}
if val, ok := c.getSetting("security.rate_limit.window"); ok {
if v, err := strconv.Atoi(val); err == nil && v > 0 {
window = v
}
}
if val, ok := c.getSetting("security.rate_limit.burst"); ok {
if v, err := strconv.Atoi(val); err == nil && v > 0 {
burst = v
}
}
if window == 0 {
window = 60
}
limit := rate.Limit(float64(requests) / float64(window))
clientIP := util.CanonicalizeIPForSecurity(ctx.ClientIP())
limiter := mgr.getLimiter(clientIP, limit, burst)
if !limiter.Allow() {
logger.Log().WithField("ip", clientIP).Warn("Rate limit exceeded (Go middleware)")
ctx.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Too many requests"})
return
}
ctx.Next()
}
}

View File

@@ -0,0 +1,336 @@
package cerberus
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/Wikid82/charon/backend/internal/config"
"github.com/Wikid82/charon/backend/internal/models"
)
func init() {
gin.SetMode(gin.TestMode)
}
func setupRateLimitTestDB(t *testing.T) *gorm.DB {
t.Helper()
dsn := fmt.Sprintf("file:rate_limit_test_%d?mode=memory&cache=shared", time.Now().UnixNano())
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&models.Setting{}))
return db
}
func TestRateLimitMiddleware(t *testing.T) {
t.Run("Blocks excessive requests", func(t *testing.T) {
// Limit to 5 requests per second, with burst of 5
mw := NewRateLimitMiddleware(5, 1, 5)
r := gin.New()
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Make 5 allowed requests
for i := 0; i < 5; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Make 6th request (should fail)
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
})
t.Run("Different IPs have separate limits", func(t *testing.T) {
mw := NewRateLimitMiddleware(1, 1, 1)
r := gin.New()
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 1st User
req1, _ := http.NewRequest("GET", "/", nil)
req1.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req1)
assert.Equal(t, http.StatusOK, w1.Code)
// 2nd User (should pass)
req2, _ := http.NewRequest("GET", "/", nil)
req2.RemoteAddr = "10.0.0.2:1234"
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
assert.Equal(t, http.StatusOK, w2.Code)
})
t.Run("Replenishes tokens over time", func(t *testing.T) {
// 1 request per second (burst 1)
mw := NewRateLimitMiddleware(1, 1, 1)
// Manually override the burst/limit for predictable testing isn't easy with wrapper
// So we rely on the implementation using x/time/rate
// Test:
// 1. Consume 1
// 2. Consume 2 (Fail)
// 3. Wait until refill
// 4. Consume 3 (Pass)
r := gin.New()
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
// 1. Consume
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
// 2. Consume Fail
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
// 3. Wait until refill
require.Eventually(t, func() bool {
w3 := httptest.NewRecorder()
r.ServeHTTP(w3, req)
return w3.Code == http.StatusOK
}, 1500*time.Millisecond, 25*time.Millisecond)
})
}
func TestRateLimitManager_ReconfiguresLimiter(t *testing.T) {
mgr := &rateLimitManager{
limiters: make(map[string]*rate.Limiter),
lastSeen: make(map[string]time.Time),
}
limiter := mgr.getLimiter("10.0.0.1", rate.Limit(1), 1)
assert.Equal(t, rate.Limit(1), limiter.Limit())
assert.Equal(t, 1, limiter.Burst())
limiter = mgr.getLimiter("10.0.0.1", rate.Limit(2), 2)
assert.Equal(t, rate.Limit(2), limiter.Limit())
assert.Equal(t, 2, limiter.Burst())
}
func TestRateLimitManager_CleanupRemovesStaleEntries(t *testing.T) {
mgr := &rateLimitManager{
limiters: map[string]*rate.Limiter{
"10.0.0.1": rate.NewLimiter(rate.Limit(1), 1),
},
lastSeen: map[string]time.Time{
"10.0.0.1": time.Now().Add(-11 * time.Minute),
},
}
mgr.cleanup()
assert.Empty(t, mgr.limiters)
assert.Empty(t, mgr.lastSeen)
}
func TestRateLimitMiddleware_EmergencyBypass(t *testing.T) {
mw := NewRateLimitMiddleware(1, 1, 1)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("emergency_bypass", true)
c.Next()
})
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_DisabledAllowsTraffic(t *testing.T) {
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, nil)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_EnabledByConfig(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 1,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
for i := 0; i < 2; i++ {
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if i == 0 {
assert.Equal(t, http.StatusOK, w.Code)
} else {
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
}
}
func TestCerberusRateLimitMiddleware_EmergencyBypass(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 1,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("emergency_bypass", true)
c.Next()
})
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_EnabledBySetting(t *testing.T) {
db := setupRateLimitTestDB(t)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, db)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}
func TestCerberusRateLimitMiddleware_OverridesConfigWithSettings(t *testing.T) {
db := setupRateLimitTestDB(t)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 10,
RateLimitWindowSec: 10,
RateLimitBurst: 10,
}
cerb := New(cfg, db)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}
func TestCerberusRateLimitMiddleware_WindowFallback(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 0,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
)
@@ -29,14 +30,17 @@ type Config struct {
// SecurityConfig holds configuration for optional security services.
type SecurityConfig struct {
CrowdSecMode string
CrowdSecAPIURL string
CrowdSecAPIKey string
CrowdSecConfigDir string
WAFMode string
RateLimitMode string
ACLMode string
CerberusEnabled bool
CrowdSecMode string
CrowdSecAPIURL string
CrowdSecAPIKey string
CrowdSecConfigDir string
WAFMode string
RateLimitMode string
RateLimitRequests int
RateLimitWindowSec int
RateLimitBurst int
ACLMode string
CerberusEnabled bool
// ManagementCIDRs defines IP ranges allowed to use emergency break glass token
// Default: RFC1918 private networks (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 127.0.0.0/8)
ManagementCIDRs []string
@@ -110,14 +114,17 @@ func Load() (Config, error) {
// loadSecurityConfig loads the security configuration with proper parsing of array fields
func loadSecurityConfig() SecurityConfig {
cfg := SecurityConfig{
CrowdSecMode: getEnvAny("disabled", "CERBERUS_SECURITY_CROWDSEC_MODE", "CHARON_SECURITY_CROWDSEC_MODE", "CPM_SECURITY_CROWDSEC_MODE"),
CrowdSecAPIURL: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_URL", "CHARON_SECURITY_CROWDSEC_API_URL", "CPM_SECURITY_CROWDSEC_API_URL"),
CrowdSecAPIKey: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"),
CrowdSecConfigDir: getEnvAny(filepath.Join("data", "crowdsec"), "CHARON_CROWDSEC_CONFIG_DIR", "CPM_CROWDSEC_CONFIG_DIR"),
WAFMode: getEnvAny("disabled", "CERBERUS_SECURITY_WAF_MODE", "CHARON_SECURITY_WAF_MODE", "CPM_SECURITY_WAF_MODE"),
RateLimitMode: getEnvAny("disabled", "CERBERUS_SECURITY_RATELIMIT_MODE", "CHARON_SECURITY_RATELIMIT_MODE", "CPM_SECURITY_RATELIMIT_MODE"),
ACLMode: getEnvAny("disabled", "CERBERUS_SECURITY_ACL_MODE", "CHARON_SECURITY_ACL_MODE", "CPM_SECURITY_ACL_MODE"),
CerberusEnabled: getEnvAny("true", "CERBERUS_SECURITY_CERBERUS_ENABLED", "CHARON_SECURITY_CERBERUS_ENABLED", "CPM_SECURITY_CERBERUS_ENABLED") != "false",
CrowdSecMode: getEnvAny("disabled", "CERBERUS_SECURITY_CROWDSEC_MODE", "CHARON_SECURITY_CROWDSEC_MODE", "CPM_SECURITY_CROWDSEC_MODE"),
CrowdSecAPIURL: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_URL", "CHARON_SECURITY_CROWDSEC_API_URL", "CPM_SECURITY_CROWDSEC_API_URL"),
CrowdSecAPIKey: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"),
CrowdSecConfigDir: getEnvAny(filepath.Join("data", "crowdsec"), "CHARON_CROWDSEC_CONFIG_DIR", "CPM_CROWDSEC_CONFIG_DIR"),
WAFMode: getEnvAny("disabled", "CERBERUS_SECURITY_WAF_MODE", "CHARON_SECURITY_WAF_MODE", "CPM_SECURITY_WAF_MODE"),
RateLimitMode: getEnvAny("disabled", "CERBERUS_SECURITY_RATELIMIT_MODE", "CHARON_SECURITY_RATELIMIT_MODE", "CPM_SECURITY_RATELIMIT_MODE"),
RateLimitRequests: getEnvIntAny(100, "CERBERUS_SECURITY_RATELIMIT_REQUESTS", "CHARON_SECURITY_RATELIMIT_REQUESTS"),
RateLimitWindowSec: getEnvIntAny(60, "CERBERUS_SECURITY_RATELIMIT_WINDOW", "CHARON_SECURITY_RATELIMIT_WINDOW"),
RateLimitBurst: getEnvIntAny(20, "CERBERUS_SECURITY_RATELIMIT_BURST", "CHARON_SECURITY_RATELIMIT_BURST"),
ACLMode: getEnvAny("disabled", "CERBERUS_SECURITY_ACL_MODE", "CHARON_SECURITY_ACL_MODE", "CPM_SECURITY_ACL_MODE"),
CerberusEnabled: getEnvAny("true", "CERBERUS_SECURITY_CERBERUS_ENABLED", "CHARON_SECURITY_CERBERUS_ENABLED", "CPM_SECURITY_CERBERUS_ENABLED") != "false",
}
// Parse management CIDRs (comma-separated list)
@@ -173,3 +180,16 @@ func getEnvAny(fallback string, keys ...string) string {
}
return fallback
}
// getEnvIntAny checks a list of environment variable names, attempts to parse as int.
// Returns first successfully parsed value. Returns fallback if none found or parsing failed.
func getEnvIntAny(fallback int, keys ...string) int {
valStr := getEnvAny("", keys...)
if valStr == "" {
return fallback
}
if val, err := strconv.Atoi(valStr); err == nil {
return val
}
return fallback
}

View File

@@ -28,6 +28,14 @@ func IsIPInCIDRList(clientIP, cidrList string) bool {
}
if parsed := net.ParseIP(entry); parsed != nil {
// Fix for Issue 1: Canonicalize entry to support mixed IPv4/IPv6 loopback matching
// This ensures that "::1" in the list matches "127.0.0.1" (from canonicalized client IP)
if canonEntry := util.CanonicalizeIPForSecurity(entry); canonEntry != "" {
if p := net.ParseIP(canonEntry); p != nil {
parsed = p
}
}
if ip.Equal(parsed) {
return true
}
@@ -41,6 +49,12 @@ func IsIPInCIDRList(clientIP, cidrList string) bool {
if cidr.Contains(ip) {
return true
}
// Fix for Issue 1: Handle IPv6 loopback CIDR matching against canonicalized IPv4 localhost
// If client is 127.0.0.1 (canonical localhost) and CIDR contains ::1, allow it
if ip.Equal(net.IPv4(127, 0, 0, 1)) && cidr.Contains(net.IPv6loopback) {
return true
}
}
return false

View File

@@ -45,6 +45,18 @@ func TestIsIPInCIDRList(t *testing.T) {
list: "192.168.0.0/16",
expected: false,
},
{
name: "IPv6 loopback match",
ip: "::1",
list: "::1",
expected: true,
},
{
name: "IPv6 loopback CIDR match",
ip: "::1",
list: "::1/128",
expected: true,
},
}
for _, tt := range tests {

View File

@@ -42,8 +42,8 @@ func (m *mockCrowdsecExecutor) Status(ctx context.Context, configDir string) (ru
// mockCommandExecutor is a test mock for CommandExecutor interface
type mockCommandExecutor struct {
executeCalls [][]string // Track command invocations
executeErr error // Error to return
executeOut []byte // Output to return
executeErr error // Error to return
executeOut []byte // Output to return
}
func (m *mockCommandExecutor) Execute(ctx context.Context, name string, args ...string) ([]byte, error) {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/Wikid82/charon/backend/internal/caddy"
@@ -46,12 +47,82 @@ func (s *ProxyHostService) ValidateUniqueDomain(domainNames string, excludeID ui
return nil
}
// ValidateHostname checks if the provided string is a valid hostname or IP address.
func (s *ProxyHostService) ValidateHostname(host string) error {
// Trim protocol if present
if len(host) > 8 && host[:8] == "https://" {
host = host[8:]
} else if len(host) > 7 && host[:7] == "http://" {
host = host[7:]
}
// Remove port if present
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
}
// Basic check: is it an IP?
if net.ParseIP(host) != nil {
return nil
}
// Is it a valid hostname/domain?
// Regex for hostname validation (RFC 1123 mostly)
// Simple version: alphanumeric, dots, dashes.
// Allow underscores? Technically usually not in hostnames, but internal docker ones yes.
for _, r := range host {
if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '.' && r != '-' && r != '_' {
// Allow ":" for IPv6 literals if not parsed by ParseIP? ParseIP handles IPv6.
return errors.New("invalid hostname format")
}
}
return nil
}
func (s *ProxyHostService) validateProxyHost(host *models.ProxyHost) error {
if host.ForwardHost == "" {
return errors.New("forward host is required")
}
// Basic hostname/IP validation
target := host.ForwardHost
// Strip protocol if user accidentally typed http://10.0.0.1
target = strings.TrimPrefix(target, "http://")
target = strings.TrimPrefix(target, "https://")
// Strip port if present
if h, _, err := net.SplitHostPort(target); err == nil {
target = h
}
// Validate target
if net.ParseIP(target) == nil {
// Not a valid IP, check hostname rules
// Allow: a-z, 0-9, -, ., _ (for docker service names)
validHostname := true
for _, r := range target {
if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '.' && r != '-' && r != '_' {
validHostname = false
break
}
}
if !validHostname {
return errors.New("forward host must be a valid IP address or hostname")
}
}
return nil
}
// Create validates and creates a new proxy host.
func (s *ProxyHostService) Create(host *models.ProxyHost) error {
if err := s.ValidateUniqueDomain(host.DomainNames, 0); err != nil {
return err
}
if err := s.validateProxyHost(host); err != nil {
return err
}
// Normalize and validate advanced config (if present)
if host.AdvancedConfig != "" {
var parsed any
@@ -75,6 +146,10 @@ func (s *ProxyHostService) Update(host *models.ProxyHost) error {
return err
}
if err := s.validateProxyHost(host); err != nil {
return err
}
// Normalize and validate advanced config (if present)
if host.AdvancedConfig != "" {
var parsed any

View File

@@ -0,0 +1,95 @@
package services
import (
"testing"
"github.com/Wikid82/charon/backend/internal/models"
"github.com/stretchr/testify/assert"
)
func TestProxyHostService_ForwardHostValidation(t *testing.T) {
db := setupProxyHostTestDB(t)
service := NewProxyHostService(db)
tests := []struct {
name string
forwardHost string
wantErr bool
}{
{
name: "Valid IP",
forwardHost: "192.168.1.1",
wantErr: false,
},
{
name: "Valid Hostname",
forwardHost: "example.com",
wantErr: false,
},
{
name: "Docker Service Name",
forwardHost: "my-service",
wantErr: false,
},
{
name: "Docker Service Name with Underscore",
forwardHost: "my_db_Service",
wantErr: false,
},
{
name: "Docker Internal Host",
forwardHost: "host.docker.internal",
wantErr: false,
},
{
name: "IP with Port (Should be stripped and pass)",
forwardHost: "192.168.1.1:8080",
wantErr: false,
},
{
name: "Hostname with Port (Should be stripped and pass)",
forwardHost: "example.com:3000",
wantErr: false,
},
{
name: "Host with http scheme (Should be stripped and pass)",
forwardHost: "http://example.com",
wantErr: false,
},
{
name: "Host with https scheme (Should be stripped and pass)",
forwardHost: "https://example.com",
wantErr: false,
},
{
name: "Invalid Characters",
forwardHost: "invalid$host",
wantErr: true,
},
{
name: "Empty Host",
forwardHost: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
host := &models.ProxyHost{
DomainNames: "test-" + tt.name + ".example.com",
ForwardHost: tt.forwardHost,
ForwardPort: 8080,
}
// We only care about validation error
err := service.Create(host)
if tt.wantErr {
assert.Error(t, err)
} else if err != nil {
// Check if error is validation or something else
// If it's something else, it might be fine for this test context
// but "forward host must be..." is what we look for.
assert.NotContains(t, err.Error(), "forward host", "Should not fail validation")
}
})
}
}