fix: improve error handling and session management in various handlers and middleware
This commit is contained in:
@@ -58,7 +58,13 @@ func (h *AccessListHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, acl)
|
||||
createdACL, err := h.service.GetByUUID(acl.UUID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, createdACL)
|
||||
}
|
||||
|
||||
// List handles GET /api/v1/access-lists
|
||||
@@ -100,12 +106,14 @@ func (h *AccessListHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
var updates models.AccessList
|
||||
if err := c.ShouldBindJSON(&updates); err != nil {
|
||||
err = c.ShouldBindJSON(&updates)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Update(acl.ID, &updates); err != nil {
|
||||
err = h.service.Update(acl.ID, &updates)
|
||||
if err != nil {
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
@@ -114,8 +122,16 @@ func (h *AccessListHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch updated record
|
||||
updatedAcl, _ := h.service.GetByID(acl.ID)
|
||||
updatedAcl, err := h.service.GetByID(acl.ID)
|
||||
if err != nil {
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, updatedAcl)
|
||||
}
|
||||
|
||||
|
||||
@@ -323,17 +323,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
func (h *AuthHandler) Verify(c *gin.Context) {
|
||||
// Extract token from cookie or Authorization header
|
||||
var tokenString string
|
||||
|
||||
// Try cookie first (most common for browser requests)
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil && cookie != "" {
|
||||
tokenString = cookie
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
|
||||
// Fall back to Authorization header
|
||||
// Fall back to cookie (most common for browser requests)
|
||||
if tokenString == "" {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil && cookie != "" {
|
||||
tokenString = cookie
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,15 +391,14 @@ func (h *AuthHandler) Verify(c *gin.Context) {
|
||||
func (h *AuthHandler) VerifyStatus(c *gin.Context) {
|
||||
// Extract token
|
||||
var tokenString string
|
||||
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil && cookie != "" {
|
||||
tokenString = cookie
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
|
||||
if tokenString == "" {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil && cookie != "" {
|
||||
tokenString = cookie
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -101,8 +101,18 @@ func (h *SecurityHandler) GetStatus(c *gin.Context) {
|
||||
var setting struct{ Value string }
|
||||
|
||||
// Cerberus enabled override
|
||||
cerberusOverrideApplied := false
|
||||
if err := h.db.Raw("SELECT value FROM settings WHERE key = ? LIMIT 1", "feature.cerberus.enabled").Scan(&setting).Error; err == nil && setting.Value != "" {
|
||||
enabled = strings.EqualFold(setting.Value, "true")
|
||||
cerberusOverrideApplied = true
|
||||
}
|
||||
|
||||
// Backward-compatible Cerberus enabled override
|
||||
if !cerberusOverrideApplied {
|
||||
setting = struct{ Value string }{}
|
||||
if err := h.db.Raw("SELECT value FROM settings WHERE key = ? LIMIT 1", "security.cerberus.enabled").Scan(&setting).Error; err == nil && setting.Value != "" {
|
||||
enabled = strings.EqualFold(setting.Value, "true")
|
||||
}
|
||||
}
|
||||
|
||||
// WAF enabled override
|
||||
@@ -1147,6 +1157,20 @@ func (h *SecurityHandler) toggleSecurityModule(c *gin.Context, settingKey string
|
||||
return
|
||||
}
|
||||
|
||||
if settingKey == "feature.cerberus.enabled" {
|
||||
legacyCerberus := models.Setting{
|
||||
Key: "security.cerberus.enabled",
|
||||
Value: value,
|
||||
Category: "security",
|
||||
Type: "bool",
|
||||
}
|
||||
if err := h.db.Where(models.Setting{Key: legacyCerberus.Key}).Assign(legacyCerberus).FirstOrCreate(&legacyCerberus).Error; err != nil {
|
||||
log.WithError(err).Error("Failed to sync legacy Cerberus setting")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update security module"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if settingKey == "security.acl.enabled" && enabled {
|
||||
var count int64
|
||||
if err := h.db.Model(&models.SecurityConfig{}).Count(&count).Error; err != nil {
|
||||
@@ -1206,8 +1230,8 @@ func (h *SecurityHandler) toggleSecurityModule(c *gin.Context, settingKey string
|
||||
}
|
||||
|
||||
type settingSnapshot struct {
|
||||
exists bool
|
||||
setting models.Setting
|
||||
exists bool
|
||||
setting models.Setting
|
||||
}
|
||||
|
||||
func (h *SecurityHandler) snapshotSettings(keys []string) (map[string]settingSnapshot, error) {
|
||||
|
||||
@@ -43,15 +43,13 @@ func AuthMiddleware(authService *services.AuthService) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
func extractAuthToken(c *gin.Context) (string, bool) {
|
||||
authHeader := ""
|
||||
|
||||
// Try cookie first for browser flows (including WebSocket upgrades)
|
||||
if cookieToken := extractAuthCookieToken(c); cookieToken != "" {
|
||||
authHeader = "Bearer " + cookieToken
|
||||
}
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
|
||||
// Fall back to cookie for browser flows (including WebSocket upgrades)
|
||||
if authHeader == "" {
|
||||
authHeader = c.GetHeader("Authorization")
|
||||
if cookieToken := extractAuthCookieToken(c); cookieToken != "" {
|
||||
authHeader = "Bearer " + cookieToken
|
||||
}
|
||||
}
|
||||
|
||||
// DEPRECATED: Query parameter authentication for WebSocket connections
|
||||
|
||||
@@ -55,7 +55,11 @@ client.interceptors.response.use(
|
||||
console.warn('Authentication failed:', error.config?.url);
|
||||
// Skip auth error handling for login/auth endpoints to avoid redirect loops
|
||||
const url = error.config?.url || '';
|
||||
const isAuthEndpoint = url.includes('/auth/login') || url.includes('/auth/me');
|
||||
const isAuthEndpoint =
|
||||
url.includes('/auth/login') ||
|
||||
url.includes('/auth/me') ||
|
||||
url.includes('/auth/logout') ||
|
||||
url.includes('/auth/refresh');
|
||||
if (onAuthError && !isAuthEndpoint) {
|
||||
onAuthError();
|
||||
}
|
||||
|
||||
@@ -1,22 +1,41 @@
|
||||
import { useState, useEffect, useCallback, type ReactNode, type FC } from 'react';
|
||||
import { useState, useEffect, useCallback, useRef, type ReactNode, type FC } from 'react';
|
||||
import client, { setAuthToken, setAuthErrorHandler } from '../api/client';
|
||||
import { AuthContext, User } from './AuthContextValue';
|
||||
|
||||
export const AuthProvider: FC<{ children: ReactNode }> = ({ children }) => {
|
||||
const [user, setUser] = useState<User | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const authRequestVersionRef = useRef(0);
|
||||
|
||||
const fetchSessionUser = useCallback(async (): Promise<User> => {
|
||||
const response = await fetch('/api/v1/auth/me', {
|
||||
method: 'GET',
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Session validation failed');
|
||||
}
|
||||
|
||||
return response.json() as Promise<User>;
|
||||
}, []);
|
||||
|
||||
const invalidateAuthRequests = useCallback(() => {
|
||||
authRequestVersionRef.current += 1;
|
||||
}, []);
|
||||
|
||||
// Handle session expiry by clearing auth state and redirecting to login
|
||||
const handleAuthError = useCallback(() => {
|
||||
console.log('Session expired, redirecting to login');
|
||||
console.warn('Session expired, clearing auth state');
|
||||
invalidateAuthRequests();
|
||||
localStorage.removeItem('charon_auth_token');
|
||||
setAuthToken(null);
|
||||
setUser(null);
|
||||
// Use window.location for full page redirect to clear any stale state
|
||||
if (window.location.pathname !== '/login') {
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}, []);
|
||||
setIsLoading(false);
|
||||
}, [invalidateAuthRequests]);
|
||||
|
||||
// Register auth error handler on mount
|
||||
useEffect(() => {
|
||||
@@ -25,6 +44,9 @@ export const AuthProvider: FC<{ children: ReactNode }> = ({ children }) => {
|
||||
|
||||
useEffect(() => {
|
||||
const checkAuth = async () => {
|
||||
const requestVersion = authRequestVersionRef.current + 1;
|
||||
authRequestVersionRef.current = requestVersion;
|
||||
|
||||
try {
|
||||
const stored = localStorage.getItem('charon_auth_token');
|
||||
if (stored) {
|
||||
@@ -33,54 +55,72 @@ export const AuthProvider: FC<{ children: ReactNode }> = ({ children }) => {
|
||||
// No token in localStorage - don't even try to authenticate
|
||||
// This prevents re-authentication via HttpOnly cookie after logout
|
||||
setAuthToken(null);
|
||||
setUser(null);
|
||||
setIsLoading(false);
|
||||
if (authRequestVersionRef.current === requestVersion) {
|
||||
setUser(null);
|
||||
setIsLoading(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const response = await client.get('/auth/me');
|
||||
setUser(response.data);
|
||||
const response = await fetchSessionUser();
|
||||
if (authRequestVersionRef.current === requestVersion) {
|
||||
setUser(response);
|
||||
}
|
||||
} catch {
|
||||
setAuthToken(null);
|
||||
setUser(null);
|
||||
if (authRequestVersionRef.current === requestVersion) {
|
||||
setAuthToken(null);
|
||||
setUser(null);
|
||||
}
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
if (authRequestVersionRef.current === requestVersion) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
checkAuth();
|
||||
}, []);
|
||||
}, [fetchSessionUser]);
|
||||
|
||||
const login = useCallback(async (token?: string) => {
|
||||
const requestVersion = authRequestVersionRef.current + 1;
|
||||
authRequestVersionRef.current = requestVersion;
|
||||
setIsLoading(true);
|
||||
|
||||
const login = async (token?: string) => {
|
||||
if (token) {
|
||||
localStorage.setItem('charon_auth_token', token);
|
||||
setAuthToken(token);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await client.get<User>('/auth/me');
|
||||
setUser(response.data);
|
||||
const response = await fetchSessionUser();
|
||||
if (authRequestVersionRef.current === requestVersion) {
|
||||
setUser(response);
|
||||
}
|
||||
} catch (error) {
|
||||
setUser(null);
|
||||
setAuthToken(null);
|
||||
localStorage.removeItem('charon_auth_token');
|
||||
if (authRequestVersionRef.current === requestVersion) {
|
||||
setUser(null);
|
||||
setAuthToken(null);
|
||||
localStorage.removeItem('charon_auth_token');
|
||||
}
|
||||
throw error;
|
||||
} finally {
|
||||
if (authRequestVersionRef.current === requestVersion) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
}, [fetchSessionUser]);
|
||||
|
||||
const logout = async () => {
|
||||
invalidateAuthRequests();
|
||||
localStorage.removeItem('charon_auth_token');
|
||||
setAuthToken(null);
|
||||
setUser(null);
|
||||
setIsLoading(false);
|
||||
|
||||
try {
|
||||
await client.post('/auth/logout');
|
||||
} catch (error) {
|
||||
console.error("Logout failed", error);
|
||||
}
|
||||
localStorage.removeItem('charon_auth_token');
|
||||
setAuthToken(null);
|
||||
setUser(null);
|
||||
|
||||
// Force navigation to login with full page reload to clear any stale state
|
||||
// This ensures all React state and cookies are cleared
|
||||
if (window.location.pathname !== '/login') {
|
||||
window.location.href = '/login';
|
||||
}
|
||||
};
|
||||
|
||||
const changePassword = async (oldPassword: string, newPassword: string) => {
|
||||
|
||||
Reference in New Issue
Block a user