diff --git a/backend/internal/api/handlers/cerberus_logs_ws.go b/backend/internal/api/handlers/cerberus_logs_ws.go index 62a2df1b..2157ede8 100644 --- a/backend/internal/api/handlers/cerberus_logs_ws.go +++ b/backend/internal/api/handlers/cerberus_logs_ws.go @@ -16,11 +16,15 @@ import ( // CerberusLogsHandler handles WebSocket connections for streaming security logs. type CerberusLogsHandler struct { watcher *services.LogWatcher + tracker *services.WebSocketTracker } // NewCerberusLogsHandler creates a new handler for Cerberus security log streaming. -func NewCerberusLogsHandler(watcher *services.LogWatcher) *CerberusLogsHandler { - return &CerberusLogsHandler{watcher: watcher} +func NewCerberusLogsHandler(watcher *services.LogWatcher, tracker *services.WebSocketTracker) *CerberusLogsHandler { + return &CerberusLogsHandler{ + watcher: watcher, + tracker: tracker, + } } // LiveLogs handles WebSocket connections for Cerberus security log streaming. @@ -52,6 +56,22 @@ func (h *CerberusLogsHandler) LiveLogs(c *gin.Context) { subscriberID := uuid.New().String() logger.Log().WithField("subscriber_id", subscriberID).Info("Cerberus logs WebSocket connected") + // Register connection with tracker if available + if h.tracker != nil { + filters := c.Request.URL.RawQuery + connInfo := &services.ConnectionInfo{ + ID: subscriberID, + Type: "cerberus", + ConnectedAt: time.Now(), + LastActivityAt: time.Now(), + RemoteAddr: c.Request.RemoteAddr, + UserAgent: c.Request.UserAgent(), + Filters: filters, + } + h.tracker.Register(connInfo) + defer h.tracker.Unregister(subscriberID) + } + // Parse query filters sourceFilter := strings.ToLower(c.Query("source")) // waf, crowdsec, ratelimit, acl, normal levelFilter := strings.ToLower(c.Query("level")) // info, warn, error @@ -117,6 +137,11 @@ func (h *CerberusLogsHandler) LiveLogs(c *gin.Context) { return } + // Update activity timestamp + if h.tracker != nil { + h.tracker.UpdateActivity(subscriberID) + } + case <-ticker.C: // Send ping to keep connection alive if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { diff --git a/backend/internal/api/handlers/cerberus_logs_ws_test.go b/backend/internal/api/handlers/cerberus_logs_ws_test.go index 281e732d..6ab8229f 100644 --- a/backend/internal/api/handlers/cerberus_logs_ws_test.go +++ b/backend/internal/api/handlers/cerberus_logs_ws_test.go @@ -29,10 +29,12 @@ func TestCerberusLogsHandler_NewHandler(t *testing.T) { t.Parallel() watcher := services.NewLogWatcher("/tmp/test.log") - handler := NewCerberusLogsHandler(watcher) + tracker := services.NewWebSocketTracker() + handler := NewCerberusLogsHandler(watcher, tracker) assert.NotNil(t, handler) assert.Equal(t, watcher, handler.watcher) + assert.Equal(t, tracker, handler.tracker) } // TestCerberusLogsHandler_SuccessfulConnection verifies WebSocket upgrade. @@ -51,7 +53,7 @@ func TestCerberusLogsHandler_SuccessfulConnection(t *testing.T) { require.NoError(t, err) defer watcher.Stop() - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) // Create test server router := gin.New() @@ -88,7 +90,7 @@ func TestCerberusLogsHandler_ReceiveLogEntries(t *testing.T) { require.NoError(t, err) defer watcher.Stop() - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) // Create test server router := gin.New() @@ -157,7 +159,7 @@ func TestCerberusLogsHandler_SourceFilter(t *testing.T) { require.NoError(t, err) defer watcher.Stop() - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) router := gin.New() router.GET("/ws", handler.LiveLogs) @@ -236,7 +238,7 @@ func TestCerberusLogsHandler_BlockedOnlyFilter(t *testing.T) { require.NoError(t, err) defer watcher.Stop() - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) router := gin.New() router.GET("/ws", handler.LiveLogs) @@ -313,7 +315,7 @@ func TestCerberusLogsHandler_IPFilter(t *testing.T) { require.NoError(t, err) defer watcher.Stop() - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) router := gin.New() router.GET("/ws", handler.LiveLogs) @@ -388,7 +390,7 @@ func TestCerberusLogsHandler_ClientDisconnect(t *testing.T) { require.NoError(t, err) defer watcher.Stop() - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) router := gin.New() router.GET("/ws", handler.LiveLogs) @@ -424,7 +426,7 @@ func TestCerberusLogsHandler_MultipleClients(t *testing.T) { require.NoError(t, err) defer watcher.Stop() - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) router := gin.New() router.GET("/ws", handler.LiveLogs) @@ -486,7 +488,7 @@ func TestCerberusLogsHandler_UpgradeFailure(t *testing.T) { t.Parallel() watcher := services.NewLogWatcher("/tmp/test.log") - handler := NewCerberusLogsHandler(watcher) + handler := NewCerberusLogsHandler(watcher, nil) router := gin.New() router.GET("/ws", handler.LiveLogs) diff --git a/backend/internal/api/handlers/logs_ws.go b/backend/internal/api/handlers/logs_ws.go index 47608f5d..ecb880db 100644 --- a/backend/internal/api/handlers/logs_ws.go +++ b/backend/internal/api/handlers/logs_ws.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/websocket" "github.com/Wikid82/charon/backend/internal/logger" + "github.com/Wikid82/charon/backend/internal/services" ) var upgrader = websocket.Upgrader{ @@ -31,8 +32,26 @@ type LogEntry struct { Fields map[string]interface{} `json:"fields"` } +// LogsWSHandler handles WebSocket connections for live log streaming. +type LogsWSHandler struct { + tracker *services.WebSocketTracker +} + +// NewLogsWSHandler creates a new handler for log streaming. +func NewLogsWSHandler(tracker *services.WebSocketTracker) *LogsWSHandler { + return &LogsWSHandler{tracker: tracker} +} + // LogsWebSocketHandler handles WebSocket connections for live log streaming. +// DEPRECATED: Use NewLogsWSHandler().HandleWebSocket instead. Kept for backward compatibility. func LogsWebSocketHandler(c *gin.Context) { + // For backward compatibility, create a nil tracker if called directly + handler := NewLogsWSHandler(nil) + handler.HandleWebSocket(c) +} + +// HandleWebSocket handles WebSocket connections for live log streaming. +func (h *LogsWSHandler) HandleWebSocket(c *gin.Context) { logger.Log().Info("WebSocket connection attempt received") // Upgrade HTTP connection to WebSocket @@ -52,6 +71,22 @@ func LogsWebSocketHandler(c *gin.Context) { logger.Log().WithField("subscriber_id", subscriberID).Info("WebSocket connection established successfully") + // Register connection with tracker if available + if h.tracker != nil { + filters := c.Request.URL.RawQuery + connInfo := &services.ConnectionInfo{ + ID: subscriberID, + Type: "logs", + ConnectedAt: time.Now(), + LastActivityAt: time.Now(), + RemoteAddr: c.Request.RemoteAddr, + UserAgent: c.Request.UserAgent(), + Filters: filters, + } + h.tracker.Register(connInfo) + defer h.tracker.Unregister(subscriberID) + } + // Parse query parameters for filtering levelFilter := strings.ToLower(c.Query("level")) sourceFilter := strings.ToLower(c.Query("source")) @@ -115,6 +150,11 @@ func LogsWebSocketHandler(c *gin.Context) { return } + // Update activity timestamp + if h.tracker != nil { + h.tracker.UpdateActivity(subscriberID) + } + case <-ticker.C: // Send ping to keep connection alive if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { diff --git a/backend/internal/api/handlers/websocket_status_handler.go b/backend/internal/api/handlers/websocket_status_handler.go new file mode 100644 index 00000000..dfa7d9d0 --- /dev/null +++ b/backend/internal/api/handlers/websocket_status_handler.go @@ -0,0 +1,34 @@ +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/Wikid82/charon/backend/internal/services" +) + +// WebSocketStatusHandler provides endpoints for WebSocket connection monitoring. +type WebSocketStatusHandler struct { + tracker *services.WebSocketTracker +} + +// NewWebSocketStatusHandler creates a new handler for WebSocket status monitoring. +func NewWebSocketStatusHandler(tracker *services.WebSocketTracker) *WebSocketStatusHandler { + return &WebSocketStatusHandler{tracker: tracker} +} + +// GetConnections returns a list of all active WebSocket connections. +func (h *WebSocketStatusHandler) GetConnections(c *gin.Context) { + connections := h.tracker.GetAllConnections() + c.JSON(http.StatusOK, gin.H{ + "connections": connections, + "count": len(connections), + }) +} + +// GetStats returns aggregate statistics about WebSocket connections. +func (h *WebSocketStatusHandler) GetStats(c *gin.Context) { + stats := h.tracker.GetStats() + c.JSON(http.StatusOK, stats) +} diff --git a/backend/internal/api/handlers/websocket_status_handler_test.go b/backend/internal/api/handlers/websocket_status_handler_test.go new file mode 100644 index 00000000..b0fa8abc --- /dev/null +++ b/backend/internal/api/handlers/websocket_status_handler_test.go @@ -0,0 +1,169 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/Wikid82/charon/backend/internal/services" +) + +func TestWebSocketStatusHandler_GetConnections(t *testing.T) { + gin.SetMode(gin.TestMode) + + tracker := services.NewWebSocketTracker() + handler := NewWebSocketStatusHandler(tracker) + + // Register test connections + conn1 := &services.ConnectionInfo{ + ID: "conn-1", + Type: "logs", + ConnectedAt: time.Now(), + LastActivityAt: time.Now(), + RemoteAddr: "192.168.1.1:12345", + UserAgent: "Mozilla/5.0", + Filters: "level=error", + } + conn2 := &services.ConnectionInfo{ + ID: "conn-2", + Type: "cerberus", + ConnectedAt: time.Now(), + LastActivityAt: time.Now(), + RemoteAddr: "192.168.1.2:54321", + UserAgent: "Chrome/90.0", + Filters: "source=waf", + } + + tracker.Register(conn1) + tracker.Register(conn2) + + // Create test request + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/connections", nil) + + // Call handler + handler.GetConnections(c) + + // Verify response + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, float64(2), response["count"]) + connections, ok := response["connections"].([]interface{}) + require.True(t, ok) + assert.Len(t, connections, 2) +} + +func TestWebSocketStatusHandler_GetConnectionsEmpty(t *testing.T) { + gin.SetMode(gin.TestMode) + + tracker := services.NewWebSocketTracker() + handler := NewWebSocketStatusHandler(tracker) + + // Create test request + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/connections", nil) + + // Call handler + handler.GetConnections(c) + + // Verify response + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, float64(0), response["count"]) + connections, ok := response["connections"].([]interface{}) + require.True(t, ok) + assert.Len(t, connections, 0) +} + +func TestWebSocketStatusHandler_GetStats(t *testing.T) { + gin.SetMode(gin.TestMode) + + tracker := services.NewWebSocketTracker() + handler := NewWebSocketStatusHandler(tracker) + + // Register test connections + conn1 := &services.ConnectionInfo{ + ID: "conn-1", + Type: "logs", + ConnectedAt: time.Now(), + } + conn2 := &services.ConnectionInfo{ + ID: "conn-2", + Type: "logs", + ConnectedAt: time.Now(), + } + conn3 := &services.ConnectionInfo{ + ID: "conn-3", + Type: "cerberus", + ConnectedAt: time.Now(), + } + + tracker.Register(conn1) + tracker.Register(conn2) + tracker.Register(conn3) + + // Create test request + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/stats", nil) + + // Call handler + handler.GetStats(c) + + // Verify response + assert.Equal(t, http.StatusOK, w.Code) + + var stats services.ConnectionStats + err := json.Unmarshal(w.Body.Bytes(), &stats) + require.NoError(t, err) + + assert.Equal(t, 3, stats.TotalActive) + assert.Equal(t, 2, stats.LogsConnections) + assert.Equal(t, 1, stats.CerberusConnections) + assert.NotNil(t, stats.OldestConnection) + assert.False(t, stats.LastUpdated.IsZero()) +} + +func TestWebSocketStatusHandler_GetStatsEmpty(t *testing.T) { + gin.SetMode(gin.TestMode) + + tracker := services.NewWebSocketTracker() + handler := NewWebSocketStatusHandler(tracker) + + // Create test request + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/stats", nil) + + // Call handler + handler.GetStats(c) + + // Verify response + assert.Equal(t, http.StatusOK, w.Code) + + var stats services.ConnectionStats + err := json.Unmarshal(w.Body.Bytes(), &stats) + require.NoError(t, err) + + assert.Equal(t, 0, stats.TotalActive) + assert.Equal(t, 0, stats.LogsConnections) + assert.Equal(t, 0, stats.CerberusConnections) + assert.Nil(t, stats.OldestConnection) + assert.False(t, stats.LastUpdated.IsZero()) +} diff --git a/backend/internal/api/routes/routes.go b/backend/internal/api/routes/routes.go index 12374854..d2a493af 100644 --- a/backend/internal/api/routes/routes.go +++ b/backend/internal/api/routes/routes.go @@ -119,6 +119,10 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error { logService := services.NewLogService(&cfg) logsHandler := handlers.NewLogsHandler(logService) + // WebSocket tracker for connection monitoring + wsTracker := services.NewWebSocketTracker() + wsStatusHandler := handlers.NewWebSocketStatusHandler(wsTracker) + // Notification Service (needed for multiple handlers) notificationService := services.NewNotificationService(db) @@ -160,7 +164,14 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error { protected.GET("/logs", logsHandler.List) protected.GET("/logs/:filename", logsHandler.Read) protected.GET("/logs/:filename/download", logsHandler.Download) - protected.GET("/logs/live", handlers.LogsWebSocketHandler) + + // WebSocket endpoints + logsWSHandler := handlers.NewLogsWSHandler(wsTracker) + protected.GET("/logs/live", logsWSHandler.HandleWebSocket) + + // WebSocket status monitoring + protected.GET("/websocket/connections", wsStatusHandler.GetConnections) + protected.GET("/websocket/stats", wsStatusHandler.GetStats) // Security Notification Settings securityNotificationService := services.NewSecurityNotificationService(db) @@ -395,7 +406,7 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error { if err := logWatcher.Start(context.Background()); err != nil { logger.Log().WithError(err).Error("Failed to start security log watcher") } - cerberusLogsHandler := handlers.NewCerberusLogsHandler(logWatcher) + cerberusLogsHandler := handlers.NewCerberusLogsHandler(logWatcher, wsTracker) protected.GET("/cerberus/logs/ws", cerberusLogsHandler.LiveLogs) // Access Lists diff --git a/backend/internal/services/websocket_tracker.go b/backend/internal/services/websocket_tracker.go new file mode 100644 index 00000000..1a2ed7ad --- /dev/null +++ b/backend/internal/services/websocket_tracker.go @@ -0,0 +1,140 @@ +// Package services provides business logic services for the application. +package services + +import ( + "sync" + "time" + + "github.com/Wikid82/charon/backend/internal/logger" +) + +// ConnectionInfo tracks information about a single WebSocket connection. +type ConnectionInfo struct { + ID string `json:"id"` + Type string `json:"type"` // "logs" or "cerberus" + ConnectedAt time.Time `json:"connected_at"` + LastActivityAt time.Time `json:"last_activity_at"` + RemoteAddr string `json:"remote_addr,omitempty"` + UserAgent string `json:"user_agent,omitempty"` + Filters string `json:"filters,omitempty"` // Query parameters used for filtering +} + +// ConnectionStats provides aggregate statistics about WebSocket connections. +type ConnectionStats struct { + TotalActive int `json:"total_active"` + LogsConnections int `json:"logs_connections"` + CerberusConnections int `json:"cerberus_connections"` + OldestConnection *time.Time `json:"oldest_connection,omitempty"` + LastUpdated time.Time `json:"last_updated"` +} + +// WebSocketTracker tracks active WebSocket connections and provides statistics. +type WebSocketTracker struct { + mu sync.RWMutex + connections map[string]*ConnectionInfo +} + +// NewWebSocketTracker creates a new WebSocket connection tracker. +func NewWebSocketTracker() *WebSocketTracker { + return &WebSocketTracker{ + connections: make(map[string]*ConnectionInfo), + } +} + +// Register adds a new WebSocket connection to tracking. +func (t *WebSocketTracker) Register(conn *ConnectionInfo) { + t.mu.Lock() + defer t.mu.Unlock() + + t.connections[conn.ID] = conn + logger.Log().WithField("connection_id", conn.ID). + WithField("type", conn.Type). + WithField("remote_addr", conn.RemoteAddr). + Debug("WebSocket connection registered") +} + +// Unregister removes a WebSocket connection from tracking. +func (t *WebSocketTracker) Unregister(connectionID string) { + t.mu.Lock() + defer t.mu.Unlock() + + if conn, exists := t.connections[connectionID]; exists { + duration := time.Since(conn.ConnectedAt) + logger.Log().WithField("connection_id", connectionID). + WithField("type", conn.Type). + WithField("duration", duration.String()). + Debug("WebSocket connection unregistered") + delete(t.connections, connectionID) + } +} + +// UpdateActivity updates the last activity timestamp for a connection. +func (t *WebSocketTracker) UpdateActivity(connectionID string) { + t.mu.Lock() + defer t.mu.Unlock() + + if conn, exists := t.connections[connectionID]; exists { + conn.LastActivityAt = time.Now() + } +} + +// GetConnection retrieves information about a specific connection. +func (t *WebSocketTracker) GetConnection(connectionID string) (*ConnectionInfo, bool) { + t.mu.RLock() + defer t.mu.RUnlock() + + conn, exists := t.connections[connectionID] + return conn, exists +} + +// GetAllConnections returns a slice of all active connections. +func (t *WebSocketTracker) GetAllConnections() []*ConnectionInfo { + t.mu.RLock() + defer t.mu.RUnlock() + + connections := make([]*ConnectionInfo, 0, len(t.connections)) + for _, conn := range t.connections { + // Create a copy to avoid race conditions + connCopy := *conn + connections = append(connections, &connCopy) + } + return connections +} + +// GetStats returns aggregate statistics about WebSocket connections. +func (t *WebSocketTracker) GetStats() *ConnectionStats { + t.mu.RLock() + defer t.mu.RUnlock() + + stats := &ConnectionStats{ + TotalActive: len(t.connections), + LogsConnections: 0, + CerberusConnections: 0, + LastUpdated: time.Now(), + } + + var oldestTime *time.Time + for _, conn := range t.connections { + switch conn.Type { + case "logs": + stats.LogsConnections++ + case "cerberus": + stats.CerberusConnections++ + } + + if oldestTime == nil || conn.ConnectedAt.Before(*oldestTime) { + t := conn.ConnectedAt + oldestTime = &t + } + } + + stats.OldestConnection = oldestTime + return stats +} + +// GetCount returns the total number of active connections. +func (t *WebSocketTracker) GetCount() int { + t.mu.RLock() + defer t.mu.RUnlock() + return len(t.connections) +} diff --git a/backend/internal/services/websocket_tracker_test.go b/backend/internal/services/websocket_tracker_test.go new file mode 100644 index 00000000..24526c88 --- /dev/null +++ b/backend/internal/services/websocket_tracker_test.go @@ -0,0 +1,224 @@ +package services + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewWebSocketTracker(t *testing.T) { + tracker := NewWebSocketTracker() + assert.NotNil(t, tracker) + assert.NotNil(t, tracker.connections) + assert.Equal(t, 0, tracker.GetCount()) +} + +func TestWebSocketTracker_Register(t *testing.T) { + tracker := NewWebSocketTracker() + + conn := &ConnectionInfo{ + ID: "test-conn-1", + Type: "logs", + ConnectedAt: time.Now(), + LastActivityAt: time.Now(), + RemoteAddr: "192.168.1.1:12345", + UserAgent: "Mozilla/5.0", + Filters: "level=error", + } + + tracker.Register(conn) + assert.Equal(t, 1, tracker.GetCount()) + + // Verify the connection is retrievable + retrieved, exists := tracker.GetConnection("test-conn-1") + assert.True(t, exists) + assert.Equal(t, conn.ID, retrieved.ID) + assert.Equal(t, conn.Type, retrieved.Type) +} + +func TestWebSocketTracker_Unregister(t *testing.T) { + tracker := NewWebSocketTracker() + + conn := &ConnectionInfo{ + ID: "test-conn-1", + Type: "cerberus", + ConnectedAt: time.Now(), + } + + tracker.Register(conn) + assert.Equal(t, 1, tracker.GetCount()) + + tracker.Unregister("test-conn-1") + assert.Equal(t, 0, tracker.GetCount()) + + // Verify the connection is no longer retrievable + _, exists := tracker.GetConnection("test-conn-1") + assert.False(t, exists) +} + +func TestWebSocketTracker_UnregisterNonExistent(t *testing.T) { + tracker := NewWebSocketTracker() + + // Should not panic + tracker.Unregister("non-existent-id") + assert.Equal(t, 0, tracker.GetCount()) +} + +func TestWebSocketTracker_UpdateActivity(t *testing.T) { + tracker := NewWebSocketTracker() + + initialTime := time.Now().Add(-1 * time.Hour) + conn := &ConnectionInfo{ + ID: "test-conn-1", + Type: "logs", + ConnectedAt: initialTime, + LastActivityAt: initialTime, + } + + tracker.Register(conn) + + // Wait a moment to ensure time difference + time.Sleep(10 * time.Millisecond) + + tracker.UpdateActivity("test-conn-1") + + retrieved, exists := tracker.GetConnection("test-conn-1") + require.True(t, exists) + assert.True(t, retrieved.LastActivityAt.After(initialTime)) +} + +func TestWebSocketTracker_UpdateActivityNonExistent(t *testing.T) { + tracker := NewWebSocketTracker() + + // Should not panic + tracker.UpdateActivity("non-existent-id") +} + +func TestWebSocketTracker_GetAllConnections(t *testing.T) { + tracker := NewWebSocketTracker() + + conn1 := &ConnectionInfo{ + ID: "conn-1", + Type: "logs", + ConnectedAt: time.Now(), + } + conn2 := &ConnectionInfo{ + ID: "conn-2", + Type: "cerberus", + ConnectedAt: time.Now(), + } + + tracker.Register(conn1) + tracker.Register(conn2) + + connections := tracker.GetAllConnections() + assert.Equal(t, 2, len(connections)) + + // Verify both connections are present (order may vary) + ids := make(map[string]bool) + for _, conn := range connections { + ids[conn.ID] = true + } + assert.True(t, ids["conn-1"]) + assert.True(t, ids["conn-2"]) +} + +func TestWebSocketTracker_GetStats(t *testing.T) { + tracker := NewWebSocketTracker() + + now := time.Now() + oldestTime := now.Add(-10 * time.Minute) + + conn1 := &ConnectionInfo{ + ID: "conn-1", + Type: "logs", + ConnectedAt: now, + } + conn2 := &ConnectionInfo{ + ID: "conn-2", + Type: "cerberus", + ConnectedAt: oldestTime, + } + conn3 := &ConnectionInfo{ + ID: "conn-3", + Type: "logs", + ConnectedAt: now.Add(-5 * time.Minute), + } + + tracker.Register(conn1) + tracker.Register(conn2) + tracker.Register(conn3) + + stats := tracker.GetStats() + assert.Equal(t, 3, stats.TotalActive) + assert.Equal(t, 2, stats.LogsConnections) + assert.Equal(t, 1, stats.CerberusConnections) + assert.NotNil(t, stats.OldestConnection) + assert.True(t, stats.OldestConnection.Equal(oldestTime)) + assert.False(t, stats.LastUpdated.IsZero()) +} + +func TestWebSocketTracker_GetStatsEmpty(t *testing.T) { + tracker := NewWebSocketTracker() + + stats := tracker.GetStats() + assert.Equal(t, 0, stats.TotalActive) + assert.Equal(t, 0, stats.LogsConnections) + assert.Equal(t, 0, stats.CerberusConnections) + assert.Nil(t, stats.OldestConnection) + assert.False(t, stats.LastUpdated.IsZero()) +} + +func TestWebSocketTracker_ConcurrentAccess(t *testing.T) { + tracker := NewWebSocketTracker() + + // Test concurrent registration + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(id int) { + conn := &ConnectionInfo{ + ID: string(rune('a' + id)), + Type: "logs", + ConnectedAt: time.Now(), + } + tracker.Register(conn) + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + assert.Equal(t, 10, tracker.GetCount()) + + // Test concurrent read + for i := 0; i < 10; i++ { + go func() { + _ = tracker.GetAllConnections() + _ = tracker.GetStats() + done <- true + }() + } + + for i := 0; i < 10; i++ { + <-done + } + + // Test concurrent unregister + for i := 0; i < 10; i++ { + go func(id int) { + tracker.Unregister(string(rune('a' + id))) + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } + + assert.Equal(t, 0, tracker.GetCount()) +}