fix: enhance security features
- Updated `crowdsec_handler.go` to log inaccessible paths during config export and handle permission errors gracefully. - Modified `emergency_handler.go` to clear admin whitelist during security reset and ensure proper updates to security configurations. - Enhanced user password update functionality in `user_handler.go` to reset failed login attempts and lockout status. - Introduced rate limiting middleware in `cerberus` to manage request rates and prevent abuse, with comprehensive tests for various scenarios. - Added validation for proxy host entries in `proxyhost_service.go` to ensure valid hostnames and IP addresses, including tests for various cases. - Improved IP matching logic in `whitelist.go` to support both IPv4 and IPv6 loopback addresses. - Updated configuration loading in `config.go` to include rate limiting parameters from environment variables. - Added tests for new functionalities and validations to ensure robustness and reliability.
This commit is contained in:
@@ -9,14 +9,6 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSeedMain_CreatesDatabaseFile(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
@@ -44,42 +36,3 @@ func TestSeedMain_CreatesDatabaseFile(t *testing.T) {
|
||||
t.Fatalf("expected db file to be non-empty")
|
||||
}
|
||||
}
|
||||
package main
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
} } t.Fatalf("expected db file to be non-empty") if info.Size() == 0 { } t.Fatalf("expected db file to exist at %s: %v", dbPath, err) if err != nil { info, err := os.Stat(dbPath) dbPath := filepath.Join("data", "charon.db") main() } t.Fatalf("mkdir data: %v", err) if err := os.MkdirAll("data", 0o755); err != nil { t.Cleanup(func() { _ = os.Chdir(wd) }) } t.Fatalf("chdir: %v", err) if err := os.Chdir(tmp); err != nil { tmp := t.TempDir() } t.Fatalf("getwd: %v", err) if err != nil { wd, err := os.Getwd() t.Parallel()func TestSeedMain_CreatesDatabaseFile(t *testing.T) {) "testing" "path/filepath" "os"
|
||||
|
||||
@@ -754,7 +754,8 @@ func (h *CrowdsecHandler) ExportConfig(c *gin.Context) {
|
||||
// Walk the DataDir and add files to the archive
|
||||
err := filepath.Walk(h.DataDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
logger.Log().WithError(err).Warnf("failed to access path %s during export walk", path)
|
||||
return nil // Skip files we cannot access
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
@@ -798,13 +799,18 @@ func (h *CrowdsecHandler) ExportConfig(c *gin.Context) {
|
||||
|
||||
// ListFiles returns a flat list of files under the CrowdSec DataDir.
|
||||
func (h *CrowdsecHandler) ListFiles(c *gin.Context) {
|
||||
var files []string
|
||||
files := []string{}
|
||||
if _, err := os.Stat(h.DataDir); os.IsNotExist(err) {
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
return
|
||||
}
|
||||
err := filepath.Walk(h.DataDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
// Permission errors (e.g. lost+found) should not abort the walk
|
||||
if os.IsPermission(err) {
|
||||
logger.Log().WithError(err).WithField("path", path).Debug("Skipping inaccessible path during list")
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
@@ -1754,7 +1760,9 @@ func (h *CrowdsecHandler) GetKeyStatus(c *gin.Context) {
|
||||
// No key available
|
||||
response.KeySource = "none"
|
||||
response.Valid = false
|
||||
response.Message = "No CrowdSec API key configured. Start CrowdSec to auto-generate one."
|
||||
if response.Message == "" {
|
||||
response.Message = "No CrowdSec API key configured. Start CrowdSec to auto-generate one."
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
@@ -2002,13 +2010,14 @@ func (h *CrowdsecHandler) GetBouncerInfo(c *gin.Context) {
|
||||
fileKey := readKeyFromFile(bouncerKeyFile)
|
||||
|
||||
var fullKey string
|
||||
if envKey != "" {
|
||||
switch {
|
||||
case envKey != "":
|
||||
info.KeySource = "env_var"
|
||||
fullKey = envKey
|
||||
} else if fileKey != "" {
|
||||
case fileKey != "":
|
||||
info.KeySource = "file"
|
||||
fullKey = fileKey
|
||||
} else {
|
||||
default:
|
||||
info.KeySource = "none"
|
||||
}
|
||||
|
||||
|
||||
@@ -245,10 +245,22 @@ func (h *EmergencyHandler) disableAllSecurityModules() ([]string, error) {
|
||||
disabledModules = append(disabledModules, key)
|
||||
}
|
||||
|
||||
// Clear admin whitelist to prevent bypass persistence after reset
|
||||
adminWhitelistSetting := models.Setting{
|
||||
Key: "security.admin_whitelist",
|
||||
Value: "",
|
||||
Category: "security",
|
||||
Type: "string",
|
||||
}
|
||||
if err := h.db.Where(models.Setting{Key: adminWhitelistSetting.Key}).Assign(adminWhitelistSetting).FirstOrCreate(&adminWhitelistSetting).Error; err != nil {
|
||||
return disabledModules, fmt.Errorf("failed to clear admin whitelist: %w", err)
|
||||
}
|
||||
|
||||
// Also update the SecurityConfig record if it exists
|
||||
var securityConfig models.SecurityConfig
|
||||
if err := h.db.Where("name = ?", "default").First(&securityConfig).Error; err == nil {
|
||||
securityConfig.Enabled = false
|
||||
securityConfig.AdminWhitelist = ""
|
||||
securityConfig.WAFMode = "disabled"
|
||||
securityConfig.RateLimitMode = "disabled"
|
||||
securityConfig.RateLimitEnable = false
|
||||
|
||||
@@ -125,12 +125,19 @@ func TestEmergencySecurityReset_Success(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "disabled", crowdsecMode.Value)
|
||||
|
||||
// Verify admin whitelist is cleared
|
||||
var adminWhitelist models.Setting
|
||||
err = db.Where("key = ?", "security.admin_whitelist").First(&adminWhitelist).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", adminWhitelist.Value)
|
||||
|
||||
// Verify SecurityConfig was updated
|
||||
var updatedConfig models.SecurityConfig
|
||||
err = db.Where("name = ?", "default").First(&updatedConfig).Error
|
||||
require.NoError(t, err)
|
||||
assert.False(t, updatedConfig.Enabled)
|
||||
assert.Equal(t, "disabled", updatedConfig.WAFMode)
|
||||
assert.Equal(t, "", updatedConfig.AdminWhitelist)
|
||||
|
||||
// Note: Audit logging is async via SecurityService channel, tested separately
|
||||
}
|
||||
|
||||
@@ -601,6 +601,7 @@ func (h *UserHandler) GetUser(c *gin.Context) {
|
||||
type UpdateUserRequest struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Password *string `json:"password" binding:"omitempty,min=8"`
|
||||
Role string `json:"role"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
}
|
||||
@@ -653,6 +654,16 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
updates["role"] = req.Role
|
||||
}
|
||||
|
||||
if req.Password != nil {
|
||||
if err := user.SetPassword(*req.Password); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to hash password"})
|
||||
return
|
||||
}
|
||||
updates["password_hash"] = user.PasswordHash
|
||||
updates["failed_login_attempts"] = 0
|
||||
updates["locked_until"] = nil
|
||||
}
|
||||
|
||||
if req.Enabled != nil {
|
||||
updates["enabled"] = *req.Enabled
|
||||
}
|
||||
|
||||
@@ -754,6 +754,43 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdateUser_PasswordReset(t *testing.T) {
|
||||
handler, db := setupUserHandlerWithProxyHosts(t)
|
||||
|
||||
user := &models.User{UUID: uuid.NewString(), Email: "reset@example.com", Name: "Reset User", Role: "user"}
|
||||
require.NoError(t, user.SetPassword("oldpassword123"))
|
||||
lockUntil := time.Now().Add(10 * time.Minute)
|
||||
user.FailedLoginAttempts = 4
|
||||
user.LockedUntil = &lockUntil
|
||||
db.Create(user)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("role", "admin")
|
||||
c.Next()
|
||||
})
|
||||
r.PUT("/users/:id", handler.UpdateUser)
|
||||
|
||||
body := map[string]any{
|
||||
"password": "newpassword123",
|
||||
}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest("PUT", "/users/1", bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var updated models.User
|
||||
db.First(&updated, user.ID)
|
||||
assert.True(t, updated.CheckPassword("newpassword123"))
|
||||
assert.False(t, updated.CheckPassword("oldpassword123"))
|
||||
assert.Equal(t, 0, updated.FailedLoginAttempts)
|
||||
assert.Nil(t, updated.LockedUntil)
|
||||
}
|
||||
|
||||
func TestUserHandler_DeleteUser_NonAdmin(t *testing.T) {
|
||||
handler, _ := setupUserHandlerWithProxyHosts(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -130,6 +130,7 @@ func RegisterWithDeps(router *gin.Engine, db *gorm.DB, cfg config.Config, caddyM
|
||||
// Emergency endpoint
|
||||
emergencyHandler := handlers.NewEmergencyHandlerWithDeps(db, caddyManager, cerb)
|
||||
emergency := router.Group("/api/v1/emergency")
|
||||
// Emergency endpoints must stay responsive and should not be rate limited.
|
||||
emergency.POST("/security-reset", emergencyHandler.SecurityReset)
|
||||
|
||||
// Emergency token management (admin-only, protected by EmergencyBypass middleware)
|
||||
@@ -146,7 +147,11 @@ func RegisterWithDeps(router *gin.Engine, db *gorm.DB, cfg config.Config, caddyM
|
||||
authMiddleware := middleware.AuthMiddleware(authService)
|
||||
|
||||
api := router.Group("/api/v1")
|
||||
// Rate Limiting (Emergency/Go-layer) MUST run before Auth to prevent 401 masking 429
|
||||
api.Use(cerb.RateLimitMiddleware())
|
||||
api.Use(middleware.OptionalAuth(authService))
|
||||
// Cerberus middleware (ACL, WAF Stats, CrowdSec Tracking) runs after Auth
|
||||
// because ACLs need to know if user is authenticated admin to apply whitelist bypass
|
||||
api.Use(cerb.Middleware())
|
||||
|
||||
// Backup routes
|
||||
|
||||
179
backend/internal/cerberus/rate_limit.go
Normal file
179
backend/internal/cerberus/rate_limit.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package cerberus
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/util"
|
||||
)
|
||||
|
||||
// rateLimitManager manages per-IP rate limiters.
|
||||
type rateLimitManager struct {
|
||||
mu sync.Mutex
|
||||
limiters map[string]*rate.Limiter
|
||||
lastSeen map[string]time.Time
|
||||
}
|
||||
|
||||
func newRateLimitManager() *rateLimitManager {
|
||||
rl := &rateLimitManager{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
lastSeen: make(map[string]time.Time),
|
||||
}
|
||||
// Start cleanup goroutine
|
||||
go rl.cleanupLoop()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *rateLimitManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
rl.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimitManager) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
cutoff := time.Now().Add(-10 * time.Minute)
|
||||
for ip, seen := range rl.lastSeen {
|
||||
if seen.Before(cutoff) {
|
||||
delete(rl.limiters, ip)
|
||||
delete(rl.lastSeen, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimitManager) getLimiter(ip string, r rate.Limit, b int) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
lim, exists := rl.limiters[ip]
|
||||
if !exists {
|
||||
lim = rate.NewLimiter(r, b)
|
||||
rl.limiters[ip] = lim
|
||||
}
|
||||
rl.lastSeen[ip] = time.Now()
|
||||
|
||||
// Check if limit changed (re-config)
|
||||
if lim.Limit() != r || lim.Burst() != b {
|
||||
lim = rate.NewLimiter(r, b)
|
||||
rl.limiters[ip] = lim
|
||||
}
|
||||
|
||||
return lim
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware creates a new rate limit middleware with fixed parameters.
|
||||
// Useful for testing or when Cerberus context is not available.
|
||||
func NewRateLimitMiddleware(requests int, windowSec int, burst int) gin.HandlerFunc {
|
||||
mgr := newRateLimitManager()
|
||||
|
||||
if windowSec <= 0 {
|
||||
windowSec = 1
|
||||
}
|
||||
limit := rate.Limit(float64(requests) / float64(windowSec))
|
||||
|
||||
return func(ctx *gin.Context) {
|
||||
// Check for emergency bypass flag
|
||||
if bypass, exists := ctx.Get("emergency_bypass"); exists && bypass.(bool) {
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := util.CanonicalizeIPForSecurity(ctx.ClientIP())
|
||||
limiter := mgr.getLimiter(clientIP, limit, burst)
|
||||
|
||||
if !limiter.Allow() {
|
||||
logger.Log().WithField("ip", clientIP).Warn("Rate limit exceeded (Go middleware)")
|
||||
ctx.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Too many requests"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware enforces rate limiting based on security config.
|
||||
func (c *Cerberus) RateLimitMiddleware() gin.HandlerFunc {
|
||||
mgr := newRateLimitManager()
|
||||
|
||||
return func(ctx *gin.Context) {
|
||||
// Check for emergency bypass flag
|
||||
if bypass, exists := ctx.Get("emergency_bypass"); exists && bypass.(bool) {
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Check config enabled status
|
||||
enabled := false
|
||||
if c.cfg.RateLimitMode == "enabled" {
|
||||
enabled = true
|
||||
} else {
|
||||
// Check dynamic setting
|
||||
if v, ok := c.getSetting("security.rate_limit.enabled"); ok && strings.EqualFold(v, "true") {
|
||||
enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Determine limits
|
||||
requests := 100 // per window
|
||||
window := 60 // seconds
|
||||
burst := 20
|
||||
|
||||
if c.cfg.RateLimitRequests > 0 {
|
||||
requests = c.cfg.RateLimitRequests
|
||||
}
|
||||
if c.cfg.RateLimitWindowSec > 0 {
|
||||
window = c.cfg.RateLimitWindowSec
|
||||
}
|
||||
if c.cfg.RateLimitBurst > 0 {
|
||||
burst = c.cfg.RateLimitBurst
|
||||
}
|
||||
|
||||
// Check for dynamic overrides from settings (Issue #3 fix)
|
||||
if val, ok := c.getSetting("security.rate_limit.requests"); ok {
|
||||
if v, err := strconv.Atoi(val); err == nil && v > 0 {
|
||||
requests = v
|
||||
}
|
||||
}
|
||||
if val, ok := c.getSetting("security.rate_limit.window"); ok {
|
||||
if v, err := strconv.Atoi(val); err == nil && v > 0 {
|
||||
window = v
|
||||
}
|
||||
}
|
||||
if val, ok := c.getSetting("security.rate_limit.burst"); ok {
|
||||
if v, err := strconv.Atoi(val); err == nil && v > 0 {
|
||||
burst = v
|
||||
}
|
||||
}
|
||||
|
||||
if window == 0 {
|
||||
window = 60
|
||||
}
|
||||
limit := rate.Limit(float64(requests) / float64(window))
|
||||
|
||||
clientIP := util.CanonicalizeIPForSecurity(ctx.ClientIP())
|
||||
limiter := mgr.getLimiter(clientIP, limit, burst)
|
||||
|
||||
if !limiter.Allow() {
|
||||
logger.Log().WithField("ip", clientIP).Warn("Rate limit exceeded (Go middleware)")
|
||||
ctx.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Too many requests"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
||||
336
backend/internal/cerberus/rate_limit_test.go
Normal file
336
backend/internal/cerberus/rate_limit_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package cerberus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/config"
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func setupRateLimitTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dsn := fmt.Sprintf("file:rate_limit_test_%d?mode=memory&cache=shared", time.Now().UnixNano())
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&models.Setting{}))
|
||||
return db
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware(t *testing.T) {
|
||||
t.Run("Blocks excessive requests", func(t *testing.T) {
|
||||
// Limit to 5 requests per second, with burst of 5
|
||||
mw := NewRateLimitMiddleware(5, 1, 5)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(mw)
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Make 5 allowed requests
|
||||
for i := 0; i < 5; i++ {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Make 6th request (should fail)
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
})
|
||||
|
||||
t.Run("Different IPs have separate limits", func(t *testing.T) {
|
||||
mw := NewRateLimitMiddleware(1, 1, 1)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(mw)
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 1st User
|
||||
req1, _ := http.NewRequest("GET", "/", nil)
|
||||
req1.RemoteAddr = "10.0.0.1:1234"
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
// 2nd User (should pass)
|
||||
req2, _ := http.NewRequest("GET", "/", nil)
|
||||
req2.RemoteAddr = "10.0.0.2:1234"
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
})
|
||||
|
||||
t.Run("Replenishes tokens over time", func(t *testing.T) {
|
||||
// 1 request per second (burst 1)
|
||||
mw := NewRateLimitMiddleware(1, 1, 1)
|
||||
// Manually override the burst/limit for predictable testing isn't easy with wrapper
|
||||
// So we rely on the implementation using x/time/rate
|
||||
// Test:
|
||||
// 1. Consume 1
|
||||
// 2. Consume 2 (Fail)
|
||||
// 3. Wait until refill
|
||||
// 4. Consume 3 (Pass)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(mw)
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "1.2.3.4:1234"
|
||||
|
||||
// 1. Consume
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, req)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
// 2. Consume Fail
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
||||
|
||||
// 3. Wait until refill
|
||||
require.Eventually(t, func() bool {
|
||||
w3 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w3, req)
|
||||
return w3.Code == http.StatusOK
|
||||
}, 1500*time.Millisecond, 25*time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimitManager_ReconfiguresLimiter(t *testing.T) {
|
||||
mgr := &rateLimitManager{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
lastSeen: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
limiter := mgr.getLimiter("10.0.0.1", rate.Limit(1), 1)
|
||||
assert.Equal(t, rate.Limit(1), limiter.Limit())
|
||||
assert.Equal(t, 1, limiter.Burst())
|
||||
|
||||
limiter = mgr.getLimiter("10.0.0.1", rate.Limit(2), 2)
|
||||
assert.Equal(t, rate.Limit(2), limiter.Limit())
|
||||
assert.Equal(t, 2, limiter.Burst())
|
||||
}
|
||||
|
||||
func TestRateLimitManager_CleanupRemovesStaleEntries(t *testing.T) {
|
||||
mgr := &rateLimitManager{
|
||||
limiters: map[string]*rate.Limiter{
|
||||
"10.0.0.1": rate.NewLimiter(rate.Limit(1), 1),
|
||||
},
|
||||
lastSeen: map[string]time.Time{
|
||||
"10.0.0.1": time.Now().Add(-11 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
mgr.cleanup()
|
||||
assert.Empty(t, mgr.limiters)
|
||||
assert.Empty(t, mgr.lastSeen)
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_EmergencyBypass(t *testing.T) {
|
||||
mw := NewRateLimitMiddleware(1, 1, 1)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("emergency_bypass", true)
|
||||
c.Next()
|
||||
})
|
||||
r.Use(mw)
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCerberusRateLimitMiddleware_DisabledAllowsTraffic(t *testing.T) {
|
||||
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(cerb.RateLimitMiddleware())
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCerberusRateLimitMiddleware_EnabledByConfig(t *testing.T) {
|
||||
cfg := config.SecurityConfig{
|
||||
RateLimitMode: "enabled",
|
||||
RateLimitRequests: 1,
|
||||
RateLimitWindowSec: 1,
|
||||
RateLimitBurst: 1,
|
||||
}
|
||||
cerb := New(cfg, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(cerb.RateLimitMiddleware())
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
for i := 0; i < 2; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if i == 0 {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCerberusRateLimitMiddleware_EmergencyBypass(t *testing.T) {
|
||||
cfg := config.SecurityConfig{
|
||||
RateLimitMode: "enabled",
|
||||
RateLimitRequests: 1,
|
||||
RateLimitWindowSec: 1,
|
||||
RateLimitBurst: 1,
|
||||
}
|
||||
cerb := New(cfg, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("emergency_bypass", true)
|
||||
c.Next()
|
||||
})
|
||||
r.Use(cerb.RateLimitMiddleware())
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCerberusRateLimitMiddleware_EnabledBySetting(t *testing.T) {
|
||||
db := setupRateLimitTestDB(t)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
|
||||
|
||||
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, db)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(cerb.RateLimitMiddleware())
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, req)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
||||
}
|
||||
|
||||
func TestCerberusRateLimitMiddleware_OverridesConfigWithSettings(t *testing.T) {
|
||||
db := setupRateLimitTestDB(t)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
|
||||
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
|
||||
|
||||
cfg := config.SecurityConfig{
|
||||
RateLimitMode: "enabled",
|
||||
RateLimitRequests: 10,
|
||||
RateLimitWindowSec: 10,
|
||||
RateLimitBurst: 10,
|
||||
}
|
||||
cerb := New(cfg, db)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(cerb.RateLimitMiddleware())
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, req)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
||||
}
|
||||
|
||||
func TestCerberusRateLimitMiddleware_WindowFallback(t *testing.T) {
|
||||
cfg := config.SecurityConfig{
|
||||
RateLimitMode: "enabled",
|
||||
RateLimitRequests: 1,
|
||||
RateLimitWindowSec: 0,
|
||||
RateLimitBurst: 1,
|
||||
}
|
||||
cerb := New(cfg, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(cerb.RateLimitMiddleware())
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, req)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -29,14 +30,17 @@ type Config struct {
|
||||
|
||||
// SecurityConfig holds configuration for optional security services.
|
||||
type SecurityConfig struct {
|
||||
CrowdSecMode string
|
||||
CrowdSecAPIURL string
|
||||
CrowdSecAPIKey string
|
||||
CrowdSecConfigDir string
|
||||
WAFMode string
|
||||
RateLimitMode string
|
||||
ACLMode string
|
||||
CerberusEnabled bool
|
||||
CrowdSecMode string
|
||||
CrowdSecAPIURL string
|
||||
CrowdSecAPIKey string
|
||||
CrowdSecConfigDir string
|
||||
WAFMode string
|
||||
RateLimitMode string
|
||||
RateLimitRequests int
|
||||
RateLimitWindowSec int
|
||||
RateLimitBurst int
|
||||
ACLMode string
|
||||
CerberusEnabled bool
|
||||
// ManagementCIDRs defines IP ranges allowed to use emergency break glass token
|
||||
// Default: RFC1918 private networks (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 127.0.0.0/8)
|
||||
ManagementCIDRs []string
|
||||
@@ -110,14 +114,17 @@ func Load() (Config, error) {
|
||||
// loadSecurityConfig loads the security configuration with proper parsing of array fields
|
||||
func loadSecurityConfig() SecurityConfig {
|
||||
cfg := SecurityConfig{
|
||||
CrowdSecMode: getEnvAny("disabled", "CERBERUS_SECURITY_CROWDSEC_MODE", "CHARON_SECURITY_CROWDSEC_MODE", "CPM_SECURITY_CROWDSEC_MODE"),
|
||||
CrowdSecAPIURL: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_URL", "CHARON_SECURITY_CROWDSEC_API_URL", "CPM_SECURITY_CROWDSEC_API_URL"),
|
||||
CrowdSecAPIKey: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"),
|
||||
CrowdSecConfigDir: getEnvAny(filepath.Join("data", "crowdsec"), "CHARON_CROWDSEC_CONFIG_DIR", "CPM_CROWDSEC_CONFIG_DIR"),
|
||||
WAFMode: getEnvAny("disabled", "CERBERUS_SECURITY_WAF_MODE", "CHARON_SECURITY_WAF_MODE", "CPM_SECURITY_WAF_MODE"),
|
||||
RateLimitMode: getEnvAny("disabled", "CERBERUS_SECURITY_RATELIMIT_MODE", "CHARON_SECURITY_RATELIMIT_MODE", "CPM_SECURITY_RATELIMIT_MODE"),
|
||||
ACLMode: getEnvAny("disabled", "CERBERUS_SECURITY_ACL_MODE", "CHARON_SECURITY_ACL_MODE", "CPM_SECURITY_ACL_MODE"),
|
||||
CerberusEnabled: getEnvAny("true", "CERBERUS_SECURITY_CERBERUS_ENABLED", "CHARON_SECURITY_CERBERUS_ENABLED", "CPM_SECURITY_CERBERUS_ENABLED") != "false",
|
||||
CrowdSecMode: getEnvAny("disabled", "CERBERUS_SECURITY_CROWDSEC_MODE", "CHARON_SECURITY_CROWDSEC_MODE", "CPM_SECURITY_CROWDSEC_MODE"),
|
||||
CrowdSecAPIURL: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_URL", "CHARON_SECURITY_CROWDSEC_API_URL", "CPM_SECURITY_CROWDSEC_API_URL"),
|
||||
CrowdSecAPIKey: getEnvAny("", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"),
|
||||
CrowdSecConfigDir: getEnvAny(filepath.Join("data", "crowdsec"), "CHARON_CROWDSEC_CONFIG_DIR", "CPM_CROWDSEC_CONFIG_DIR"),
|
||||
WAFMode: getEnvAny("disabled", "CERBERUS_SECURITY_WAF_MODE", "CHARON_SECURITY_WAF_MODE", "CPM_SECURITY_WAF_MODE"),
|
||||
RateLimitMode: getEnvAny("disabled", "CERBERUS_SECURITY_RATELIMIT_MODE", "CHARON_SECURITY_RATELIMIT_MODE", "CPM_SECURITY_RATELIMIT_MODE"),
|
||||
RateLimitRequests: getEnvIntAny(100, "CERBERUS_SECURITY_RATELIMIT_REQUESTS", "CHARON_SECURITY_RATELIMIT_REQUESTS"),
|
||||
RateLimitWindowSec: getEnvIntAny(60, "CERBERUS_SECURITY_RATELIMIT_WINDOW", "CHARON_SECURITY_RATELIMIT_WINDOW"),
|
||||
RateLimitBurst: getEnvIntAny(20, "CERBERUS_SECURITY_RATELIMIT_BURST", "CHARON_SECURITY_RATELIMIT_BURST"),
|
||||
ACLMode: getEnvAny("disabled", "CERBERUS_SECURITY_ACL_MODE", "CHARON_SECURITY_ACL_MODE", "CPM_SECURITY_ACL_MODE"),
|
||||
CerberusEnabled: getEnvAny("true", "CERBERUS_SECURITY_CERBERUS_ENABLED", "CHARON_SECURITY_CERBERUS_ENABLED", "CPM_SECURITY_CERBERUS_ENABLED") != "false",
|
||||
}
|
||||
|
||||
// Parse management CIDRs (comma-separated list)
|
||||
@@ -173,3 +180,16 @@ func getEnvAny(fallback string, keys ...string) string {
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// getEnvIntAny checks a list of environment variable names, attempts to parse as int.
|
||||
// Returns first successfully parsed value. Returns fallback if none found or parsing failed.
|
||||
func getEnvIntAny(fallback int, keys ...string) int {
|
||||
valStr := getEnvAny("", keys...)
|
||||
if valStr == "" {
|
||||
return fallback
|
||||
}
|
||||
if val, err := strconv.Atoi(valStr); err == nil {
|
||||
return val
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
@@ -28,6 +28,14 @@ func IsIPInCIDRList(clientIP, cidrList string) bool {
|
||||
}
|
||||
|
||||
if parsed := net.ParseIP(entry); parsed != nil {
|
||||
// Fix for Issue 1: Canonicalize entry to support mixed IPv4/IPv6 loopback matching
|
||||
// This ensures that "::1" in the list matches "127.0.0.1" (from canonicalized client IP)
|
||||
if canonEntry := util.CanonicalizeIPForSecurity(entry); canonEntry != "" {
|
||||
if p := net.ParseIP(canonEntry); p != nil {
|
||||
parsed = p
|
||||
}
|
||||
}
|
||||
|
||||
if ip.Equal(parsed) {
|
||||
return true
|
||||
}
|
||||
@@ -41,6 +49,12 @@ func IsIPInCIDRList(clientIP, cidrList string) bool {
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fix for Issue 1: Handle IPv6 loopback CIDR matching against canonicalized IPv4 localhost
|
||||
// If client is 127.0.0.1 (canonical localhost) and CIDR contains ::1, allow it
|
||||
if ip.Equal(net.IPv4(127, 0, 0, 1)) && cidr.Contains(net.IPv6loopback) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -45,6 +45,18 @@ func TestIsIPInCIDRList(t *testing.T) {
|
||||
list: "192.168.0.0/16",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback match",
|
||||
ip: "::1",
|
||||
list: "::1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback CIDR match",
|
||||
ip: "::1",
|
||||
list: "::1/128",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -42,8 +42,8 @@ func (m *mockCrowdsecExecutor) Status(ctx context.Context, configDir string) (ru
|
||||
// mockCommandExecutor is a test mock for CommandExecutor interface
|
||||
type mockCommandExecutor struct {
|
||||
executeCalls [][]string // Track command invocations
|
||||
executeErr error // Error to return
|
||||
executeOut []byte // Output to return
|
||||
executeErr error // Error to return
|
||||
executeOut []byte // Output to return
|
||||
}
|
||||
|
||||
func (m *mockCommandExecutor) Execute(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/caddy"
|
||||
@@ -46,12 +47,82 @@ func (s *ProxyHostService) ValidateUniqueDomain(domainNames string, excludeID ui
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateHostname checks if the provided string is a valid hostname or IP address.
|
||||
func (s *ProxyHostService) ValidateHostname(host string) error {
|
||||
// Trim protocol if present
|
||||
if len(host) > 8 && host[:8] == "https://" {
|
||||
host = host[8:]
|
||||
} else if len(host) > 7 && host[:7] == "http://" {
|
||||
host = host[7:]
|
||||
}
|
||||
|
||||
// Remove port if present
|
||||
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = parsedHost
|
||||
}
|
||||
|
||||
// Basic check: is it an IP?
|
||||
if net.ParseIP(host) != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Is it a valid hostname/domain?
|
||||
// Regex for hostname validation (RFC 1123 mostly)
|
||||
// Simple version: alphanumeric, dots, dashes.
|
||||
// Allow underscores? Technically usually not in hostnames, but internal docker ones yes.
|
||||
for _, r := range host {
|
||||
if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '.' && r != '-' && r != '_' {
|
||||
// Allow ":" for IPv6 literals if not parsed by ParseIP? ParseIP handles IPv6.
|
||||
return errors.New("invalid hostname format")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyHostService) validateProxyHost(host *models.ProxyHost) error {
|
||||
if host.ForwardHost == "" {
|
||||
return errors.New("forward host is required")
|
||||
}
|
||||
|
||||
// Basic hostname/IP validation
|
||||
target := host.ForwardHost
|
||||
// Strip protocol if user accidentally typed http://10.0.0.1
|
||||
target = strings.TrimPrefix(target, "http://")
|
||||
target = strings.TrimPrefix(target, "https://")
|
||||
// Strip port if present
|
||||
if h, _, err := net.SplitHostPort(target); err == nil {
|
||||
target = h
|
||||
}
|
||||
|
||||
// Validate target
|
||||
if net.ParseIP(target) == nil {
|
||||
// Not a valid IP, check hostname rules
|
||||
// Allow: a-z, 0-9, -, ., _ (for docker service names)
|
||||
validHostname := true
|
||||
for _, r := range target {
|
||||
if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '.' && r != '-' && r != '_' {
|
||||
validHostname = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !validHostname {
|
||||
return errors.New("forward host must be a valid IP address or hostname")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create validates and creates a new proxy host.
|
||||
func (s *ProxyHostService) Create(host *models.ProxyHost) error {
|
||||
if err := s.ValidateUniqueDomain(host.DomainNames, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.validateProxyHost(host); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Normalize and validate advanced config (if present)
|
||||
if host.AdvancedConfig != "" {
|
||||
var parsed any
|
||||
@@ -75,6 +146,10 @@ func (s *ProxyHostService) Update(host *models.ProxyHost) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.validateProxyHost(host); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Normalize and validate advanced config (if present)
|
||||
if host.AdvancedConfig != "" {
|
||||
var parsed any
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProxyHostService_ForwardHostValidation(t *testing.T) {
|
||||
db := setupProxyHostTestDB(t)
|
||||
service := NewProxyHostService(db)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
forwardHost string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid IP",
|
||||
forwardHost: "192.168.1.1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Hostname",
|
||||
forwardHost: "example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Docker Service Name",
|
||||
forwardHost: "my-service",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Docker Service Name with Underscore",
|
||||
forwardHost: "my_db_Service",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Docker Internal Host",
|
||||
forwardHost: "host.docker.internal",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IP with Port (Should be stripped and pass)",
|
||||
forwardHost: "192.168.1.1:8080",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Hostname with Port (Should be stripped and pass)",
|
||||
forwardHost: "example.com:3000",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Host with http scheme (Should be stripped and pass)",
|
||||
forwardHost: "http://example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Host with https scheme (Should be stripped and pass)",
|
||||
forwardHost: "https://example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid Characters",
|
||||
forwardHost: "invalid$host",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Empty Host",
|
||||
forwardHost: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
host := &models.ProxyHost{
|
||||
DomainNames: "test-" + tt.name + ".example.com",
|
||||
ForwardHost: tt.forwardHost,
|
||||
ForwardPort: 8080,
|
||||
}
|
||||
// We only care about validation error
|
||||
err := service.Create(host)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else if err != nil {
|
||||
// Check if error is validation or something else
|
||||
// If it's something else, it might be fine for this test context
|
||||
// but "forward host must be..." is what we look for.
|
||||
assert.NotContains(t, err.Error(), "forward host", "Should not fail validation")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user