565 lines
15 KiB
Go
565 lines
15 KiB
Go
package cerberus
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/time/rate"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/Wikid82/charon/backend/internal/config"
|
|
"github.com/Wikid82/charon/backend/internal/models"
|
|
)
|
|
|
|
func init() {
|
|
gin.SetMode(gin.TestMode)
|
|
}
|
|
|
|
func setupRateLimitTestDB(t *testing.T) *gorm.DB {
|
|
t.Helper()
|
|
dsn := fmt.Sprintf("file:rate_limit_test_%d?mode=memory&cache=shared", time.Now().UnixNano())
|
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
|
require.NoError(t, err)
|
|
require.NoError(t, db.AutoMigrate(&models.Setting{}))
|
|
return db
|
|
}
|
|
|
|
func TestRateLimitMiddleware(t *testing.T) {
|
|
t.Run("Blocks excessive requests", func(t *testing.T) {
|
|
// Limit to 5 requests per second, with burst of 5
|
|
mw := NewRateLimitMiddleware(5, 1, 5)
|
|
|
|
r := gin.New()
|
|
r.Use(mw)
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
// Make 5 allowed requests
|
|
for i := 0; i < 5; i++ {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
// Make 6th request (should fail)
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
|
})
|
|
|
|
t.Run("Different IPs have separate limits", func(t *testing.T) {
|
|
mw := NewRateLimitMiddleware(1, 1, 1)
|
|
|
|
r := gin.New()
|
|
r.Use(mw)
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
// 1st User
|
|
req1, _ := http.NewRequest("GET", "/", nil)
|
|
req1.RemoteAddr = "10.0.0.1:1234"
|
|
w1 := httptest.NewRecorder()
|
|
r.ServeHTTP(w1, req1)
|
|
assert.Equal(t, http.StatusOK, w1.Code)
|
|
|
|
// 2nd User (should pass)
|
|
req2, _ := http.NewRequest("GET", "/", nil)
|
|
req2.RemoteAddr = "10.0.0.2:1234"
|
|
w2 := httptest.NewRecorder()
|
|
r.ServeHTTP(w2, req2)
|
|
assert.Equal(t, http.StatusOK, w2.Code)
|
|
})
|
|
|
|
t.Run("Replenishes tokens over time", func(t *testing.T) {
|
|
// 1 request per second (burst 1)
|
|
mw := NewRateLimitMiddleware(1, 1, 1)
|
|
// Manually override the burst/limit for predictable testing isn't easy with wrapper
|
|
// So we rely on the implementation using x/time/rate
|
|
// Test:
|
|
// 1. Consume 1
|
|
// 2. Consume 2 (Fail)
|
|
// 3. Wait until refill
|
|
// 4. Consume 3 (Pass)
|
|
|
|
r := gin.New()
|
|
r.Use(mw)
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "1.2.3.4:1234"
|
|
|
|
// 1. Consume
|
|
w1 := httptest.NewRecorder()
|
|
r.ServeHTTP(w1, req)
|
|
assert.Equal(t, http.StatusOK, w1.Code)
|
|
|
|
// 2. Consume Fail
|
|
w2 := httptest.NewRecorder()
|
|
r.ServeHTTP(w2, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
|
|
|
// 3. Wait until refill
|
|
require.Eventually(t, func() bool {
|
|
w3 := httptest.NewRecorder()
|
|
r.ServeHTTP(w3, req)
|
|
return w3.Code == http.StatusOK
|
|
}, 1500*time.Millisecond, 25*time.Millisecond)
|
|
})
|
|
}
|
|
|
|
func TestRateLimitManager_ReconfiguresLimiter(t *testing.T) {
|
|
mgr := &rateLimitManager{
|
|
limiters: make(map[string]*rate.Limiter),
|
|
lastSeen: make(map[string]time.Time),
|
|
}
|
|
|
|
limiter := mgr.getLimiter("10.0.0.1", rate.Limit(1), 1)
|
|
assert.Equal(t, rate.Limit(1), limiter.Limit())
|
|
assert.Equal(t, 1, limiter.Burst())
|
|
|
|
limiter = mgr.getLimiter("10.0.0.1", rate.Limit(2), 2)
|
|
assert.Equal(t, rate.Limit(2), limiter.Limit())
|
|
assert.Equal(t, 2, limiter.Burst())
|
|
}
|
|
|
|
func TestRateLimitManager_CleanupRemovesStaleEntries(t *testing.T) {
|
|
mgr := &rateLimitManager{
|
|
limiters: map[string]*rate.Limiter{
|
|
"10.0.0.1": rate.NewLimiter(rate.Limit(1), 1),
|
|
},
|
|
lastSeen: map[string]time.Time{
|
|
"10.0.0.1": time.Now().Add(-11 * time.Minute),
|
|
},
|
|
}
|
|
|
|
mgr.cleanup()
|
|
assert.Empty(t, mgr.limiters)
|
|
assert.Empty(t, mgr.lastSeen)
|
|
}
|
|
|
|
func TestRateLimitMiddleware_EmergencyBypass(t *testing.T) {
|
|
mw := NewRateLimitMiddleware(1, 1, 1)
|
|
|
|
r := gin.New()
|
|
r.Use(func(c *gin.Context) {
|
|
c.Set("emergency_bypass", true)
|
|
c.Next()
|
|
})
|
|
r.Use(mw)
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 2; i++ {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_DisabledAllowsTraffic(t *testing.T) {
|
|
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 3; i++ {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_EnabledByConfig(t *testing.T) {
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 1,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
for i := 0; i < 2; i++ {
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
if i == 0 {
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
} else {
|
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_EmergencyBypass(t *testing.T) {
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 1,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(func(c *gin.Context) {
|
|
c.Set("emergency_bypass", true)
|
|
c.Next()
|
|
})
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 2; i++ {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_EnabledBySetting(t *testing.T) {
|
|
db := setupRateLimitTestDB(t)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
|
|
|
|
cerb := New(config.SecurityConfig{RateLimitMode: "disabled"}, db)
|
|
|
|
r := gin.New()
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
|
|
w1 := httptest.NewRecorder()
|
|
r.ServeHTTP(w1, req)
|
|
assert.Equal(t, http.StatusOK, w1.Code)
|
|
|
|
w2 := httptest.NewRecorder()
|
|
r.ServeHTTP(w2, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_OverridesConfigWithSettings(t *testing.T) {
|
|
db := setupRateLimitTestDB(t)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "true"}).Error)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.requests", Value: "1"}).Error)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.window", Value: "1"}).Error)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.burst", Value: "1"}).Error)
|
|
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 10,
|
|
RateLimitWindowSec: 10,
|
|
RateLimitBurst: 10,
|
|
}
|
|
cerb := New(cfg, db)
|
|
|
|
r := gin.New()
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
|
|
w1 := httptest.NewRecorder()
|
|
r.ServeHTTP(w1, req)
|
|
assert.Equal(t, http.StatusOK, w1.Code)
|
|
|
|
w2 := httptest.NewRecorder()
|
|
r.ServeHTTP(w2, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_SettingsDisableOverride(t *testing.T) {
|
|
db := setupRateLimitTestDB(t)
|
|
require.NoError(t, db.Create(&models.Setting{Key: "security.rate_limit.enabled", Value: "false"}).Error)
|
|
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 60,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, db)
|
|
|
|
r := gin.New()
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
|
|
for i := 0; i < 3; i++ {
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_WindowFallback(t *testing.T) {
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 0,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
|
|
w1 := httptest.NewRecorder()
|
|
r.ServeHTTP(w1, req)
|
|
assert.Equal(t, http.StatusOK, w1.Code)
|
|
|
|
w2 := httptest.NewRecorder()
|
|
r.ServeHTTP(w2, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_AdminSecurityControlPlaneBypass(t *testing.T) {
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 60,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(func(c *gin.Context) {
|
|
c.Set("role", "admin")
|
|
c.Set("userID", uint(1))
|
|
c.Next()
|
|
})
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/api/v1/security/status", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 3; i++ {
|
|
req, _ := http.NewRequest("GET", "/api/v1/security/status", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestIsAdminSecurityControlPlaneRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
t.Run("admin role bypasses control plane", func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
ctx, _ := gin.CreateTestContext(rec)
|
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/security/rules", http.NoBody)
|
|
ctx.Set("role", "admin")
|
|
assert.True(t, isAdminSecurityControlPlaneRequest(ctx))
|
|
})
|
|
|
|
t.Run("bearer token bypasses control plane", func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
ctx, _ := gin.CreateTestContext(rec)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/settings", http.NoBody)
|
|
req.Header.Set("Authorization", "Bearer token")
|
|
ctx.Request = req
|
|
assert.True(t, isAdminSecurityControlPlaneRequest(ctx))
|
|
})
|
|
|
|
t.Run("non control plane path is not bypassed", func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
ctx, _ := gin.CreateTestContext(rec)
|
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/proxy-hosts", http.NoBody)
|
|
ctx.Set("role", "admin")
|
|
assert.False(t, isAdminSecurityControlPlaneRequest(ctx))
|
|
})
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_AdminSettingsBypass(t *testing.T) {
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 60,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(func(c *gin.Context) {
|
|
c.Set("role", "admin")
|
|
c.Set("userID", uint(1))
|
|
c.Next()
|
|
})
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.POST("/api/v1/settings", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 3; i++ {
|
|
req, _ := http.NewRequest("POST", "/api/v1/settings", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_ControlPlaneBypassWithBearerWithoutRoleContext(t *testing.T) {
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 60,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.POST("/api/v1/settings", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 3; i++ {
|
|
req, _ := http.NewRequest("POST", "/api/v1/settings", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
req.Header.Set("Authorization", "Bearer test-token")
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCerberusRateLimitMiddleware_AdminNonSecurityPathStillLimited(t *testing.T) {
|
|
cfg := config.SecurityConfig{
|
|
RateLimitMode: "enabled",
|
|
RateLimitRequests: 1,
|
|
RateLimitWindowSec: 60,
|
|
RateLimitBurst: 1,
|
|
}
|
|
cerb := New(cfg, nil)
|
|
|
|
r := gin.New()
|
|
r.Use(func(c *gin.Context) {
|
|
c.Set("role", "admin")
|
|
c.Set("userID", uint(1))
|
|
c.Next()
|
|
})
|
|
r.Use(cerb.RateLimitMiddleware())
|
|
r.GET("/api/v1/users", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/api/v1/users", nil)
|
|
req.RemoteAddr = "10.0.0.1:1234"
|
|
|
|
w1 := httptest.NewRecorder()
|
|
r.ServeHTTP(w1, req)
|
|
assert.Equal(t, http.StatusOK, w1.Code)
|
|
|
|
w2 := httptest.NewRecorder()
|
|
r.ServeHTTP(w2, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
|
}
|
|
|
|
func TestIsAdminSecurityControlPlaneRequest_UsesDecodedRawPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
recorder := httptest.NewRecorder()
|
|
ctx, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/security%2Frules", http.NoBody)
|
|
req.URL.Path = "/api/v1/security%2Frules"
|
|
req.URL.RawPath = "/api/v1/security%2Frules"
|
|
req.Header.Set("Authorization", "Bearer token")
|
|
ctx.Request = req
|
|
|
|
assert.True(t, isAdminSecurityControlPlaneRequest(ctx))
|
|
}
|
|
|
|
func TestNewRateLimitMiddleware_UsesWindowFallbackWhenNonPositive(t *testing.T) {
|
|
mw := NewRateLimitMiddleware(1, 0, 1)
|
|
|
|
r := gin.New()
|
|
r.Use(mw)
|
|
r.GET("/", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "10.10.10.10:1234"
|
|
|
|
w1 := httptest.NewRecorder()
|
|
r.ServeHTTP(w1, req)
|
|
assert.Equal(t, http.StatusOK, w1.Code)
|
|
|
|
w2 := httptest.NewRecorder()
|
|
r.ServeHTTP(w2, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, w2.Code)
|
|
}
|
|
|
|
func TestNewRateLimitMiddleware_BypassesControlPlaneBearerRequests(t *testing.T) {
|
|
mw := NewRateLimitMiddleware(1, 1, 1)
|
|
|
|
r := gin.New()
|
|
r.Use(mw)
|
|
r.GET("/api/v1/settings", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
for i := 0; i < 3; i++ {
|
|
req, _ := http.NewRequest(http.MethodGet, "/api/v1/settings", nil)
|
|
req.RemoteAddr = "10.10.10.11:1234"
|
|
req.Header.Set("Authorization", "Bearer admin-token")
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
}
|