chore: git cache cleanup
This commit is contained in:
125
backend/internal/api/middleware/auth.go
Normal file
125
backend/internal/api/middleware/auth.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
"github.com/Wikid82/charon/backend/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AuthMiddleware(authService *services.AuthService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if bypass, exists := c.Get("emergency_bypass"); exists {
|
||||
if bypassActive, ok := bypass.(bool); ok && bypassActive {
|
||||
c.Set("role", "admin")
|
||||
c.Set("userID", uint(0))
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if authService == nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString, ok := extractAuthToken(c)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, _, err := authService.AuthenticateToken(tokenString)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("userID", user.ID)
|
||||
c.Set("role", string(user.Role))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractAuthToken(c *gin.Context) (string, bool) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
|
||||
// Fall back to cookie for browser flows (including WebSocket upgrades)
|
||||
if authHeader == "" {
|
||||
if cookieToken := extractAuthCookieToken(c); cookieToken != "" {
|
||||
authHeader = "Bearer " + cookieToken
|
||||
}
|
||||
}
|
||||
|
||||
// DEPRECATED: Query parameter authentication for WebSocket connections
|
||||
// This fallback exists only for backward compatibility and will be removed in a future version.
|
||||
// Query parameters are logged in access logs and should not be used for sensitive data.
|
||||
// Use HttpOnly cookies instead, which are automatically sent by browsers and not logged.
|
||||
if authHeader == "" {
|
||||
if token := c.Query("token"); token != "" {
|
||||
authHeader = "Bearer " + token
|
||||
}
|
||||
}
|
||||
|
||||
if authHeader == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return tokenString, true
|
||||
}
|
||||
|
||||
func extractAuthCookieToken(c *gin.Context) string {
|
||||
if c.Request == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
token := ""
|
||||
for _, cookie := range c.Request.Cookies() {
|
||||
if cookie.Name != "auth_token" {
|
||||
continue
|
||||
}
|
||||
|
||||
if cookie.Value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
token = cookie.Value
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func RequireRole(role models.UserRole) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userRole := c.GetString("role")
|
||||
if userRole == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
if userRole != string(role) && userRole != string(models.RoleAdmin) {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "Forbidden"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RequireManagementAccess() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role := c.GetString("role")
|
||||
if role == string(models.RolePassthrough) {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "Pass-through users cannot access management features"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
487
backend/internal/api/middleware/auth_test.go
Normal file
487
backend/internal/api/middleware/auth_test.go
Normal file
@@ -0,0 +1,487 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/config"
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
"github.com/Wikid82/charon/backend/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupAuthService(t *testing.T) *services.AuthService {
|
||||
authService, _ := setupAuthServiceWithDB(t)
|
||||
return authService
|
||||
}
|
||||
|
||||
func setupAuthServiceWithDB(t *testing.T) (*services.AuthService, *gorm.DB) {
|
||||
dbName := "file:" + t.Name() + "?mode=memory&cache=shared"
|
||||
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
_ = db.AutoMigrate(&models.User{})
|
||||
cfg := config.Config{JWTSecret: "test-secret"}
|
||||
return services.NewAuthService(db, cfg), db
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_MissingHeader(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
// We pass nil for authService because we expect it to fail before using it
|
||||
r.Use(AuthMiddleware(nil))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Authorization header required")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_EmergencyBypass(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("emergency_bypass", true)
|
||||
c.Next()
|
||||
})
|
||||
r.Use(AuthMiddleware(nil))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
role, _ := c.Get("role")
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, "admin", role)
|
||||
assert.Equal(t, uint(0), userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRequireRole_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("role", "admin")
|
||||
c.Next()
|
||||
})
|
||||
r.Use(RequireRole("admin"))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRequireRole_Forbidden(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("role", "user")
|
||||
c.Next()
|
||||
})
|
||||
r.Use(RequireRole("admin"))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Cookie(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
user, err := authService.Register("test@example.com", "password", "Test User")
|
||||
require.NoError(t, err)
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, user.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
req.AddCookie(&http.Cookie{Name: "auth_token", Value: token})
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidToken(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
user, err := authService.Register("test@example.com", "password", "Test User")
|
||||
require.NoError(t, err)
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, user.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_PrefersCookieOverAuthorizationHeader(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
cookieUser, _ := authService.Register("cookie-header@example.com", "password", "Cookie Header User")
|
||||
cookieToken, _ := authService.GenerateToken(cookieUser)
|
||||
headerUser, _ := authService.Register("header@example.com", "password", "Header User")
|
||||
headerToken, _ := authService.GenerateToken(headerUser)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, headerUser.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
req.Header.Set("Authorization", "Bearer "+headerToken)
|
||||
req.AddCookie(&http.Cookie{Name: "auth_token", Value: cookieToken})
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_UsesCookieWhenAuthorizationHeaderIsInvalid(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
user, err := authService.Register("cookie-valid@example.com", "password", "Cookie Valid User")
|
||||
require.NoError(t, err)
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, user.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "/test", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
req.AddCookie(&http.Cookie{Name: "auth_token", Value: token})
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_UsesLastNonEmptyCookieWhenDuplicateCookiesExist(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
user, err := authService.Register("dupecookie@example.com", "password", "Dup Cookie User")
|
||||
require.NoError(t, err)
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, user.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "/test", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{Name: "auth_token", Value: ""})
|
||||
req.AddCookie(&http.Cookie{Name: "auth_token", Value: token})
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_InvalidToken(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Invalid token")
|
||||
}
|
||||
|
||||
func TestRequireRole_MissingRoleInContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
// No role set in context
|
||||
r.Use(RequireRole("admin"))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_QueryParamFallback(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
user, err := authService.Register("test@example.com", "password", "Test User")
|
||||
require.NoError(t, err)
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, user.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Test that query param auth still works (deprecated fallback)
|
||||
req, err := http.NewRequest("GET", "/test?token="+token, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_PrefersCookieOverQueryParam(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
|
||||
// Create two different users
|
||||
cookieUser, err := authService.Register("cookie@example.com", "password", "Cookie User")
|
||||
require.NoError(t, err)
|
||||
cookieToken, err := authService.GenerateToken(cookieUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
queryUser, err := authService.Register("query@example.com", "password", "Query User")
|
||||
require.NoError(t, err)
|
||||
queryToken, err := authService.GenerateToken(queryUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("userID")
|
||||
// Should use the cookie user, not the query param user
|
||||
assert.Equal(t, cookieUser.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Both cookie and query param provided - cookie should win
|
||||
req, err := http.NewRequest("GET", "/test?token="+queryToken, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{Name: "auth_token", Value: cookieToken})
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RejectsDisabledUserToken(t *testing.T) {
|
||||
authService, db := setupAuthServiceWithDB(t)
|
||||
user, err := authService.Register("disabled@example.com", "password", "Disabled User")
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.Model(&models.User{}).Where("id = ?", user.ID).Update("enabled", false).Error)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "/test", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RejectsDeletedUserToken(t *testing.T) {
|
||||
authService, db := setupAuthServiceWithDB(t)
|
||||
user, err := authService.Register("deleted@example.com", "password", "Deleted User")
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.Delete(&models.User{}, user.ID).Error)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "/test", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RejectsTokenAfterSessionInvalidation(t *testing.T) {
|
||||
authService := setupAuthService(t)
|
||||
user, err := authService.Register("session-invalidated@example.com", "password", "Session Invalidated")
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, authService.InvalidateSessions(user.ID))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(AuthMiddleware(authService))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "/test", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
|
||||
func TestExtractAuthCookieToken_ReturnsEmptyWhenRequestNil(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = nil
|
||||
|
||||
token := extractAuthCookieToken(ctx)
|
||||
assert.Equal(t, "", token)
|
||||
}
|
||||
|
||||
func TestExtractAuthCookieToken_IgnoresNonAuthCookies(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
req, err := http.NewRequest("GET", "/", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{Name: "session", Value: "abc"})
|
||||
ctx.Request = req
|
||||
|
||||
token := extractAuthCookieToken(ctx)
|
||||
assert.Equal(t, "", token)
|
||||
}
|
||||
|
||||
func TestRequireManagementAccess_PassthroughBlocked(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("role", string(models.RolePassthrough))
|
||||
c.Next()
|
||||
})
|
||||
r.Use(RequireManagementAccess())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Pass-through users cannot access management features")
|
||||
}
|
||||
|
||||
func TestRequireManagementAccess_UserAllowed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("role", string(models.RoleUser))
|
||||
c.Next()
|
||||
})
|
||||
r.Use(RequireManagementAccess())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRequireManagementAccess_AdminAllowed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("role", string(models.RoleAdmin))
|
||||
c.Next()
|
||||
})
|
||||
r.Use(RequireManagementAccess())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
5
backend/internal/api/middleware/doc.go
Normal file
5
backend/internal/api/middleware/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package middleware provides Gin middleware for the Charon backend API.
|
||||
//
|
||||
// It includes middleware for authentication, request logging, panic recovery,
|
||||
// security headers, and request ID generation.
|
||||
package middleware
|
||||
129
backend/internal/api/middleware/emergency.go
Normal file
129
backend/internal/api/middleware/emergency.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
// EmergencyTokenHeader is the HTTP header name for emergency token
|
||||
EmergencyTokenHeader = "X-Emergency-Token"
|
||||
// EmergencyTokenEnvVar is the environment variable name for emergency token
|
||||
EmergencyTokenEnvVar = "CHARON_EMERGENCY_TOKEN"
|
||||
// MinTokenLength is the minimum required length for emergency tokens
|
||||
MinTokenLength = 32
|
||||
)
|
||||
|
||||
// EmergencyBypass creates middleware that bypasses all security checks
|
||||
// when a valid emergency token is present from an authorized source.
|
||||
//
|
||||
// Security conditions (ALL must be met):
|
||||
// 1. Request from management CIDR (RFC1918 private networks by default)
|
||||
// 2. X-Emergency-Token header matches configured token (timing-safe)
|
||||
// 3. Token meets minimum length requirement (32+ chars)
|
||||
//
|
||||
// This middleware must be registered FIRST in the middleware chain.
|
||||
func EmergencyBypass(managementCIDRs []string, db *gorm.DB) gin.HandlerFunc {
|
||||
// Load emergency token from environment
|
||||
emergencyToken := os.Getenv(EmergencyTokenEnvVar)
|
||||
if emergencyToken == "" {
|
||||
logger.Log().Warn("CHARON_EMERGENCY_TOKEN not set - emergency bypass disabled")
|
||||
return func(c *gin.Context) { c.Next() } // noop
|
||||
}
|
||||
|
||||
if len(emergencyToken) < MinTokenLength {
|
||||
logger.Log().Warn("CHARON_EMERGENCY_TOKEN too short - emergency bypass disabled")
|
||||
return func(c *gin.Context) { c.Next() } // noop
|
||||
}
|
||||
|
||||
// Parse management CIDRs
|
||||
var managementNets []*net.IPNet
|
||||
for _, cidr := range managementCIDRs {
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
logger.Log().WithError(err).WithField("cidr", cidr).Warn("Invalid management CIDR")
|
||||
continue
|
||||
}
|
||||
managementNets = append(managementNets, ipnet)
|
||||
}
|
||||
|
||||
// Default to RFC1918 private networks if none specified
|
||||
if len(managementNets) == 0 {
|
||||
managementNets = []*net.IPNet{
|
||||
mustParseCIDR("10.0.0.0/8"),
|
||||
mustParseCIDR("172.16.0.0/12"),
|
||||
mustParseCIDR("192.168.0.0/16"),
|
||||
mustParseCIDR("127.0.0.0/8"), // localhost for local development
|
||||
mustParseCIDR("::1/128"), // IPv6 localhost
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Check if emergency token is present
|
||||
providedToken := c.GetHeader(EmergencyTokenHeader)
|
||||
if providedToken == "" {
|
||||
c.Next() // No emergency token - proceed normally
|
||||
return
|
||||
}
|
||||
|
||||
// Validate source IP is from management network
|
||||
clientIPStr := util.CanonicalizeIPForSecurity(c.ClientIP())
|
||||
clientIP := net.ParseIP(clientIPStr)
|
||||
if clientIP == nil {
|
||||
logger.Log().WithField("ip", util.SanitizeForLog(clientIPStr)).Warn("Emergency bypass: invalid client IP")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
inManagementNet := false
|
||||
for _, ipnet := range managementNets {
|
||||
if ipnet.Contains(clientIP) {
|
||||
inManagementNet = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !inManagementNet {
|
||||
logger.Log().WithField("ip", util.SanitizeForLog(clientIP.String())).Warn("Emergency bypass: IP not in management network")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Timing-safe token comparison
|
||||
if !constantTimeCompare(emergencyToken, providedToken) {
|
||||
logger.Log().WithField("ip", util.SanitizeForLog(clientIP.String())).Warn("Emergency bypass: invalid token")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Valid emergency token from authorized source
|
||||
logger.Log().WithFields(map[string]interface{}{
|
||||
"ip": util.SanitizeForLog(clientIP.String()),
|
||||
"path": util.SanitizeForLog(c.Request.URL.Path),
|
||||
}).Warn("EMERGENCY BYPASS ACTIVE: Request bypassing all security checks")
|
||||
|
||||
// Set flag for downstream handlers to know this is an emergency request
|
||||
c.Set("emergency_bypass", true)
|
||||
|
||||
// Strip emergency token header to prevent it from reaching application
|
||||
// This is critical for security - prevents token exposure in logs
|
||||
c.Request.Header.Del(EmergencyTokenHeader)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseCIDR(cidr string) *net.IPNet {
|
||||
_, ipnet, _ := net.ParseCIDR(cidr)
|
||||
return ipnet
|
||||
}
|
||||
|
||||
func constantTimeCompare(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
||||
277
backend/internal/api/middleware/emergency_test.go
Normal file
277
backend/internal/api/middleware/emergency_test.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEmergencyBypass_NoToken(t *testing.T) {
|
||||
// Test that requests without emergency token proceed normally
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
_, exists := c.Get("emergency_bypass")
|
||||
assert.False(t, exists, "Emergency bypass flag should not be set")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_InvalidClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
_, exists := c.Get("emergency_bypass")
|
||||
assert.False(t, exists, "Emergency bypass flag should not be set for invalid client IP")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
req.RemoteAddr = "invalid-remote-addr"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_ValidToken(t *testing.T) {
|
||||
// Test that valid token from allowed IP sets bypass flag
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
bypass, exists := c.Get("emergency_bypass")
|
||||
assert.True(t, exists, "Emergency bypass flag should be set")
|
||||
assert.True(t, bypass.(bool), "Emergency bypass flag should be true")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "bypass active"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Verify token was stripped from request
|
||||
assert.Empty(t, req.Header.Get(EmergencyTokenHeader), "Token should be stripped")
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_ValidToken_IPv6Localhost(t *testing.T) {
|
||||
// Test that valid token from IPv6 localhost is treated as localhost
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
_ = router.SetTrustedProxies(nil)
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
bypass, exists := c.Get("emergency_bypass")
|
||||
assert.True(t, exists, "Emergency bypass flag should be set")
|
||||
assert.True(t, bypass.(bool), "Emergency bypass flag should be true")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "bypass active"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
req.RemoteAddr = "[::1]:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_InvalidToken(t *testing.T) {
|
||||
// Test that invalid token does not set bypass flag
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
_, exists := c.Get("emergency_bypass")
|
||||
assert.False(t, exists, "Emergency bypass flag should not be set")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "wrong-token")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_UnauthorizedIP(t *testing.T) {
|
||||
// Test that valid token from disallowed IP does not set bypass flag
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
_, exists := c.Get("emergency_bypass")
|
||||
assert.False(t, exists, "Emergency bypass flag should not be set")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
req.RemoteAddr = "203.0.113.1:12345" // Public IP (not in management network)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_TokenStripped(t *testing.T) {
|
||||
// Test that emergency token header is removed after validation
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
var tokenInHandler string
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
tokenInHandler = c.GetHeader(EmergencyTokenHeader)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Empty(t, tokenInHandler, "Token should not be visible in downstream handlers")
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_MinimumLength(t *testing.T) {
|
||||
// Test that tokens < 32 chars are rejected
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "short-token")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
_, exists := c.Get("emergency_bypass")
|
||||
assert.False(t, exists, "Emergency bypass flag should not be set with short token")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "short-token")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_NoTokenConfigured(t *testing.T) {
|
||||
// Test that middleware is no-op when token not configured
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Don't set CHARON_EMERGENCY_TOKEN
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "")
|
||||
|
||||
router := gin.New()
|
||||
managementCIDRs := []string{"127.0.0.0/8"}
|
||||
router.Use(EmergencyBypass(managementCIDRs, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
_, exists := c.Get("emergency_bypass")
|
||||
assert.False(t, exists, "Emergency bypass flag should not be set")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "any-token")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestEmergencyBypass_DefaultCIDRs(t *testing.T) {
|
||||
// Test that RFC1918 networks are used by default
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("CHARON_EMERGENCY_TOKEN", "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
|
||||
router := gin.New()
|
||||
// Pass empty CIDR list to trigger default behavior
|
||||
router.Use(EmergencyBypass([]string{}, nil))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
bypass, exists := c.Get("emergency_bypass")
|
||||
assert.True(t, exists, "Emergency bypass flag should be set")
|
||||
assert.True(t, bypass.(bool), "Emergency bypass flag should be true")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "bypass active"})
|
||||
})
|
||||
|
||||
// Test with various RFC1918 addresses
|
||||
testIPs := []string{
|
||||
"10.0.0.1:12345",
|
||||
"172.16.0.1:12345",
|
||||
"192.168.1.1:12345",
|
||||
"127.0.0.1:12345",
|
||||
}
|
||||
|
||||
for _, remoteAddr := range testIPs {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(EmergencyTokenHeader, "test-token-that-meets-minimum-length-requirement-32-chars")
|
||||
req.RemoteAddr = remoteAddr
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Should accept IP: %s", remoteAddr)
|
||||
}
|
||||
}
|
||||
44
backend/internal/api/middleware/optional_auth.go
Normal file
44
backend/internal/api/middleware/optional_auth.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/Wikid82/charon/backend/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OptionalAuth applies best-effort authentication for downstream middleware without blocking requests.
|
||||
func OptionalAuth(authService *services.AuthService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if authService == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if bypass, exists := c.Get("emergency_bypass"); exists {
|
||||
if bypassActive, ok := bypass.(bool); ok && bypassActive {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := c.Get("role"); exists {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString, ok := extractAuthToken(c)
|
||||
if !ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
user, _, err := authService.AuthenticateToken(tokenString)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("userID", user.ID)
|
||||
c.Set("role", string(user.Role))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
167
backend/internal/api/middleware/optional_auth_test.go
Normal file
167
backend/internal/api/middleware/optional_auth_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOptionalAuth_NilServicePassThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(OptionalAuth(nil))
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
_, hasUserID := c.Get("userID")
|
||||
_, hasRole := c.Get("role")
|
||||
assert.False(t, hasUserID)
|
||||
assert.False(t, hasRole)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
res := httptest.NewRecorder()
|
||||
r.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
}
|
||||
|
||||
func TestOptionalAuth_EmergencyBypassPassThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := setupAuthService(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("emergency_bypass", true)
|
||||
c.Next()
|
||||
})
|
||||
r.Use(OptionalAuth(authService))
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
_, hasUserID := c.Get("userID")
|
||||
_, hasRole := c.Get("role")
|
||||
assert.False(t, hasUserID)
|
||||
assert.False(t, hasRole)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
res := httptest.NewRecorder()
|
||||
r.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
}
|
||||
|
||||
func TestOptionalAuth_RoleAlreadyInContextSkipsAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := setupAuthService(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("role", "admin")
|
||||
c.Set("userID", uint(42))
|
||||
c.Next()
|
||||
})
|
||||
r.Use(OptionalAuth(authService))
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
role, _ := c.Get("role")
|
||||
userID, _ := c.Get("userID")
|
||||
assert.Equal(t, "admin", role)
|
||||
assert.Equal(t, uint(42), userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
res := httptest.NewRecorder()
|
||||
r.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
}
|
||||
|
||||
func TestOptionalAuth_NoTokenPassThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := setupAuthService(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(OptionalAuth(authService))
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
_, hasUserID := c.Get("userID")
|
||||
_, hasRole := c.Get("role")
|
||||
assert.False(t, hasUserID)
|
||||
assert.False(t, hasRole)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
res := httptest.NewRecorder()
|
||||
r.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
}
|
||||
|
||||
func TestOptionalAuth_InvalidTokenPassThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := setupAuthService(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(OptionalAuth(authService))
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
_, hasUserID := c.Get("userID")
|
||||
_, hasRole := c.Get("role")
|
||||
assert.False(t, hasUserID)
|
||||
assert.False(t, hasRole)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
res := httptest.NewRecorder()
|
||||
r.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
}
|
||||
|
||||
func TestOptionalAuth_ValidTokenSetsContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService, db := setupAuthServiceWithDB(t)
|
||||
user := &models.User{Email: "optional-auth@example.com", Name: "Optional Auth", Role: models.RoleAdmin, Enabled: true}
|
||||
require.NoError(t, user.SetPassword("password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
|
||||
token, err := authService.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(OptionalAuth(authService))
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
role, roleExists := c.Get("role")
|
||||
userID, userExists := c.Get("userID")
|
||||
require.True(t, roleExists)
|
||||
require.True(t, userExists)
|
||||
assert.Equal(t, "admin", role)
|
||||
assert.Equal(t, user.ID, userID)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
res := httptest.NewRecorder()
|
||||
r.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
}
|
||||
47
backend/internal/api/middleware/recovery.go
Normal file
47
backend/internal/api/middleware/recovery.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery logs panic information. When verbose is true it logs stacktraces
|
||||
// and basic request metadata for debugging.
|
||||
func Recovery(verbose bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Try to get a request-scoped logger; fall back to global logger
|
||||
entry := GetRequestLogger(c)
|
||||
|
||||
// Sanitize panic message to prevent logging sensitive data
|
||||
panicMsg := util.SanitizeForLog(fmt.Sprintf("%v", r))
|
||||
if len(panicMsg) > 200 {
|
||||
panicMsg = panicMsg[:200] + "..."
|
||||
}
|
||||
|
||||
if verbose {
|
||||
// Log only the sanitized panic message and safe metadata.
|
||||
// Stack traces can contain sensitive data from the call context,
|
||||
// so we only log them internally without exposing raw values.
|
||||
entry.WithFields(map[string]any{
|
||||
"method": c.Request.Method,
|
||||
"path": SanitizePath(c.Request.URL.Path),
|
||||
}).Errorf("PANIC: %s", panicMsg)
|
||||
// Log stack trace separately at debug level for operators
|
||||
// who have enabled verbose logging and understand the risks
|
||||
entry.Debugf("Stack trace available for panic recovery (not logged for security)")
|
||||
_ = debug.Stack() // Capture but don't log to avoid CWE-312
|
||||
} else {
|
||||
entry.Errorf("PANIC: %s", panicMsg)
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
231
backend/internal/api/middleware/recovery_test.go
Normal file
231
backend/internal/api/middleware/recovery_test.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestRecoveryLogsStacktraceVerbose(t *testing.T) {
|
||||
old := log.Writer()
|
||||
buf := &bytes.Buffer{}
|
||||
log.SetOutput(buf)
|
||||
defer log.SetOutput(old)
|
||||
// Ensure structured logger writes to the same buffer and enable debug
|
||||
logger.Init(true, buf)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(Recovery(true))
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "PANIC: test panic") {
|
||||
t.Fatalf("log did not include panic message: %s", out)
|
||||
}
|
||||
// Stack traces are no longer logged to prevent CWE-312 (clear-text logging of sensitive data)
|
||||
// We now log a debug message indicating stack trace is available but not logged
|
||||
if !strings.Contains(out, "Stack trace available") {
|
||||
t.Fatalf("verbose log did not include stack trace availability message: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "request_id") {
|
||||
t.Fatalf("verbose log did not include request_id: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryLogsBriefWhenNotVerbose(t *testing.T) {
|
||||
old := log.Writer()
|
||||
buf := &bytes.Buffer{}
|
||||
log.SetOutput(buf)
|
||||
defer log.SetOutput(old)
|
||||
|
||||
// Ensure structured logger writes to the same buffer and keep debug off
|
||||
logger.Init(false, buf)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(Recovery(false))
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("brief panic")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "PANIC: brief panic") {
|
||||
t.Fatalf("log did not include panic message: %s", out)
|
||||
}
|
||||
// Stack traces should not appear in non-verbose mode
|
||||
if strings.Contains(out, "Stacktrace:") {
|
||||
t.Fatalf("non-verbose log unexpectedly included stacktrace: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecoveryDoesNotLogSensitiveHeaders verifies that sensitive headers
|
||||
// are no longer logged at all (not even redacted) to prevent CWE-312.
|
||||
func TestRecoveryDoesNotLogSensitiveHeaders(t *testing.T) {
|
||||
old := log.Writer()
|
||||
buf := &bytes.Buffer{}
|
||||
log.SetOutput(buf)
|
||||
defer log.SetOutput(old)
|
||||
|
||||
// Ensure structured logger writes to the same buffer and enable debug
|
||||
logger.Init(true, buf)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(Recovery(true))
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("sensitive panic")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
|
||||
// Add sensitive header that should not appear in logs at all
|
||||
req.Header.Set("Authorization", "Bearer secret-token")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
out := buf.String()
|
||||
// Verify sensitive token is not logged
|
||||
if strings.Contains(out, "secret-token") {
|
||||
t.Fatalf("log contained sensitive token: %s", out)
|
||||
}
|
||||
// Headers are no longer logged at all to prevent potential information leakage
|
||||
if strings.Contains(out, "headers") {
|
||||
t.Fatalf("log should not include headers field: %s", out)
|
||||
}
|
||||
// Verify sanitized panic message is logged
|
||||
if !strings.Contains(out, "PANIC: sensitive panic") {
|
||||
t.Fatalf("log did not include sanitized panic message: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecoveryTruncatesLongPanicMessage verifies that panic messages longer
|
||||
// than 200 characters are truncated with "..." suffix.
|
||||
func TestRecoveryTruncatesLongPanicMessage(t *testing.T) {
|
||||
old := log.Writer()
|
||||
buf := &bytes.Buffer{}
|
||||
log.SetOutput(buf)
|
||||
defer log.SetOutput(old)
|
||||
|
||||
logger.Init(false, buf)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(Recovery(false))
|
||||
|
||||
// Create a panic message longer than 200 characters
|
||||
longMessage := strings.Repeat("x", 250)
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic(longMessage)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
out := buf.String()
|
||||
// Should contain truncated message (200 chars + "...")
|
||||
expectedTruncated := strings.Repeat("x", 200) + "..."
|
||||
if !strings.Contains(out, expectedTruncated) {
|
||||
t.Fatalf("log should contain truncated panic message with '...': %s", out)
|
||||
}
|
||||
// Should NOT contain the full 250 char message
|
||||
if strings.Contains(out, longMessage) {
|
||||
t.Fatalf("log should not contain full long panic message: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecoveryNoPanicNormalFlow verifies that middleware passes through
|
||||
// normally when no panic occurs.
|
||||
func TestRecoveryNoPanicNormalFlow(t *testing.T) {
|
||||
old := log.Writer()
|
||||
buf := &bytes.Buffer{}
|
||||
log.SetOutput(buf)
|
||||
defer log.SetOutput(old)
|
||||
|
||||
logger.Init(false, buf)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(Recovery(true))
|
||||
router.GET("/ok", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ok", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
out := buf.String()
|
||||
// Should NOT contain PANIC in logs
|
||||
if strings.Contains(out, "PANIC") {
|
||||
t.Fatalf("log should not contain PANIC for normal flow: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecoveryPanicWithNilValue tests recovery from panic with a nil-like value.
|
||||
// Note: panic(nil) behavior changed in Go 1.21+ and triggers linter warnings,
|
||||
// so we use an explicit error value instead.
|
||||
func TestRecoveryPanicWithNilValue(t *testing.T) {
|
||||
old := log.Writer()
|
||||
buf := &bytes.Buffer{}
|
||||
log.SetOutput(buf)
|
||||
defer log.SetOutput(old)
|
||||
|
||||
logger.Init(false, buf)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(Recovery(false))
|
||||
router.GET("/panic-nil", func(c *gin.Context) {
|
||||
panic("intentional test panic with nil-like value")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic-nil", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Verify the panic was recovered and returned 500
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "PANIC") {
|
||||
t.Error("expected PANIC in log output")
|
||||
}
|
||||
}
|
||||
40
backend/internal/api/middleware/request_id.go
Normal file
40
backend/internal/api/middleware/request_id.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/trace"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const RequestIDHeader = "X-Request-ID"
|
||||
|
||||
// RequestID generates a uuid per request and places it in context and header.
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
rid := uuid.New().String()
|
||||
c.Set(string(trace.RequestIDKey), rid)
|
||||
c.Writer.Header().Set(RequestIDHeader, rid)
|
||||
// Add to logger fields for this request
|
||||
entry := logger.WithFields(map[string]any{"request_id": rid})
|
||||
c.Set("logger", entry)
|
||||
// Propagate into the request context so it can be used by services
|
||||
ctx := context.WithValue(c.Request.Context(), trace.RequestIDKey, rid)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestLogger retrieves the request-scoped logger from context or the global logger
|
||||
func GetRequestLogger(c *gin.Context) *logrus.Entry {
|
||||
if v, ok := c.Get("logger"); ok {
|
||||
if entry, ok := v.(*logrus.Entry); ok {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
// fallback
|
||||
return logger.Log()
|
||||
}
|
||||
37
backend/internal/api/middleware/request_id_test.go
Normal file
37
backend/internal/api/middleware/request_id_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestRequestIDAddsHeaderAndLogger(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
logger.Init(true, buf)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// Ensure logger exists in context and header is present
|
||||
if _, ok := c.Get("logger"); !ok {
|
||||
t.Fatalf("expected request-scoped logger in context")
|
||||
}
|
||||
c.String(200, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get(RequestIDHeader) == "" {
|
||||
t.Fatalf("expected response to include X-Request-ID header")
|
||||
}
|
||||
}
|
||||
26
backend/internal/api/middleware/request_logger.go
Normal file
26
backend/internal/api/middleware/request_logger.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/util"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestLogger logs basic request information along with the request_id.
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
latency := time.Since(start)
|
||||
entry := GetRequestLogger(c)
|
||||
entry.WithFields(map[string]any{
|
||||
"status": c.Writer.Status(),
|
||||
"method": c.Request.Method,
|
||||
"path": SanitizePath(c.Request.URL.Path),
|
||||
"latency": latency.String(),
|
||||
"client": util.SanitizeForLog(c.ClientIP()),
|
||||
}).Info("handled request")
|
||||
}
|
||||
}
|
||||
72
backend/internal/api/middleware/request_logger_test.go
Normal file
72
backend/internal/api/middleware/request_logger_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestRequestLoggerSanitizesPath(t *testing.T) {
|
||||
old := logger.Log()
|
||||
buf := &bytes.Buffer{}
|
||||
logger.Init(true, buf)
|
||||
|
||||
longPath := "/" + strings.Repeat("a", 300)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(RequestLogger())
|
||||
router.GET(longPath, func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, longPath, http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
out := buf.String()
|
||||
if strings.Contains(out, strings.Repeat("a", 300)) {
|
||||
t.Fatalf("logged unsanitized long path")
|
||||
}
|
||||
i := strings.Index(out, "path=")
|
||||
if i == -1 {
|
||||
t.Fatalf("could not find path in logs: %s", out)
|
||||
}
|
||||
sub := out[i:]
|
||||
j := strings.Index(sub, " request_id=")
|
||||
if j == -1 {
|
||||
t.Fatalf("could not isolate path field from logs: %s", out)
|
||||
}
|
||||
pathField := sub[len("path="):j]
|
||||
if strings.Contains(pathField, "\n") || strings.Contains(pathField, "\r") {
|
||||
t.Fatalf("path field contains control characters after sanitization: %s", pathField)
|
||||
}
|
||||
_ = old // silence unused var
|
||||
}
|
||||
|
||||
func TestRequestLoggerIncludesRequestID(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
logger.Init(true, buf)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
router.Use(RequestLogger())
|
||||
router.GET("/ok", func(c *gin.Context) { c.String(200, "ok") })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ok", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: %d", w.Code)
|
||||
}
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "request_id") {
|
||||
t.Fatalf("expected log output to include request_id: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "handled request") {
|
||||
t.Fatalf("expected log output to indicate handled request: %s", out)
|
||||
}
|
||||
}
|
||||
62
backend/internal/api/middleware/sanitize.go
Normal file
62
backend/internal/api/middleware/sanitize.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/util"
|
||||
)
|
||||
|
||||
// SanitizeHeaders returns a map of header keys to redacted/sanitized values
|
||||
// for safe logging. Sensitive headers are redacted; other values are
|
||||
// sanitized using util.SanitizeForLog and truncated.
|
||||
func SanitizeHeaders(h http.Header) map[string][]string {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
sensitive := map[string]struct{}{
|
||||
"authorization": {},
|
||||
"cookie": {},
|
||||
"set-cookie": {},
|
||||
"proxy-authorization": {},
|
||||
"x-api-key": {},
|
||||
"x-api-token": {},
|
||||
"x-access-token": {},
|
||||
"x-auth-token": {},
|
||||
"x-api-secret": {},
|
||||
"x-forwarded-for": {},
|
||||
}
|
||||
out := make(map[string][]string, len(h))
|
||||
for k, vals := range h {
|
||||
keyLower := strings.ToLower(k)
|
||||
if _, ok := sensitive[keyLower]; ok {
|
||||
out[k] = []string{"<redacted>"}
|
||||
continue
|
||||
}
|
||||
sanitizedVals := make([]string, 0, len(vals))
|
||||
for _, v := range vals {
|
||||
v2 := util.SanitizeForLog(v)
|
||||
if len(v2) > 200 {
|
||||
v2 = v2[:200]
|
||||
}
|
||||
sanitizedVals = append(sanitizedVals, v2)
|
||||
}
|
||||
out[k] = sanitizedVals
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SanitizePath prepares a request path for safe logging by removing
|
||||
// control characters and truncating long values. It does not include
|
||||
// query parameters.
|
||||
func SanitizePath(p string) string {
|
||||
// remove query string
|
||||
if i := strings.Index(p, "?"); i != -1 {
|
||||
p = p[:i]
|
||||
}
|
||||
p = util.SanitizeForLog(p)
|
||||
if len(p) > 200 {
|
||||
p = p[:200]
|
||||
}
|
||||
return p
|
||||
}
|
||||
55
backend/internal/api/middleware/sanitize_test.go
Normal file
55
backend/internal/api/middleware/sanitize_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSanitizeHeaders(t *testing.T) {
|
||||
t.Run("nil headers", func(t *testing.T) {
|
||||
require.Nil(t, SanitizeHeaders(nil))
|
||||
})
|
||||
|
||||
t.Run("redacts sensitive headers", func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("Authorization", "secret")
|
||||
headers.Set("X-Api-Key", "token")
|
||||
headers.Set("Cookie", "sessionid=abc")
|
||||
|
||||
sanitized := SanitizeHeaders(headers)
|
||||
|
||||
require.Equal(t, []string{"<redacted>"}, sanitized["Authorization"])
|
||||
require.Equal(t, []string{"<redacted>"}, sanitized["X-Api-Key"])
|
||||
require.Equal(t, []string{"<redacted>"}, sanitized["Cookie"])
|
||||
})
|
||||
|
||||
t.Run("sanitizes and truncates values", func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Add("X-Trace", "line1\nline2\r\t")
|
||||
headers.Add("X-Custom", strings.Repeat("a", 210))
|
||||
|
||||
sanitized := SanitizeHeaders(headers)
|
||||
|
||||
traceValue := sanitized["X-Trace"][0]
|
||||
require.NotContains(t, traceValue, "\n")
|
||||
require.NotContains(t, traceValue, "\r")
|
||||
require.NotContains(t, traceValue, "\t")
|
||||
|
||||
customValue := sanitized["X-Custom"][0]
|
||||
require.Equal(t, 200, len(customValue))
|
||||
require.True(t, strings.HasPrefix(customValue, strings.Repeat("a", 200)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizePath(t *testing.T) {
|
||||
paddedPath := "/api/v1/resource/" + strings.Repeat("x", 210) + "?token=secret"
|
||||
|
||||
sanitized := SanitizePath(paddedPath)
|
||||
|
||||
require.NotContains(t, sanitized, "?")
|
||||
require.False(t, strings.ContainsAny(sanitized, "\n\r\t"))
|
||||
require.Equal(t, 200, len(sanitized))
|
||||
}
|
||||
130
backend/internal/api/middleware/security.go
Normal file
130
backend/internal/api/middleware/security.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SecurityHeadersConfig holds configuration for the security headers middleware.
|
||||
type SecurityHeadersConfig struct {
|
||||
// IsDevelopment enables less strict settings for local development
|
||||
IsDevelopment bool
|
||||
// CustomCSPDirectives allows adding extra CSP directives
|
||||
CustomCSPDirectives map[string]string
|
||||
}
|
||||
|
||||
// DefaultSecurityHeadersConfig returns a secure default configuration.
|
||||
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
|
||||
return SecurityHeadersConfig{
|
||||
IsDevelopment: false,
|
||||
CustomCSPDirectives: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityHeaders returns middleware that sets security-related HTTP headers.
|
||||
// This implements Phase 1 of the security hardening plan.
|
||||
func SecurityHeaders(cfg SecurityHeadersConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Build Content-Security-Policy
|
||||
csp := buildCSP(cfg)
|
||||
c.Header("Content-Security-Policy", csp)
|
||||
|
||||
// Strict-Transport-Security (HSTS)
|
||||
// max-age=31536000 = 1 year
|
||||
// includeSubDomains ensures all subdomains also use HTTPS
|
||||
// preload allows browser preload lists (requires submission to hstspreload.org)
|
||||
if !cfg.IsDevelopment {
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
||||
}
|
||||
|
||||
// X-Frame-Options: Prevent clickjacking
|
||||
// DENY prevents any framing; SAMEORIGIN would allow same-origin framing
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
|
||||
// X-Content-Type-Options: Prevent MIME sniffing
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// X-XSS-Protection: Enable browser XSS filtering (legacy but still useful)
|
||||
// mode=block tells browser to block the response if XSS is detected
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Referrer-Policy: Control referrer information sent with requests
|
||||
// strict-origin-when-cross-origin sends full URL for same-origin, origin only for cross-origin
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Permissions-Policy: Restrict browser features
|
||||
// Disable features that aren't needed for security
|
||||
c.Header("Permissions-Policy", buildPermissionsPolicy())
|
||||
|
||||
// Cross-Origin-Opener-Policy: Isolate browsing context
|
||||
// Skip in development mode to avoid browser warnings on HTTP
|
||||
// In production, Caddy always uses HTTPS, so safe to set unconditionally
|
||||
if !cfg.IsDevelopment {
|
||||
c.Header("Cross-Origin-Opener-Policy", "same-origin")
|
||||
}
|
||||
|
||||
// Cross-Origin-Resource-Policy: Prevent cross-origin reads
|
||||
c.Header("Cross-Origin-Resource-Policy", "same-origin")
|
||||
|
||||
// Cross-Origin-Embedder-Policy: Require CORP for cross-origin resources
|
||||
// Note: This can break some external resources, use with caution
|
||||
// c.Header("Cross-Origin-Embedder-Policy", "require-corp")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// buildCSP constructs the Content-Security-Policy header value.
|
||||
func buildCSP(cfg SecurityHeadersConfig) string {
|
||||
// Base CSP directives for a secure single-page application
|
||||
directives := map[string]string{
|
||||
"default-src": "'self'",
|
||||
"script-src": "'self'",
|
||||
"style-src": "'self' 'unsafe-inline'", // unsafe-inline needed for many CSS-in-JS solutions
|
||||
"img-src": "'self' data: https:", // Allow HTTPS images and data URIs
|
||||
"font-src": "'self' data:", // Allow self-hosted fonts and data URIs
|
||||
"connect-src": "'self'", // API connections
|
||||
"frame-src": "'none'", // No iframes
|
||||
"object-src": "'none'", // No plugins (Flash, etc.)
|
||||
"base-uri": "'self'", // Restrict base tag
|
||||
"form-action": "'self'", // Restrict form submissions
|
||||
}
|
||||
|
||||
// In development, allow more sources for hot reloading, etc.
|
||||
if cfg.IsDevelopment {
|
||||
directives["script-src"] = "'self' 'unsafe-inline' 'unsafe-eval'"
|
||||
directives["connect-src"] = "'self' ws: wss:" // WebSocket for HMR
|
||||
}
|
||||
|
||||
// Apply custom directives
|
||||
for key, value := range cfg.CustomCSPDirectives {
|
||||
directives[key] = value
|
||||
}
|
||||
|
||||
// Build the CSP string
|
||||
var parts []string
|
||||
for directive, value := range directives {
|
||||
parts = append(parts, fmt.Sprintf("%s %s", directive, value))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// buildPermissionsPolicy constructs the Permissions-Policy header value.
|
||||
func buildPermissionsPolicy() string {
|
||||
// Disable features we don't need
|
||||
policies := []string{
|
||||
"accelerometer=()",
|
||||
"camera=()",
|
||||
"geolocation=()",
|
||||
"gyroscope=()",
|
||||
"magnetometer=()",
|
||||
"microphone=()",
|
||||
"payment=()",
|
||||
"usb=()",
|
||||
}
|
||||
|
||||
return strings.Join(policies, ", ")
|
||||
}
|
||||
223
backend/internal/api/middleware/security_test.go
Normal file
223
backend/internal/api/middleware/security_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
isDevelopment bool
|
||||
checkHeaders func(t *testing.T, resp *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "production mode sets HSTS",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
hsts := resp.Header().Get("Strict-Transport-Security")
|
||||
assert.Contains(t, hsts, "max-age=31536000")
|
||||
assert.Contains(t, hsts, "includeSubDomains")
|
||||
assert.Contains(t, hsts, "preload")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "development mode skips HSTS",
|
||||
isDevelopment: true,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
hsts := resp.Header().Get("Strict-Transport-Security")
|
||||
assert.Empty(t, hsts)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets X-Frame-Options",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, "DENY", resp.Header().Get("X-Frame-Options"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets X-Content-Type-Options",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, "nosniff", resp.Header().Get("X-Content-Type-Options"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets X-XSS-Protection",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, "1; mode=block", resp.Header().Get("X-XSS-Protection"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets Referrer-Policy",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", resp.Header().Get("Referrer-Policy"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets Content-Security-Policy",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
csp := resp.Header().Get("Content-Security-Policy")
|
||||
assert.NotEmpty(t, csp)
|
||||
assert.Contains(t, csp, "default-src")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "development mode CSP allows unsafe-eval",
|
||||
isDevelopment: true,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
csp := resp.Header().Get("Content-Security-Policy")
|
||||
assert.Contains(t, csp, "unsafe-eval")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets Permissions-Policy",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
pp := resp.Header().Get("Permissions-Policy")
|
||||
assert.NotEmpty(t, pp)
|
||||
assert.Contains(t, pp, "camera=()")
|
||||
assert.Contains(t, pp, "microphone=()")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets Cross-Origin-Opener-Policy in production",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, "same-origin", resp.Header().Get("Cross-Origin-Opener-Policy"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skips Cross-Origin-Opener-Policy in development",
|
||||
isDevelopment: true,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Empty(t, resp.Header().Get("Cross-Origin-Opener-Policy"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sets Cross-Origin-Resource-Policy",
|
||||
isDevelopment: false,
|
||||
checkHeaders: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, "same-origin", resp.Header().Get("Cross-Origin-Resource-Policy"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(SecurityHeaders(SecurityHeadersConfig{
|
||||
IsDevelopment: tt.isDevelopment,
|
||||
}))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
tt.checkHeaders(t, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersCustomCSP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(SecurityHeaders(SecurityHeadersConfig{
|
||||
IsDevelopment: false,
|
||||
CustomCSPDirectives: map[string]string{
|
||||
"frame-src": "'self' https://trusted.com",
|
||||
},
|
||||
}))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", http.NoBody)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
csp := resp.Header().Get("Content-Security-Policy")
|
||||
assert.Contains(t, csp, "frame-src 'self' https://trusted.com")
|
||||
}
|
||||
|
||||
func TestDefaultSecurityHeadersConfig(t *testing.T) {
|
||||
cfg := DefaultSecurityHeadersConfig()
|
||||
assert.False(t, cfg.IsDevelopment)
|
||||
assert.Nil(t, cfg.CustomCSPDirectives)
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_COOP_DevelopmentMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
cfg := SecurityHeadersConfig{IsDevelopment: true}
|
||||
router.Use(SecurityHeaders(cfg))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
assert.Empty(t, resp.Header().Get("Cross-Origin-Opener-Policy"),
|
||||
"COOP header should not be set in development mode")
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_COOP_ProductionMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
cfg := SecurityHeadersConfig{IsDevelopment: false}
|
||||
router.Use(SecurityHeaders(cfg))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, "same-origin", resp.Header().Get("Cross-Origin-Opener-Policy"),
|
||||
"COOP header must be set in production mode")
|
||||
}
|
||||
|
||||
func TestBuildCSP(t *testing.T) {
|
||||
t.Run("production CSP", func(t *testing.T) {
|
||||
csp := buildCSP(SecurityHeadersConfig{IsDevelopment: false})
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
assert.Contains(t, csp, "script-src 'self'")
|
||||
assert.NotContains(t, csp, "unsafe-eval")
|
||||
})
|
||||
|
||||
t.Run("development CSP", func(t *testing.T) {
|
||||
csp := buildCSP(SecurityHeadersConfig{IsDevelopment: true})
|
||||
assert.Contains(t, csp, "unsafe-eval")
|
||||
assert.Contains(t, csp, "ws:")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildPermissionsPolicy(t *testing.T) {
|
||||
pp := buildPermissionsPolicy()
|
||||
|
||||
// Check that dangerous features are disabled
|
||||
disabledFeatures := []string{"camera", "microphone", "geolocation", "payment"}
|
||||
for _, feature := range disabledFeatures {
|
||||
assert.True(t, strings.Contains(pp, feature+"=()"),
|
||||
"Expected %s to be disabled in permissions policy", feature)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user