diff --git a/backend/internal/api/middleware/auth.go b/backend/internal/api/middleware/auth.go index fc9ab274..90b7a3e5 100644 --- a/backend/internal/api/middleware/auth.go +++ b/backend/internal/api/middleware/auth.go @@ -49,13 +49,15 @@ func AuthMiddleware(authService *services.AuthService) gin.HandlerFunc { } func extractAuthToken(c *gin.Context) (string, bool) { - authHeader := c.GetHeader("Authorization") + authHeader := "" + + // Try cookie first for browser flows (including WebSocket upgrades) + if cookieToken := extractAuthCookieToken(c); cookieToken != "" { + authHeader = "Bearer " + cookieToken + } if authHeader == "" { - // Try cookie first for browser flows (including WebSocket upgrades) - if cookie, err := c.Cookie("auth_token"); err == nil && cookie != "" { - authHeader = "Bearer " + cookie - } + authHeader = c.GetHeader("Authorization") } // DEPRECATED: Query parameter authentication for WebSocket connections @@ -80,6 +82,27 @@ func extractAuthToken(c *gin.Context) (string, bool) { return tokenString, true } +func extractAuthCookieToken(c *gin.Context) string { + if c.Request == nil { + return "" + } + + token := "" + for _, cookie := range c.Request.Cookies() { + if cookie.Name != "auth_token" { + continue + } + + if cookie.Value == "" { + continue + } + + token = cookie.Value + } + + return token +} + func RequireRole(role string) gin.HandlerFunc { return func(c *gin.Context) { userRole, exists := c.Get("role") diff --git a/backend/internal/api/middleware/auth_test.go b/backend/internal/api/middleware/auth_test.go index bb810dc7..a39feae5 100644 --- a/backend/internal/api/middleware/auth_test.go +++ b/backend/internal/api/middleware/auth_test.go @@ -155,10 +155,37 @@ func TestAuthMiddleware_ValidToken(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) } -func TestAuthMiddleware_PrefersAuthorizationHeader(t *testing.T) { +func TestAuthMiddleware_PrefersCookieOverAuthorizationHeader(t *testing.T) { authService := setupAuthService(t) - user, _ := authService.Register("header@example.com", "password", "Header User") - token, _ := authService.GenerateToken(user) + cookieUser, _ := authService.Register("cookie-header@example.com", "password", "Cookie Header User") + cookieToken, _ := authService.GenerateToken(cookieUser) + headerUser, _ := authService.Register("header@example.com", "password", "Header User") + headerToken, _ := authService.GenerateToken(headerUser) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(AuthMiddleware(authService)) + r.GET("/test", func(c *gin.Context) { + userID, _ := c.Get("userID") + assert.Equal(t, cookieUser.ID, userID) + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest("GET", "/test", http.NoBody) + req.Header.Set("Authorization", "Bearer "+headerToken) + req.AddCookie(&http.Cookie{Name: "auth_token", Value: cookieToken}) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthMiddleware_UsesCookieWhenAuthorizationHeaderIsInvalid(t *testing.T) { + authService := setupAuthService(t) + user, err := authService.Register("cookie-valid@example.com", "password", "Cookie Valid User") + require.NoError(t, err) + token, err := authService.GenerateToken(user) + require.NoError(t, err) gin.SetMode(gin.TestMode) r := gin.New() @@ -169,9 +196,36 @@ func TestAuthMiddleware_PrefersAuthorizationHeader(t *testing.T) { c.Status(http.StatusOK) }) - req, _ := http.NewRequest("GET", "/test", http.NoBody) - req.Header.Set("Authorization", "Bearer "+token) - req.AddCookie(&http.Cookie{Name: "auth_token", Value: "stale"}) + req, err := http.NewRequest("GET", "/test", http.NoBody) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer invalid-token") + req.AddCookie(&http.Cookie{Name: "auth_token", Value: token}) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthMiddleware_UsesLastNonEmptyCookieWhenDuplicateCookiesExist(t *testing.T) { + authService := setupAuthService(t) + user, err := authService.Register("dupecookie@example.com", "password", "Dup Cookie User") + require.NoError(t, err) + token, err := authService.GenerateToken(user) + require.NoError(t, err) + + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(AuthMiddleware(authService)) + r.GET("/test", func(c *gin.Context) { + userID, _ := c.Get("userID") + assert.Equal(t, user.ID, userID) + c.Status(http.StatusOK) + }) + + req, err := http.NewRequest("GET", "/test", http.NoBody) + require.NoError(t, err) + req.AddCookie(&http.Cookie{Name: "auth_token", Value: ""}) + req.AddCookie(&http.Cookie{Name: "auth_token", Value: token}) w := httptest.NewRecorder() r.ServeHTTP(w, req)