chore: git cache cleanup

This commit is contained in:
GitHub Actions
2026-03-04 18:34:49 +00:00
parent c32cce2a88
commit 27c252600a
2001 changed files with 683185 additions and 0 deletions

View File

@@ -0,0 +1,125 @@
package middleware
import (
"net/http"
"strings"
"github.com/Wikid82/charon/backend/internal/models"
"github.com/Wikid82/charon/backend/internal/services"
"github.com/gin-gonic/gin"
)
func AuthMiddleware(authService *services.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
if bypass, exists := c.Get("emergency_bypass"); exists {
if bypassActive, ok := bypass.(bool); ok && bypassActive {
c.Set("role", "admin")
c.Set("userID", uint(0))
c.Next()
return
}
}
if authService == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
return
}
tokenString, ok := extractAuthToken(c)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
return
}
user, _, err := authService.AuthenticateToken(tokenString)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
return
}
c.Set("userID", user.ID)
c.Set("role", string(user.Role))
c.Next()
}
}
func extractAuthToken(c *gin.Context) (string, bool) {
authHeader := c.GetHeader("Authorization")
// Fall back to cookie for browser flows (including WebSocket upgrades)
if authHeader == "" {
if cookieToken := extractAuthCookieToken(c); cookieToken != "" {
authHeader = "Bearer " + cookieToken
}
}
// DEPRECATED: Query parameter authentication for WebSocket connections
// This fallback exists only for backward compatibility and will be removed in a future version.
// Query parameters are logged in access logs and should not be used for sensitive data.
// Use HttpOnly cookies instead, which are automatically sent by browsers and not logged.
if authHeader == "" {
if token := c.Query("token"); token != "" {
authHeader = "Bearer " + token
}
}
if authHeader == "" {
return "", false
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == "" {
return "", false
}
return tokenString, true
}
func extractAuthCookieToken(c *gin.Context) string {
if c.Request == nil {
return ""
}
token := ""
for _, cookie := range c.Request.Cookies() {
if cookie.Name != "auth_token" {
continue
}
if cookie.Value == "" {
continue
}
token = cookie.Value
}
return token
}
func RequireRole(role models.UserRole) gin.HandlerFunc {
return func(c *gin.Context) {
userRole := c.GetString("role")
if userRole == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
if userRole != string(role) && userRole != string(models.RoleAdmin) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "Forbidden"})
return
}
c.Next()
}
}
func RequireManagementAccess() gin.HandlerFunc {
return func(c *gin.Context) {
role := c.GetString("role")
if role == string(models.RolePassthrough) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "Pass-through users cannot access management features"})
return
}
c.Next()
}
}

View File

@@ -0,0 +1,487 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wikid82/charon/backend/internal/config"
"github.com/Wikid82/charon/backend/internal/models"
"github.com/Wikid82/charon/backend/internal/services"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func setupAuthService(t *testing.T) *services.AuthService {
authService, _ := setupAuthServiceWithDB(t)
return authService
}
func setupAuthServiceWithDB(t *testing.T) (*services.AuthService, *gorm.DB) {
dbName := "file:" + t.Name() + "?mode=memory&cache=shared"
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{})
require.NoError(t, err)
_ = db.AutoMigrate(&models.User{})
cfg := config.Config{JWTSecret: "test-secret"}
return services.NewAuthService(db, cfg), db
}
func TestAuthMiddleware_MissingHeader(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// We pass nil for authService because we expect it to fail before using it
r.Use(AuthMiddleware(nil))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "Authorization header required")
}
func TestAuthMiddleware_EmergencyBypass(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("emergency_bypass", true)
c.Next()
})
r.Use(AuthMiddleware(nil))
r.GET("/test", func(c *gin.Context) {
role, _ := c.Get("role")
userID, _ := c.Get("userID")
assert.Equal(t, "admin", role)
assert.Equal(t, uint(0), userID)
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRequireRole_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", "admin")
c.Next()
})
r.Use(RequireRole("admin"))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRequireRole_Forbidden(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", "user")
c.Next()
})
r.Use(RequireRole("admin"))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}
func TestAuthMiddleware_Cookie(t *testing.T) {
authService := setupAuthService(t)
user, err := authService.Register("test@example.com", "password", "Test User")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("userID")
assert.Equal(t, user.ID, userID)
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: token})
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuthMiddleware_ValidToken(t *testing.T) {
authService := setupAuthService(t)
user, err := authService.Register("test@example.com", "password", "Test User")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("userID")
assert.Equal(t, user.ID, userID)
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuthMiddleware_PrefersCookieOverAuthorizationHeader(t *testing.T) {
authService := setupAuthService(t)
cookieUser, _ := authService.Register("cookie-header@example.com", "password", "Cookie Header User")
cookieToken, _ := authService.GenerateToken(cookieUser)
headerUser, _ := authService.Register("header@example.com", "password", "Header User")
headerToken, _ := authService.GenerateToken(headerUser)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("userID")
assert.Equal(t, headerUser.ID, userID)
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
req.Header.Set("Authorization", "Bearer "+headerToken)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: cookieToken})
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuthMiddleware_UsesCookieWhenAuthorizationHeaderIsInvalid(t *testing.T) {
authService := setupAuthService(t)
user, err := authService.Register("cookie-valid@example.com", "password", "Cookie Valid User")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("userID")
assert.Equal(t, user.ID, userID)
c.Status(http.StatusOK)
})
req, err := http.NewRequest("GET", "/test", http.NoBody)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer invalid-token")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: token})
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestAuthMiddleware_UsesLastNonEmptyCookieWhenDuplicateCookiesExist(t *testing.T) {
authService := setupAuthService(t)
user, err := authService.Register("dupecookie@example.com", "password", "Dup Cookie User")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("userID")
assert.Equal(t, user.ID, userID)
c.Status(http.StatusOK)
})
req, err := http.NewRequest("GET", "/test", http.NoBody)
require.NoError(t, err)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: ""})
req.AddCookie(&http.Cookie{Name: "auth_token", Value: token})
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuthMiddleware_InvalidToken(t *testing.T) {
authService := setupAuthService(t)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
req.Header.Set("Authorization", "Bearer invalid-token")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "Invalid token")
}
func TestRequireRole_MissingRoleInContext(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// No role set in context
r.Use(RequireRole("admin"))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest("GET", "/test", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestAuthMiddleware_QueryParamFallback(t *testing.T) {
authService := setupAuthService(t)
user, err := authService.Register("test@example.com", "password", "Test User")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("userID")
assert.Equal(t, user.ID, userID)
c.Status(http.StatusOK)
})
// Test that query param auth still works (deprecated fallback)
req, err := http.NewRequest("GET", "/test?token="+token, http.NoBody)
require.NoError(t, err)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuthMiddleware_PrefersCookieOverQueryParam(t *testing.T) {
authService := setupAuthService(t)
// Create two different users
cookieUser, err := authService.Register("cookie@example.com", "password", "Cookie User")
require.NoError(t, err)
cookieToken, err := authService.GenerateToken(cookieUser)
require.NoError(t, err)
queryUser, err := authService.Register("query@example.com", "password", "Query User")
require.NoError(t, err)
queryToken, err := authService.GenerateToken(queryUser)
require.NoError(t, err)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("userID")
// Should use the cookie user, not the query param user
assert.Equal(t, cookieUser.ID, userID)
c.Status(http.StatusOK)
})
// Both cookie and query param provided - cookie should win
req, err := http.NewRequest("GET", "/test?token="+queryToken, http.NoBody)
require.NoError(t, err)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: cookieToken})
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestAuthMiddleware_RejectsDisabledUserToken(t *testing.T) {
authService, db := setupAuthServiceWithDB(t)
user, err := authService.Register("disabled@example.com", "password", "Disabled User")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
require.NoError(t, db.Model(&models.User{}).Where("id = ?", user.ID).Update("enabled", false).Error)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, err := http.NewRequest("GET", "/test", http.NoBody)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestAuthMiddleware_RejectsDeletedUserToken(t *testing.T) {
authService, db := setupAuthServiceWithDB(t)
user, err := authService.Register("deleted@example.com", "password", "Deleted User")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
require.NoError(t, db.Delete(&models.User{}, user.ID).Error)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, err := http.NewRequest("GET", "/test", http.NoBody)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestAuthMiddleware_RejectsTokenAfterSessionInvalidation(t *testing.T) {
authService := setupAuthService(t)
user, err := authService.Register("session-invalidated@example.com", "password", "Session Invalidated")
require.NoError(t, err)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
require.NoError(t, authService.InvalidateSessions(user.ID))
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(AuthMiddleware(authService))
r.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, err := http.NewRequest("GET", "/test", http.NoBody)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestExtractAuthCookieToken_ReturnsEmptyWhenRequestNil(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = nil
token := extractAuthCookieToken(ctx)
assert.Equal(t, "", token)
}
func TestExtractAuthCookieToken_IgnoresNonAuthCookies(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
req, err := http.NewRequest("GET", "/", http.NoBody)
require.NoError(t, err)
req.AddCookie(&http.Cookie{Name: "session", Value: "abc"})
ctx.Request = req
token := extractAuthCookieToken(ctx)
assert.Equal(t, "", token)
}
func TestRequireManagementAccess_PassthroughBlocked(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", string(models.RolePassthrough))
c.Next()
})
r.Use(RequireManagementAccess())
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.Contains(t, w.Body.String(), "Pass-through users cannot access management features")
}
func TestRequireManagementAccess_UserAllowed(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", string(models.RoleUser))
c.Next()
})
r.Use(RequireManagementAccess())
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRequireManagementAccess_AdminAllowed(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", string(models.RoleAdmin))
c.Next()
})
r.Use(RequireManagementAccess())
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}

View File

@@ -0,0 +1,5 @@
// Package middleware provides Gin middleware for the Charon backend API.
//
// It includes middleware for authentication, request logging, panic recovery,
// security headers, and request ID generation.
package middleware

View File

@@ -0,0 +1,129 @@
package middleware
import (
"crypto/subtle"
"net"
"os"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/Wikid82/charon/backend/internal/util"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
const (
// EmergencyTokenHeader is the HTTP header name for emergency token
EmergencyTokenHeader = "X-Emergency-Token"
// EmergencyTokenEnvVar is the environment variable name for emergency token
EmergencyTokenEnvVar = "CHARON_EMERGENCY_TOKEN"
// MinTokenLength is the minimum required length for emergency tokens
MinTokenLength = 32
)
// EmergencyBypass creates middleware that bypasses all security checks
// when a valid emergency token is present from an authorized source.
//
// Security conditions (ALL must be met):
// 1. Request from management CIDR (RFC1918 private networks by default)
// 2. X-Emergency-Token header matches configured token (timing-safe)
// 3. Token meets minimum length requirement (32+ chars)
//
// This middleware must be registered FIRST in the middleware chain.
func EmergencyBypass(managementCIDRs []string, db *gorm.DB) gin.HandlerFunc {
// Load emergency token from environment
emergencyToken := os.Getenv(EmergencyTokenEnvVar)
if emergencyToken == "" {
logger.Log().Warn("CHARON_EMERGENCY_TOKEN not set - emergency bypass disabled")
return func(c *gin.Context) { c.Next() } // noop
}
if len(emergencyToken) < MinTokenLength {
logger.Log().Warn("CHARON_EMERGENCY_TOKEN too short - emergency bypass disabled")
return func(c *gin.Context) { c.Next() } // noop
}
// Parse management CIDRs
var managementNets []*net.IPNet
for _, cidr := range managementCIDRs {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
logger.Log().WithError(err).WithField("cidr", cidr).Warn("Invalid management CIDR")
continue
}
managementNets = append(managementNets, ipnet)
}
// Default to RFC1918 private networks if none specified
if len(managementNets) == 0 {
managementNets = []*net.IPNet{
mustParseCIDR("10.0.0.0/8"),
mustParseCIDR("172.16.0.0/12"),
mustParseCIDR("192.168.0.0/16"),
mustParseCIDR("127.0.0.0/8"), // localhost for local development
mustParseCIDR("::1/128"), // IPv6 localhost
}
}
return func(c *gin.Context) {
// Check if emergency token is present
providedToken := c.GetHeader(EmergencyTokenHeader)
if providedToken == "" {
c.Next() // No emergency token - proceed normally
return
}
// Validate source IP is from management network
clientIPStr := util.CanonicalizeIPForSecurity(c.ClientIP())
clientIP := net.ParseIP(clientIPStr)
if clientIP == nil {
logger.Log().WithField("ip", util.SanitizeForLog(clientIPStr)).Warn("Emergency bypass: invalid client IP")
c.Next()
return
}
inManagementNet := false
for _, ipnet := range managementNets {
if ipnet.Contains(clientIP) {
inManagementNet = true
break
}
}
if !inManagementNet {
logger.Log().WithField("ip", util.SanitizeForLog(clientIP.String())).Warn("Emergency bypass: IP not in management network")
c.Next()
return
}
// Timing-safe token comparison
if !constantTimeCompare(emergencyToken, providedToken) {
logger.Log().WithField("ip", util.SanitizeForLog(clientIP.String())).Warn("Emergency bypass: invalid token")
c.Next()
return
}
// Valid emergency token from authorized source
logger.Log().WithFields(map[string]interface{}{
"ip": util.SanitizeForLog(clientIP.String()),
"path": util.SanitizeForLog(c.Request.URL.Path),
}).Warn("EMERGENCY BYPASS ACTIVE: Request bypassing all security checks")
// Set flag for downstream handlers to know this is an emergency request
c.Set("emergency_bypass", true)
// Strip emergency token header to prevent it from reaching application
// This is critical for security - prevents token exposure in logs
c.Request.Header.Del(EmergencyTokenHeader)
c.Next()
}
}
func mustParseCIDR(cidr string) *net.IPNet {
_, ipnet, _ := net.ParseCIDR(cidr)
return ipnet
}
func constantTimeCompare(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}

View File

@@ -0,0 +1,277 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestEmergencyBypass_NoToken(t *testing.T) {
// Test that requests without emergency token proceed normally
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
_, exists := c.Get("emergency_bypass")
assert.False(t, exists, "Emergency bypass flag should not be set")
c.JSON(http.StatusOK, gin.H{"message": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestEmergencyBypass_InvalidClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
_, exists := c.Get("emergency_bypass")
assert.False(t, exists, "Emergency bypass flag should not be set for invalid client IP")
c.JSON(http.StatusOK, gin.H{"message": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
req.RemoteAddr = "invalid-remote-addr"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestEmergencyBypass_ValidToken(t *testing.T) {
// Test that valid token from allowed IP sets bypass flag
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
bypass, exists := c.Get("emergency_bypass")
assert.True(t, exists, "Emergency bypass flag should be set")
assert.True(t, bypass.(bool), "Emergency bypass flag should be true")
c.JSON(http.StatusOK, gin.H{"message": "bypass active"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
req.RemoteAddr = "127.0.0.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Verify token was stripped from request
assert.Empty(t, req.Header.Get(EmergencyTokenHeader), "Token should be stripped")
}
func TestEmergencyBypass_ValidToken_IPv6Localhost(t *testing.T) {
// Test that valid token from IPv6 localhost is treated as localhost
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
_ = router.SetTrustedProxies(nil)
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
bypass, exists := c.Get("emergency_bypass")
assert.True(t, exists, "Emergency bypass flag should be set")
assert.True(t, bypass.(bool), "Emergency bypass flag should be true")
c.JSON(http.StatusOK, gin.H{"message": "bypass active"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
req.RemoteAddr = "[::1]:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestEmergencyBypass_InvalidToken(t *testing.T) {
// Test that invalid token does not set bypass flag
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
_, exists := c.Get("emergency_bypass")
assert.False(t, exists, "Emergency bypass flag should not be set")
c.JSON(http.StatusOK, gin.H{"message": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "wrong-token")
req.RemoteAddr = "127.0.0.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestEmergencyBypass_UnauthorizedIP(t *testing.T) {
// Test that valid token from disallowed IP does not set bypass flag
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
_, exists := c.Get("emergency_bypass")
assert.False(t, exists, "Emergency bypass flag should not be set")
c.JSON(http.StatusOK, gin.H{"message": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
req.RemoteAddr = "203.0.113.1:12345" // Public IP (not in management network)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestEmergencyBypass_TokenStripped(t *testing.T) {
// Test that emergency token header is removed after validation
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
var tokenInHandler string
router.GET("/test", func(c *gin.Context) {
tokenInHandler = c.GetHeader(EmergencyTokenHeader)
c.JSON(http.StatusOK, gin.H{"message": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
req.RemoteAddr = "127.0.0.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Empty(t, tokenInHandler, "Token should not be visible in downstream handlers")
}
func TestEmergencyBypass_MinimumLength(t *testing.T) {
// Test that tokens < 32 chars are rejected
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "short-token")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
_, exists := c.Get("emergency_bypass")
assert.False(t, exists, "Emergency bypass flag should not be set with short token")
c.JSON(http.StatusOK, gin.H{"message": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "short-token")
req.RemoteAddr = "127.0.0.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestEmergencyBypass_NoTokenConfigured(t *testing.T) {
// Test that middleware is no-op when token not configured
gin.SetMode(gin.TestMode)
// Don't set CHARON_EMERGENCY_TOKEN
t.Setenv("CHARON_EMERGENCY_TOKEN", "")
router := gin.New()
managementCIDRs := []string{"127.0.0.0/8"}
router.Use(EmergencyBypass(managementCIDRs, nil))
router.GET("/test", func(c *gin.Context) {
_, exists := c.Get("emergency_bypass")
assert.False(t, exists, "Emergency bypass flag should not be set")
c.JSON(http.StatusOK, gin.H{"message": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "any-token")
req.RemoteAddr = "127.0.0.1:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestEmergencyBypass_DefaultCIDRs(t *testing.T) {
// Test that RFC1918 networks are used by default
gin.SetMode(gin.TestMode)
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
router := gin.New()
// Pass empty CIDR list to trigger default behavior
router.Use(EmergencyBypass([]string{}, nil))
router.GET("/test", func(c *gin.Context) {
bypass, exists := c.Get("emergency_bypass")
assert.True(t, exists, "Emergency bypass flag should be set")
assert.True(t, bypass.(bool), "Emergency bypass flag should be true")
c.JSON(http.StatusOK, gin.H{"message": "bypass active"})
})
// Test with various RFC1918 addresses
testIPs := []string{
"10.0.0.1:12345",
"172.16.0.1:12345",
"192.168.1.1:12345",
"127.0.0.1:12345",
}
for _, remoteAddr := range testIPs {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
req.RemoteAddr = remoteAddr
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code, "Should accept IP: %s", remoteAddr)
}
}

View File

@@ -0,0 +1,44 @@
package middleware
import (
"github.com/Wikid82/charon/backend/internal/services"
"github.com/gin-gonic/gin"
)
// OptionalAuth applies best-effort authentication for downstream middleware without blocking requests.
func OptionalAuth(authService *services.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
if authService == nil {
c.Next()
return
}
if bypass, exists := c.Get("emergency_bypass"); exists {
if bypassActive, ok := bypass.(bool); ok && bypassActive {
c.Next()
return
}
}
if _, exists := c.Get("role"); exists {
c.Next()
return
}
tokenString, ok := extractAuthToken(c)
if !ok {
c.Next()
return
}
user, _, err := authService.AuthenticateToken(tokenString)
if err != nil {
c.Next()
return
}
c.Set("userID", user.ID)
c.Set("role", string(user.Role))
c.Next()
}
}

View File

@@ -0,0 +1,167 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wikid82/charon/backend/internal/models"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOptionalAuth_NilServicePassThrough(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(OptionalAuth(nil))
r.GET("/", func(c *gin.Context) {
_, hasUserID := c.Get("userID")
_, hasRole := c.Get("role")
assert.False(t, hasUserID)
assert.False(t, hasRole)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
res := httptest.NewRecorder()
r.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
}
func TestOptionalAuth_EmergencyBypassPassThrough(t *testing.T) {
t.Parallel()
authService := setupAuthService(t)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("emergency_bypass", true)
c.Next()
})
r.Use(OptionalAuth(authService))
r.GET("/", func(c *gin.Context) {
_, hasUserID := c.Get("userID")
_, hasRole := c.Get("role")
assert.False(t, hasUserID)
assert.False(t, hasRole)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
res := httptest.NewRecorder()
r.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
}
func TestOptionalAuth_RoleAlreadyInContextSkipsAuth(t *testing.T) {
t.Parallel()
authService := setupAuthService(t)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("role", "admin")
c.Set("userID", uint(42))
c.Next()
})
r.Use(OptionalAuth(authService))
r.GET("/", func(c *gin.Context) {
role, _ := c.Get("role")
userID, _ := c.Get("userID")
assert.Equal(t, "admin", role)
assert.Equal(t, uint(42), userID)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
res := httptest.NewRecorder()
r.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
}
func TestOptionalAuth_NoTokenPassThrough(t *testing.T) {
t.Parallel()
authService := setupAuthService(t)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(OptionalAuth(authService))
r.GET("/", func(c *gin.Context) {
_, hasUserID := c.Get("userID")
_, hasRole := c.Get("role")
assert.False(t, hasUserID)
assert.False(t, hasRole)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
res := httptest.NewRecorder()
r.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
}
func TestOptionalAuth_InvalidTokenPassThrough(t *testing.T) {
t.Parallel()
authService := setupAuthService(t)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(OptionalAuth(authService))
r.GET("/", func(c *gin.Context) {
_, hasUserID := c.Get("userID")
_, hasRole := c.Get("role")
assert.False(t, hasUserID)
assert.False(t, hasRole)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set("Authorization", "Bearer invalid-token")
res := httptest.NewRecorder()
r.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
}
func TestOptionalAuth_ValidTokenSetsContext(t *testing.T) {
t.Parallel()
authService, db := setupAuthServiceWithDB(t)
user := &models.User{Email: "optional-auth@example.com", Name: "Optional Auth", Role: models.RoleAdmin, Enabled: true}
require.NoError(t, user.SetPassword("password123"))
require.NoError(t, db.Create(user).Error)
token, err := authService.GenerateToken(user)
require.NoError(t, err)
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(OptionalAuth(authService))
r.GET("/", func(c *gin.Context) {
role, roleExists := c.Get("role")
userID, userExists := c.Get("userID")
require.True(t, roleExists)
require.True(t, userExists)
assert.Equal(t, "admin", role)
assert.Equal(t, user.ID, userID)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set("Authorization", "Bearer "+token)
res := httptest.NewRecorder()
r.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
}

View File

@@ -0,0 +1,47 @@
package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"github.com/Wikid82/charon/backend/internal/util"
"github.com/gin-gonic/gin"
)
// Recovery logs panic information. When verbose is true it logs stacktraces
// and basic request metadata for debugging.
func Recovery(verbose bool) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if r := recover(); r != nil {
// Try to get a request-scoped logger; fall back to global logger
entry := GetRequestLogger(c)
// Sanitize panic message to prevent logging sensitive data
panicMsg := util.SanitizeForLog(fmt.Sprintf("%v", r))
if len(panicMsg) > 200 {
panicMsg = panicMsg[:200] + "..."
}
if verbose {
// Log only the sanitized panic message and safe metadata.
// Stack traces can contain sensitive data from the call context,
// so we only log them internally without exposing raw values.
entry.WithFields(map[string]any{
"method": c.Request.Method,
"path": SanitizePath(c.Request.URL.Path),
}).Errorf("PANIC: %s", panicMsg)
// Log stack trace separately at debug level for operators
// who have enabled verbose logging and understand the risks
entry.Debugf("Stack trace available for panic recovery (not logged for security)")
_ = debug.Stack() // Capture but don't log to avoid CWE-312
} else {
entry.Errorf("PANIC: %s", panicMsg)
}
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
}
}()
c.Next()
}
}

View File

@@ -0,0 +1,231 @@
package middleware
import (
"bytes"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/gin-gonic/gin"
)
func TestRecoveryLogsStacktraceVerbose(t *testing.T) {
old := log.Writer()
buf := &bytes.Buffer{}
log.SetOutput(buf)
defer log.SetOutput(old)
// Ensure structured logger writes to the same buffer and enable debug
logger.Init(true, buf)
router := gin.New()
router.Use(RequestID())
router.Use(Recovery(true))
router.GET("/panic", func(c *gin.Context) {
panic("test panic")
})
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500, got %d", w.Code)
}
out := buf.String()
if !strings.Contains(out, "PANIC: test panic") {
t.Fatalf("log did not include panic message: %s", out)
}
// Stack traces are no longer logged to prevent CWE-312 (clear-text logging of sensitive data)
// We now log a debug message indicating stack trace is available but not logged
if !strings.Contains(out, "Stack trace available") {
t.Fatalf("verbose log did not include stack trace availability message: %s", out)
}
if !strings.Contains(out, "request_id") {
t.Fatalf("verbose log did not include request_id: %s", out)
}
}
func TestRecoveryLogsBriefWhenNotVerbose(t *testing.T) {
old := log.Writer()
buf := &bytes.Buffer{}
log.SetOutput(buf)
defer log.SetOutput(old)
// Ensure structured logger writes to the same buffer and keep debug off
logger.Init(false, buf)
router := gin.New()
router.Use(RequestID())
router.Use(Recovery(false))
router.GET("/panic", func(c *gin.Context) {
panic("brief panic")
})
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500, got %d", w.Code)
}
out := buf.String()
if !strings.Contains(out, "PANIC: brief panic") {
t.Fatalf("log did not include panic message: %s", out)
}
// Stack traces should not appear in non-verbose mode
if strings.Contains(out, "Stacktrace:") {
t.Fatalf("non-verbose log unexpectedly included stacktrace: %s", out)
}
}
// TestRecoveryDoesNotLogSensitiveHeaders verifies that sensitive headers
// are no longer logged at all (not even redacted) to prevent CWE-312.
func TestRecoveryDoesNotLogSensitiveHeaders(t *testing.T) {
old := log.Writer()
buf := &bytes.Buffer{}
log.SetOutput(buf)
defer log.SetOutput(old)
// Ensure structured logger writes to the same buffer and enable debug
logger.Init(true, buf)
router := gin.New()
router.Use(RequestID())
router.Use(Recovery(true))
router.GET("/panic", func(c *gin.Context) {
panic("sensitive panic")
})
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
// Add sensitive header that should not appear in logs at all
req.Header.Set("Authorization", "Bearer secret-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500, got %d", w.Code)
}
out := buf.String()
// Verify sensitive token is not logged
if strings.Contains(out, "secret-token") {
t.Fatalf("log contained sensitive token: %s", out)
}
// Headers are no longer logged at all to prevent potential information leakage
if strings.Contains(out, "headers") {
t.Fatalf("log should not include headers field: %s", out)
}
// Verify sanitized panic message is logged
if !strings.Contains(out, "PANIC: sensitive panic") {
t.Fatalf("log did not include sanitized panic message: %s", out)
}
}
// TestRecoveryTruncatesLongPanicMessage verifies that panic messages longer
// than 200 characters are truncated with "..." suffix.
func TestRecoveryTruncatesLongPanicMessage(t *testing.T) {
old := log.Writer()
buf := &bytes.Buffer{}
log.SetOutput(buf)
defer log.SetOutput(old)
logger.Init(false, buf)
router := gin.New()
router.Use(RequestID())
router.Use(Recovery(false))
// Create a panic message longer than 200 characters
longMessage := strings.Repeat("x", 250)
router.GET("/panic", func(c *gin.Context) {
panic(longMessage)
})
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500, got %d", w.Code)
}
out := buf.String()
// Should contain truncated message (200 chars + "...")
expectedTruncated := strings.Repeat("x", 200) + "..."
if !strings.Contains(out, expectedTruncated) {
t.Fatalf("log should contain truncated panic message with '...': %s", out)
}
// Should NOT contain the full 250 char message
if strings.Contains(out, longMessage) {
t.Fatalf("log should not contain full long panic message: %s", out)
}
}
// TestRecoveryNoPanicNormalFlow verifies that middleware passes through
// normally when no panic occurs.
func TestRecoveryNoPanicNormalFlow(t *testing.T) {
old := log.Writer()
buf := &bytes.Buffer{}
log.SetOutput(buf)
defer log.SetOutput(old)
logger.Init(false, buf)
router := gin.New()
router.Use(RequestID())
router.Use(Recovery(true))
router.GET("/ok", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/ok", http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
out := buf.String()
// Should NOT contain PANIC in logs
if strings.Contains(out, "PANIC") {
t.Fatalf("log should not contain PANIC for normal flow: %s", out)
}
}
// TestRecoveryPanicWithNilValue tests recovery from panic with a nil-like value.
// Note: panic(nil) behavior changed in Go 1.21+ and triggers linter warnings,
// so we use an explicit error value instead.
func TestRecoveryPanicWithNilValue(t *testing.T) {
old := log.Writer()
buf := &bytes.Buffer{}
log.SetOutput(buf)
defer log.SetOutput(old)
logger.Init(false, buf)
router := gin.New()
router.Use(RequestID())
router.Use(Recovery(false))
router.GET("/panic-nil", func(c *gin.Context) {
panic("intentional test panic with nil-like value")
})
req := httptest.NewRequest(http.MethodGet, "/panic-nil", http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Verify the panic was recovered and returned 500
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status 500, got %d", w.Code)
}
out := buf.String()
if !strings.Contains(out, "PANIC") {
t.Error("expected PANIC in log output")
}
}

View File

@@ -0,0 +1,40 @@
package middleware
import (
"context"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/Wikid82/charon/backend/internal/trace"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
const RequestIDHeader = "X-Request-ID"
// RequestID generates a uuid per request and places it in context and header.
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
rid := uuid.New().String()
c.Set(string(trace.RequestIDKey), rid)
c.Writer.Header().Set(RequestIDHeader, rid)
// Add to logger fields for this request
entry := logger.WithFields(map[string]any{"request_id": rid})
c.Set("logger", entry)
// Propagate into the request context so it can be used by services
ctx := context.WithValue(c.Request.Context(), trace.RequestIDKey, rid)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// GetRequestLogger retrieves the request-scoped logger from context or the global logger
func GetRequestLogger(c *gin.Context) *logrus.Entry {
if v, ok := c.Get("logger"); ok {
if entry, ok := v.(*logrus.Entry); ok {
return entry
}
}
// fallback
return logger.Log()
}

View File

@@ -0,0 +1,37 @@
package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/gin-gonic/gin"
)
func TestRequestIDAddsHeaderAndLogger(t *testing.T) {
buf := &bytes.Buffer{}
logger.Init(true, buf)
router := gin.New()
router.Use(RequestID())
router.GET("/test", func(c *gin.Context) {
// Ensure logger exists in context and header is present
if _, ok := c.Get("logger"); !ok {
t.Fatalf("expected request-scoped logger in context")
}
c.String(200, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if w.Header().Get(RequestIDHeader) == "" {
t.Fatalf("expected response to include X-Request-ID header")
}
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"time"
"github.com/Wikid82/charon/backend/internal/util"
"github.com/gin-gonic/gin"
)
// RequestLogger logs basic request information along with the request_id.
func RequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
latency := time.Since(start)
entry := GetRequestLogger(c)
entry.WithFields(map[string]any{
"status": c.Writer.Status(),
"method": c.Request.Method,
"path": SanitizePath(c.Request.URL.Path),
"latency": latency.String(),
"client": util.SanitizeForLog(c.ClientIP()),
}).Info("handled request")
}
}

View File

@@ -0,0 +1,72 @@
package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/gin-gonic/gin"
)
func TestRequestLoggerSanitizesPath(t *testing.T) {
old := logger.Log()
buf := &bytes.Buffer{}
logger.Init(true, buf)
longPath := "/" + strings.Repeat("a", 300)
router := gin.New()
router.Use(RequestID())
router.Use(RequestLogger())
router.GET(longPath, func(c *gin.Context) { c.Status(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, longPath, http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
out := buf.String()
if strings.Contains(out, strings.Repeat("a", 300)) {
t.Fatalf("logged unsanitized long path")
}
i := strings.Index(out, "path=")
if i == -1 {
t.Fatalf("could not find path in logs: %s", out)
}
sub := out[i:]
j := strings.Index(sub, " request_id=")
if j == -1 {
t.Fatalf("could not isolate path field from logs: %s", out)
}
pathField := sub[len("path="):j]
if strings.Contains(pathField, "\n") || strings.Contains(pathField, "\r") {
t.Fatalf("path field contains control characters after sanitization: %s", pathField)
}
_ = old // silence unused var
}
func TestRequestLoggerIncludesRequestID(t *testing.T) {
buf := &bytes.Buffer{}
logger.Init(true, buf)
router := gin.New()
router.Use(RequestID())
router.Use(RequestLogger())
router.GET("/ok", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/ok", http.NoBody)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d", w.Code)
}
out := buf.String()
if !strings.Contains(out, "request_id") {
t.Fatalf("expected log output to include request_id: %s", out)
}
if !strings.Contains(out, "handled request") {
t.Fatalf("expected log output to indicate handled request: %s", out)
}
}

View File

@@ -0,0 +1,62 @@
package middleware
import (
"net/http"
"strings"
"github.com/Wikid82/charon/backend/internal/util"
)
// SanitizeHeaders returns a map of header keys to redacted/sanitized values
// for safe logging. Sensitive headers are redacted; other values are
// sanitized using util.SanitizeForLog and truncated.
func SanitizeHeaders(h http.Header) map[string][]string {
if h == nil {
return nil
}
sensitive := map[string]struct{}{
"authorization": {},
"cookie": {},
"set-cookie": {},
"proxy-authorization": {},
"x-api-key": {},
"x-api-token": {},
"x-access-token": {},
"x-auth-token": {},
"x-api-secret": {},
"x-forwarded-for": {},
}
out := make(map[string][]string, len(h))
for k, vals := range h {
keyLower := strings.ToLower(k)
if _, ok := sensitive[keyLower]; ok {
out[k] = []string{"<redacted>"}
continue
}
sanitizedVals := make([]string, 0, len(vals))
for _, v := range vals {
v2 := util.SanitizeForLog(v)
if len(v2) > 200 {
v2 = v2[:200]
}
sanitizedVals = append(sanitizedVals, v2)
}
out[k] = sanitizedVals
}
return out
}
// SanitizePath prepares a request path for safe logging by removing
// control characters and truncating long values. It does not include
// query parameters.
func SanitizePath(p string) string {
// remove query string
if i := strings.Index(p, "?"); i != -1 {
p = p[:i]
}
p = util.SanitizeForLog(p)
if len(p) > 200 {
p = p[:200]
}
return p
}

View File

@@ -0,0 +1,55 @@
package middleware
import (
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestSanitizeHeaders(t *testing.T) {
t.Run("nil headers", func(t *testing.T) {
require.Nil(t, SanitizeHeaders(nil))
})
t.Run("redacts sensitive headers", func(t *testing.T) {
headers := http.Header{}
headers.Set("Authorization", "secret")
headers.Set("X-Api-Key", "token")
headers.Set("Cookie", "sessionid=abc")
sanitized := SanitizeHeaders(headers)
require.Equal(t, []string{"<redacted>"}, sanitized["Authorization"])
require.Equal(t, []string{"<redacted>"}, sanitized["X-Api-Key"])
require.Equal(t, []string{"<redacted>"}, sanitized["Cookie"])
})
t.Run("sanitizes and truncates values", func(t *testing.T) {
headers := http.Header{}
headers.Add("X-Trace", "line1\nline2\r\t")
headers.Add("X-Custom", strings.Repeat("a", 210))
sanitized := SanitizeHeaders(headers)
traceValue := sanitized["X-Trace"][0]
require.NotContains(t, traceValue, "\n")
require.NotContains(t, traceValue, "\r")
require.NotContains(t, traceValue, "\t")
customValue := sanitized["X-Custom"][0]
require.Equal(t, 200, len(customValue))
require.True(t, strings.HasPrefix(customValue, strings.Repeat("a", 200)))
})
}
func TestSanitizePath(t *testing.T) {
paddedPath := "/api/v1/resource/" + strings.Repeat("x", 210) + "?token=secret"
sanitized := SanitizePath(paddedPath)
require.NotContains(t, sanitized, "?")
require.False(t, strings.ContainsAny(sanitized, "\n\r\t"))
require.Equal(t, 200, len(sanitized))
}

View File

@@ -0,0 +1,130 @@
package middleware
import (
"fmt"
"strings"
"github.com/gin-gonic/gin"
)
// SecurityHeadersConfig holds configuration for the security headers middleware.
type SecurityHeadersConfig struct {
// IsDevelopment enables less strict settings for local development
IsDevelopment bool
// CustomCSPDirectives allows adding extra CSP directives
CustomCSPDirectives map[string]string
}
// DefaultSecurityHeadersConfig returns a secure default configuration.
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
return SecurityHeadersConfig{
IsDevelopment: false,
CustomCSPDirectives: nil,
}
}
// SecurityHeaders returns middleware that sets security-related HTTP headers.
// This implements Phase 1 of the security hardening plan.
func SecurityHeaders(cfg SecurityHeadersConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// Build Content-Security-Policy
csp := buildCSP(cfg)
c.Header("Content-Security-Policy", csp)
// Strict-Transport-Security (HSTS)
// max-age=31536000 = 1 year
// includeSubDomains ensures all subdomains also use HTTPS
// preload allows browser preload lists (requires submission to hstspreload.org)
if !cfg.IsDevelopment {
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
// X-Frame-Options: Prevent clickjacking
// DENY prevents any framing; SAMEORIGIN would allow same-origin framing
c.Header("X-Frame-Options", "DENY")
// X-Content-Type-Options: Prevent MIME sniffing
c.Header("X-Content-Type-Options", "nosniff")
// X-XSS-Protection: Enable browser XSS filtering (legacy but still useful)
// mode=block tells browser to block the response if XSS is detected
c.Header("X-XSS-Protection", "1; mode=block")
// Referrer-Policy: Control referrer information sent with requests
// strict-origin-when-cross-origin sends full URL for same-origin, origin only for cross-origin
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
// Permissions-Policy: Restrict browser features
// Disable features that aren't needed for security
c.Header("Permissions-Policy", buildPermissionsPolicy())
// Cross-Origin-Opener-Policy: Isolate browsing context
// Skip in development mode to avoid browser warnings on HTTP
// In production, Caddy always uses HTTPS, so safe to set unconditionally
if !cfg.IsDevelopment {
c.Header("Cross-Origin-Opener-Policy", "same-origin")
}
// Cross-Origin-Resource-Policy: Prevent cross-origin reads
c.Header("Cross-Origin-Resource-Policy", "same-origin")
// Cross-Origin-Embedder-Policy: Require CORP for cross-origin resources
// Note: This can break some external resources, use with caution
// c.Header("Cross-Origin-Embedder-Policy", "require-corp")
c.Next()
}
}
// buildCSP constructs the Content-Security-Policy header value.
func buildCSP(cfg SecurityHeadersConfig) string {
// Base CSP directives for a secure single-page application
directives := map[string]string{
"default-src": "'self'",
"script-src": "'self'",
"style-src": "'self' 'unsafe-inline'", // unsafe-inline needed for many CSS-in-JS solutions
"img-src": "'self' data: https:", // Allow HTTPS images and data URIs
"font-src": "'self' data:", // Allow self-hosted fonts and data URIs
"connect-src": "'self'", // API connections
"frame-src": "'none'", // No iframes
"object-src": "'none'", // No plugins (Flash, etc.)
"base-uri": "'self'", // Restrict base tag
"form-action": "'self'", // Restrict form submissions
}
// In development, allow more sources for hot reloading, etc.
if cfg.IsDevelopment {
directives["script-src"] = "'self' 'unsafe-inline' 'unsafe-eval'"
directives["connect-src"] = "'self' ws: wss:" // WebSocket for HMR
}
// Apply custom directives
for key, value := range cfg.CustomCSPDirectives {
directives[key] = value
}
// Build the CSP string
var parts []string
for directive, value := range directives {
parts = append(parts, fmt.Sprintf("%s %s", directive, value))
}
return strings.Join(parts, "; ")
}
// buildPermissionsPolicy constructs the Permissions-Policy header value.
func buildPermissionsPolicy() string {
// Disable features we don't need
policies := []string{
"accelerometer=()",
"camera=()",
"geolocation=()",
"gyroscope=()",
"magnetometer=()",
"microphone=()",
"payment=()",
"usb=()",
}
return strings.Join(policies, ", ")
}

View File

@@ -0,0 +1,223 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestSecurityHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
isDevelopment bool
checkHeaders func(t *testing.T, resp *httptest.ResponseRecorder)
}{
{
name: "production mode sets HSTS",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
hsts := resp.Header().Get("Strict-Transport-Security")
assert.Contains(t, hsts, "max-age=31536000")
assert.Contains(t, hsts, "includeSubDomains")
assert.Contains(t, hsts, "preload")
},
},
{
name: "development mode skips HSTS",
isDevelopment: true,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
hsts := resp.Header().Get("Strict-Transport-Security")
assert.Empty(t, hsts)
},
},
{
name: "sets X-Frame-Options",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, "DENY", resp.Header().Get("X-Frame-Options"))
},
},
{
name: "sets X-Content-Type-Options",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, "nosniff", resp.Header().Get("X-Content-Type-Options"))
},
},
{
name: "sets X-XSS-Protection",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, "1; mode=block", resp.Header().Get("X-XSS-Protection"))
},
},
{
name: "sets Referrer-Policy",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, "strict-origin-when-cross-origin", resp.Header().Get("Referrer-Policy"))
},
},
{
name: "sets Content-Security-Policy",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
csp := resp.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
assert.Contains(t, csp, "default-src")
},
},
{
name: "development mode CSP allows unsafe-eval",
isDevelopment: true,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
csp := resp.Header().Get("Content-Security-Policy")
assert.Contains(t, csp, "unsafe-eval")
},
},
{
name: "sets Permissions-Policy",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
pp := resp.Header().Get("Permissions-Policy")
assert.NotEmpty(t, pp)
assert.Contains(t, pp, "camera=()")
assert.Contains(t, pp, "microphone=()")
},
},
{
name: "sets Cross-Origin-Opener-Policy in production",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, "same-origin", resp.Header().Get("Cross-Origin-Opener-Policy"))
},
},
{
name: "skips Cross-Origin-Opener-Policy in development",
isDevelopment: true,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Empty(t, resp.Header().Get("Cross-Origin-Opener-Policy"))
},
},
{
name: "sets Cross-Origin-Resource-Policy",
isDevelopment: false,
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, "same-origin", resp.Header().Get("Cross-Origin-Resource-Policy"))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(SecurityHeaders(SecurityHeadersConfig{
IsDevelopment: tt.isDevelopment,
}))
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
tt.checkHeaders(t, resp)
})
}
}
func TestSecurityHeadersCustomCSP(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(SecurityHeaders(SecurityHeadersConfig{
IsDevelopment: false,
CustomCSPDirectives: map[string]string{
"frame-src": "'self' https://trusted.com",
},
}))
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
csp := resp.Header().Get("Content-Security-Policy")
assert.Contains(t, csp, "frame-src 'self' https://trusted.com")
}
func TestDefaultSecurityHeadersConfig(t *testing.T) {
cfg := DefaultSecurityHeadersConfig()
assert.False(t, cfg.IsDevelopment)
assert.Nil(t, cfg.CustomCSPDirectives)
}
func TestSecurityHeaders_COOP_DevelopmentMode(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
cfg := SecurityHeadersConfig{IsDevelopment: true}
router.Use(SecurityHeaders(cfg))
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest("GET", "/test", http.NoBody)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
assert.Empty(t, resp.Header().Get("Cross-Origin-Opener-Policy"),
"COOP header should not be set in development mode")
}
func TestSecurityHeaders_COOP_ProductionMode(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
cfg := SecurityHeadersConfig{IsDevelopment: false}
router.Use(SecurityHeaders(cfg))
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest("GET", "/test", http.NoBody)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
assert.Equal(t, "same-origin", resp.Header().Get("Cross-Origin-Opener-Policy"),
"COOP header must be set in production mode")
}
func TestBuildCSP(t *testing.T) {
t.Run("production CSP", func(t *testing.T) {
csp := buildCSP(SecurityHeadersConfig{IsDevelopment: false})
assert.Contains(t, csp, "default-src 'self'")
assert.Contains(t, csp, "script-src 'self'")
assert.NotContains(t, csp, "unsafe-eval")
})
t.Run("development CSP", func(t *testing.T) {
csp := buildCSP(SecurityHeadersConfig{IsDevelopment: true})
assert.Contains(t, csp, "unsafe-eval")
assert.Contains(t, csp, "ws:")
})
}
func TestBuildPermissionsPolicy(t *testing.T) {
pp := buildPermissionsPolicy()
// Check that dangerous features are disabled
disabledFeatures := []string{"camera", "microphone", "geolocation", "payment"}
for _, feature := range disabledFeatures {
assert.True(t, strings.Contains(pp, feature+"=()"),
"Expected %s to be disabled in permissions policy", feature)
}
}