Files
Charon/backend/internal/cerberus/rate_limit_test.go
2026-03-04 18:34:49 +00:00

565 lines
15 KiB
Go

package cerberus
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/Wikid82/charon/backend/internal/config"
"github.com/Wikid82/charon/backend/internal/models"
)
func init() {
gin.SetMode(gin.TestMode)
}
func setupRateLimitTestDB(t *testing.T) *gorm.DB {
t.Helper()
dsn := fmt.Sprintf("file:rate_limit_test_%d?mode=memory&cache=shared", time.Now().UnixNano())
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&models.Setting{}))
return db
}
func TestRateLimitMiddleware(t *testing.T) {
t.Run("Blocks excessive requests", func(t *testing.T) {
// Limit to 5 requests per second, with burst of 5
mw := NewRateLimitMiddleware(5, 1, 5)
r := gin.New()
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Make 5 allowed requests
for i := 0; i < 5; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
// Make 6th request (should fail)
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
})
t.Run("Different IPs have separate limits", func(t *testing.T) {
mw := NewRateLimitMiddleware(1, 1, 1)
r := gin.New()
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 1st User
req1, _ := http.NewRequest("GET", "/", nil)
req1.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req1)
assert.Equal(t, http.StatusOK, w1.Code)
// 2nd User (should pass)
req2, _ := http.NewRequest("GET", "/", nil)
req2.RemoteAddr = "10.0.0.2:1234"
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
assert.Equal(t, http.StatusOK, w2.Code)
})
t.Run("Replenishes tokens over time", func(t *testing.T) {
// 1 request per second (burst 1)
mw := NewRateLimitMiddleware(1, 1, 1)
// Manually override the burst/limit for predictable testing isn't easy with wrapper
// So we rely on the implementation using x/time/rate
// Test:
// 1. Consume 1
// 2. Consume 2 (Fail)
// 3. Wait until refill
// 4. Consume 3 (Pass)
r := gin.New()
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "1.2.3.4:1234"
// 1. Consume
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
// 2. Consume Fail
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
// 3. Wait until refill
require.Eventually(t, func() bool {
w3 := httptest.NewRecorder()
r.ServeHTTP(w3, req)
return w3.Code == http.StatusOK
}, 1500*time.Millisecond, 25*time.Millisecond)
})
}
func TestRateLimitManager_ReconfiguresLimiter(t *testing.T) {
mgr := &rateLimitManager{
limiters: make(map[string]*rate.Limiter),
lastSeen: make(map[string]time.Time),
}
limiter := mgr.getLimiter("10.0.0.1", rate.Limit(1), 1)
assert.Equal(t, rate.Limit(1), limiter.Limit())
assert.Equal(t, 1, limiter.Burst())
limiter = mgr.getLimiter("10.0.0.1", rate.Limit(2), 2)
assert.Equal(t, rate.Limit(2), limiter.Limit())
assert.Equal(t, 2, limiter.Burst())
}
func TestRateLimitManager_CleanupRemovesStaleEntries(t *testing.T) {
mgr := &rateLimitManager{
limiters: map[string]*rate.Limiter{
"10.0.0.1": rate.NewLimiter(rate.Limit(1), 1),
},
lastSeen: map[string]time.Time{
"10.0.0.1": time.Now().Add(-11 * time.Minute),
},
}
mgr.cleanup()
assert.Empty(t, mgr.limiters)
assert.Empty(t, mgr.lastSeen)
}
func TestRateLimitMiddleware_EmergencyBypass(t *testing.T) {
mw := NewRateLimitMiddleware(1, 1, 1)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("emergency_bypass", true)
c.Next()
})
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_DisabledAllowsTraffic(t *testing.T) {
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, nil)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_EnabledByConfig(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 1,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
for i := 0; i < 2; i++ {
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if i == 0 {
assert.Equal(t, http.StatusOK, w.Code)
} else {
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
}
}
func TestCerberusRateLimitMiddleware_EmergencyBypass(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 1,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("emergency_bypass", true)
c.Next()
})
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_EnabledBySetting(t *testing.T) {
db := setupRateLimitTestDB(t)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, db)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}
func TestCerberusRateLimitMiddleware_OverridesConfigWithSettings(t *testing.T) {
db := setupRateLimitTestDB(t)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 10,
RateLimitWindowSec: 10,
RateLimitBurst: 10,
}
cerb := New(cfg, db)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}
func TestCerberusRateLimitMiddleware_SettingsDisableOverride(t *testing.T) {
db := setupRateLimitTestDB(t)
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "false"}).Error)
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 60,
RateLimitBurst: 1,
}
cerb := New(cfg, db)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
for i := 0; i < 3; i++ {
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_WindowFallback(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 0,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}
func TestCerberusRateLimitMiddleware_AdminSecurityControlPlaneBypass(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 60,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", "admin")
c.Set("userID", uint(1))
c.Next()
})
r.Use(cerb.RateLimitMiddleware())
r.GET("/api/v1/security/status", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("GET", "/api/v1/security/status", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestIsAdminSecurityControlPlaneRequest(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
t.Run("admin role bypasses control plane", func(t *testing.T) {
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/security/rules", http.NoBody)
ctx.Set("role", "admin")
assert.True(t, isAdminSecurityControlPlaneRequest(ctx))
})
t.Run("bearer token bypasses control plane", func(t *testing.T) {
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodGet, "/api/v1/settings", http.NoBody)
req.Header.Set("Authorization", "Bearer token")
ctx.Request = req
assert.True(t, isAdminSecurityControlPlaneRequest(ctx))
})
t.Run("non control plane path is not bypassed", func(t *testing.T) {
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/proxy-hosts", http.NoBody)
ctx.Set("role", "admin")
assert.False(t, isAdminSecurityControlPlaneRequest(ctx))
})
}
func TestCerberusRateLimitMiddleware_AdminSettingsBypass(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 60,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", "admin")
c.Set("userID", uint(1))
c.Next()
})
r.Use(cerb.RateLimitMiddleware())
r.POST("/api/v1/settings", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("POST", "/api/v1/settings", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_ControlPlaneBypassWithBearerWithoutRoleContext(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 60,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(cerb.RateLimitMiddleware())
r.POST("/api/v1/settings", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("POST", "/api/v1/settings", nil)
req.RemoteAddr = "10.0.0.1:1234"
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}
func TestCerberusRateLimitMiddleware_AdminNonSecurityPathStillLimited(t *testing.T) {
cfg := config.SecurityConfig{
RateLimitMode: "enabled",
RateLimitRequests: 1,
RateLimitWindowSec: 60,
RateLimitBurst: 1,
}
cerb := New(cfg, nil)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", "admin")
c.Set("userID", uint(1))
c.Next()
})
r.Use(cerb.RateLimitMiddleware())
r.GET("/api/v1/users", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/api/v1/users", nil)
req.RemoteAddr = "10.0.0.1:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}
func TestIsAdminSecurityControlPlaneRequest_UsesDecodedRawPath(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/security%2Frules", http.NoBody)
req.URL.Path = "/api/v1/security%2Frules"
req.URL.RawPath = "/api/v1/security%2Frules"
req.Header.Set("Authorization", "Bearer token")
ctx.Request = req
assert.True(t, isAdminSecurityControlPlaneRequest(ctx))
}
func TestNewRateLimitMiddleware_UsesWindowFallbackWhenNonPositive(t *testing.T) {
mw := NewRateLimitMiddleware(1, 0, 1)
r := gin.New()
r.Use(mw)
r.GET("/", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.10.10.10:1234"
w1 := httptest.NewRecorder()
r.ServeHTTP(w1, req)
assert.Equal(t, http.StatusOK, w1.Code)
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req)
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
}
func TestNewRateLimitMiddleware_BypassesControlPlaneBearerRequests(t *testing.T) {
mw := NewRateLimitMiddleware(1, 1, 1)
r := gin.New()
r.Use(mw)
r.GET("/api/v1/settings", func(c *gin.Context) {
c.Status(http.StatusOK)
})
for i := 0; i < 3; i++ {
req, _ := http.NewRequest(http.MethodGet, "/api/v1/settings", nil)
req.RemoteAddr = "10.10.10.11:1234"
req.Header.Set("Authorization", "Bearer admin-token")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
}