feat: add WebSocket connection tracking backend

Co-authored-by: Wikid82 <176516789+Wikid82@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-18 18:04:36 +00:00
parent b44064e15d
commit 854a940536
8 changed files with 658 additions and 13 deletions

View File

@@ -16,11 +16,15 @@ import (
// CerberusLogsHandler handles WebSocket connections for streaming security logs.
type CerberusLogsHandler struct {
watcher *services.LogWatcher
tracker *services.WebSocketTracker
}
// NewCerberusLogsHandler creates a new handler for Cerberus security log streaming.
func NewCerberusLogsHandler(watcher *services.LogWatcher) *CerberusLogsHandler {
return &CerberusLogsHandler{watcher: watcher}
func NewCerberusLogsHandler(watcher *services.LogWatcher, tracker *services.WebSocketTracker) *CerberusLogsHandler {
return &CerberusLogsHandler{
watcher: watcher,
tracker: tracker,
}
}
// LiveLogs handles WebSocket connections for Cerberus security log streaming.
@@ -52,6 +56,22 @@ func (h *CerberusLogsHandler) LiveLogs(c *gin.Context) {
subscriberID := uuid.New().String()
logger.Log().WithField("subscriber_id", subscriberID).Info("Cerberus logs WebSocket connected")
// Register connection with tracker if available
if h.tracker != nil {
filters := c.Request.URL.RawQuery
connInfo := &services.ConnectionInfo{
ID: subscriberID,
Type: "cerberus",
ConnectedAt: time.Now(),
LastActivityAt: time.Now(),
RemoteAddr: c.Request.RemoteAddr,
UserAgent: c.Request.UserAgent(),
Filters: filters,
}
h.tracker.Register(connInfo)
defer h.tracker.Unregister(subscriberID)
}
// Parse query filters
sourceFilter := strings.ToLower(c.Query("source")) // waf, crowdsec, ratelimit, acl, normal
levelFilter := strings.ToLower(c.Query("level")) // info, warn, error
@@ -117,6 +137,11 @@ func (h *CerberusLogsHandler) LiveLogs(c *gin.Context) {
return
}
// Update activity timestamp
if h.tracker != nil {
h.tracker.UpdateActivity(subscriberID)
}
case <-ticker.C:
// Send ping to keep connection alive
if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {

View File

@@ -29,10 +29,12 @@ func TestCerberusLogsHandler_NewHandler(t *testing.T) {
t.Parallel()
watcher := services.NewLogWatcher("/tmp/test.log")
handler := NewCerberusLogsHandler(watcher)
tracker := services.NewWebSocketTracker()
handler := NewCerberusLogsHandler(watcher, tracker)
assert.NotNil(t, handler)
assert.Equal(t, watcher, handler.watcher)
assert.Equal(t, tracker, handler.tracker)
}
// TestCerberusLogsHandler_SuccessfulConnection verifies WebSocket upgrade.
@@ -51,7 +53,7 @@ func TestCerberusLogsHandler_SuccessfulConnection(t *testing.T) {
require.NoError(t, err)
defer watcher.Stop()
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
// Create test server
router := gin.New()
@@ -88,7 +90,7 @@ func TestCerberusLogsHandler_ReceiveLogEntries(t *testing.T) {
require.NoError(t, err)
defer watcher.Stop()
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
// Create test server
router := gin.New()
@@ -157,7 +159,7 @@ func TestCerberusLogsHandler_SourceFilter(t *testing.T) {
require.NoError(t, err)
defer watcher.Stop()
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
router := gin.New()
router.GET("/ws", handler.LiveLogs)
@@ -236,7 +238,7 @@ func TestCerberusLogsHandler_BlockedOnlyFilter(t *testing.T) {
require.NoError(t, err)
defer watcher.Stop()
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
router := gin.New()
router.GET("/ws", handler.LiveLogs)
@@ -313,7 +315,7 @@ func TestCerberusLogsHandler_IPFilter(t *testing.T) {
require.NoError(t, err)
defer watcher.Stop()
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
router := gin.New()
router.GET("/ws", handler.LiveLogs)
@@ -388,7 +390,7 @@ func TestCerberusLogsHandler_ClientDisconnect(t *testing.T) {
require.NoError(t, err)
defer watcher.Stop()
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
router := gin.New()
router.GET("/ws", handler.LiveLogs)
@@ -424,7 +426,7 @@ func TestCerberusLogsHandler_MultipleClients(t *testing.T) {
require.NoError(t, err)
defer watcher.Stop()
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
router := gin.New()
router.GET("/ws", handler.LiveLogs)
@@ -486,7 +488,7 @@ func TestCerberusLogsHandler_UpgradeFailure(t *testing.T) {
t.Parallel()
watcher := services.NewLogWatcher("/tmp/test.log")
handler := NewCerberusLogsHandler(watcher)
handler := NewCerberusLogsHandler(watcher, nil)
router := gin.New()
router.GET("/ws", handler.LiveLogs)

View File

@@ -10,6 +10,7 @@ import (
"github.com/gorilla/websocket"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/Wikid82/charon/backend/internal/services"
)
var upgrader = websocket.Upgrader{
@@ -31,8 +32,26 @@ type LogEntry struct {
Fields map[string]interface{} `json:"fields"`
}
// LogsWSHandler handles WebSocket connections for live log streaming.
type LogsWSHandler struct {
tracker *services.WebSocketTracker
}
// NewLogsWSHandler creates a new handler for log streaming.
func NewLogsWSHandler(tracker *services.WebSocketTracker) *LogsWSHandler {
return &LogsWSHandler{tracker: tracker}
}
// LogsWebSocketHandler handles WebSocket connections for live log streaming.
// DEPRECATED: Use NewLogsWSHandler().HandleWebSocket instead. Kept for backward compatibility.
func LogsWebSocketHandler(c *gin.Context) {
// For backward compatibility, create a nil tracker if called directly
handler := NewLogsWSHandler(nil)
handler.HandleWebSocket(c)
}
// HandleWebSocket handles WebSocket connections for live log streaming.
func (h *LogsWSHandler) HandleWebSocket(c *gin.Context) {
logger.Log().Info("WebSocket connection attempt received")
// Upgrade HTTP connection to WebSocket
@@ -52,6 +71,22 @@ func LogsWebSocketHandler(c *gin.Context) {
logger.Log().WithField("subscriber_id", subscriberID).Info("WebSocket connection established successfully")
// Register connection with tracker if available
if h.tracker != nil {
filters := c.Request.URL.RawQuery
connInfo := &services.ConnectionInfo{
ID: subscriberID,
Type: "logs",
ConnectedAt: time.Now(),
LastActivityAt: time.Now(),
RemoteAddr: c.Request.RemoteAddr,
UserAgent: c.Request.UserAgent(),
Filters: filters,
}
h.tracker.Register(connInfo)
defer h.tracker.Unregister(subscriberID)
}
// Parse query parameters for filtering
levelFilter := strings.ToLower(c.Query("level"))
sourceFilter := strings.ToLower(c.Query("source"))
@@ -115,6 +150,11 @@ func LogsWebSocketHandler(c *gin.Context) {
return
}
// Update activity timestamp
if h.tracker != nil {
h.tracker.UpdateActivity(subscriberID)
}
case <-ticker.C:
// Send ping to keep connection alive
if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {

View File

@@ -0,0 +1,34 @@
package handlers
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/Wikid82/charon/backend/internal/services"
)
// WebSocketStatusHandler provides endpoints for WebSocket connection monitoring.
type WebSocketStatusHandler struct {
tracker *services.WebSocketTracker
}
// NewWebSocketStatusHandler creates a new handler for WebSocket status monitoring.
func NewWebSocketStatusHandler(tracker *services.WebSocketTracker) *WebSocketStatusHandler {
return &WebSocketStatusHandler{tracker: tracker}
}
// GetConnections returns a list of all active WebSocket connections.
func (h *WebSocketStatusHandler) GetConnections(c *gin.Context) {
connections := h.tracker.GetAllConnections()
c.JSON(http.StatusOK, gin.H{
"connections": connections,
"count": len(connections),
})
}
// GetStats returns aggregate statistics about WebSocket connections.
func (h *WebSocketStatusHandler) GetStats(c *gin.Context) {
stats := h.tracker.GetStats()
c.JSON(http.StatusOK, stats)
}

View File

@@ -0,0 +1,169 @@
package handlers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/Wikid82/charon/backend/internal/services"
)
func TestWebSocketStatusHandler_GetConnections(t *testing.T) {
gin.SetMode(gin.TestMode)
tracker := services.NewWebSocketTracker()
handler := NewWebSocketStatusHandler(tracker)
// Register test connections
conn1 := &services.ConnectionInfo{
ID: "conn-1",
Type: "logs",
ConnectedAt: time.Now(),
LastActivityAt: time.Now(),
RemoteAddr: "192.168.1.1:12345",
UserAgent: "Mozilla/5.0",
Filters: "level=error",
}
conn2 := &services.ConnectionInfo{
ID: "conn-2",
Type: "cerberus",
ConnectedAt: time.Now(),
LastActivityAt: time.Now(),
RemoteAddr: "192.168.1.2:54321",
UserAgent: "Chrome/90.0",
Filters: "source=waf",
}
tracker.Register(conn1)
tracker.Register(conn2)
// Create test request
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/connections", nil)
// Call handler
handler.GetConnections(c)
// Verify response
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, float64(2), response["count"])
connections, ok := response["connections"].([]interface{})
require.True(t, ok)
assert.Len(t, connections, 2)
}
func TestWebSocketStatusHandler_GetConnectionsEmpty(t *testing.T) {
gin.SetMode(gin.TestMode)
tracker := services.NewWebSocketTracker()
handler := NewWebSocketStatusHandler(tracker)
// Create test request
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/connections", nil)
// Call handler
handler.GetConnections(c)
// Verify response
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, float64(0), response["count"])
connections, ok := response["connections"].([]interface{})
require.True(t, ok)
assert.Len(t, connections, 0)
}
func TestWebSocketStatusHandler_GetStats(t *testing.T) {
gin.SetMode(gin.TestMode)
tracker := services.NewWebSocketTracker()
handler := NewWebSocketStatusHandler(tracker)
// Register test connections
conn1 := &services.ConnectionInfo{
ID: "conn-1",
Type: "logs",
ConnectedAt: time.Now(),
}
conn2 := &services.ConnectionInfo{
ID: "conn-2",
Type: "logs",
ConnectedAt: time.Now(),
}
conn3 := &services.ConnectionInfo{
ID: "conn-3",
Type: "cerberus",
ConnectedAt: time.Now(),
}
tracker.Register(conn1)
tracker.Register(conn2)
tracker.Register(conn3)
// Create test request
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/stats", nil)
// Call handler
handler.GetStats(c)
// Verify response
assert.Equal(t, http.StatusOK, w.Code)
var stats services.ConnectionStats
err := json.Unmarshal(w.Body.Bytes(), &stats)
require.NoError(t, err)
assert.Equal(t, 3, stats.TotalActive)
assert.Equal(t, 2, stats.LogsConnections)
assert.Equal(t, 1, stats.CerberusConnections)
assert.NotNil(t, stats.OldestConnection)
assert.False(t, stats.LastUpdated.IsZero())
}
func TestWebSocketStatusHandler_GetStatsEmpty(t *testing.T) {
gin.SetMode(gin.TestMode)
tracker := services.NewWebSocketTracker()
handler := NewWebSocketStatusHandler(tracker)
// Create test request
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/stats", nil)
// Call handler
handler.GetStats(c)
// Verify response
assert.Equal(t, http.StatusOK, w.Code)
var stats services.ConnectionStats
err := json.Unmarshal(w.Body.Bytes(), &stats)
require.NoError(t, err)
assert.Equal(t, 0, stats.TotalActive)
assert.Equal(t, 0, stats.LogsConnections)
assert.Equal(t, 0, stats.CerberusConnections)
assert.Nil(t, stats.OldestConnection)
assert.False(t, stats.LastUpdated.IsZero())
}

View File

@@ -119,6 +119,10 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error {
logService := services.NewLogService(&cfg)
logsHandler := handlers.NewLogsHandler(logService)
// WebSocket tracker for connection monitoring
wsTracker := services.NewWebSocketTracker()
wsStatusHandler := handlers.NewWebSocketStatusHandler(wsTracker)
// Notification Service (needed for multiple handlers)
notificationService := services.NewNotificationService(db)
@@ -160,7 +164,14 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error {
protected.GET("/logs", logsHandler.List)
protected.GET("/logs/:filename", logsHandler.Read)
protected.GET("/logs/:filename/download", logsHandler.Download)
protected.GET("/logs/live", handlers.LogsWebSocketHandler)
// WebSocket endpoints
logsWSHandler := handlers.NewLogsWSHandler(wsTracker)
protected.GET("/logs/live", logsWSHandler.HandleWebSocket)
// WebSocket status monitoring
protected.GET("/websocket/connections", wsStatusHandler.GetConnections)
protected.GET("/websocket/stats", wsStatusHandler.GetStats)
// Security Notification Settings
securityNotificationService := services.NewSecurityNotificationService(db)
@@ -395,7 +406,7 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error {
if err := logWatcher.Start(context.Background()); err != nil {
logger.Log().WithError(err).Error("Failed to start security log watcher")
}
cerberusLogsHandler := handlers.NewCerberusLogsHandler(logWatcher)
cerberusLogsHandler := handlers.NewCerberusLogsHandler(logWatcher, wsTracker)
protected.GET("/cerberus/logs/ws", cerberusLogsHandler.LiveLogs)
// Access Lists

View File

@@ -0,0 +1,140 @@
// Package services provides business logic services for the application.
package services
import (
"sync"
"time"
"github.com/Wikid82/charon/backend/internal/logger"
)
// ConnectionInfo tracks information about a single WebSocket connection.
type ConnectionInfo struct {
ID string `json:"id"`
Type string `json:"type"` // "logs" or "cerberus"
ConnectedAt time.Time `json:"connected_at"`
LastActivityAt time.Time `json:"last_activity_at"`
RemoteAddr string `json:"remote_addr,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
Filters string `json:"filters,omitempty"` // Query parameters used for filtering
}
// ConnectionStats provides aggregate statistics about WebSocket connections.
type ConnectionStats struct {
TotalActive int `json:"total_active"`
LogsConnections int `json:"logs_connections"`
CerberusConnections int `json:"cerberus_connections"`
OldestConnection *time.Time `json:"oldest_connection,omitempty"`
LastUpdated time.Time `json:"last_updated"`
}
// WebSocketTracker tracks active WebSocket connections and provides statistics.
type WebSocketTracker struct {
mu sync.RWMutex
connections map[string]*ConnectionInfo
}
// NewWebSocketTracker creates a new WebSocket connection tracker.
func NewWebSocketTracker() *WebSocketTracker {
return &WebSocketTracker{
connections: make(map[string]*ConnectionInfo),
}
}
// Register adds a new WebSocket connection to tracking.
func (t *WebSocketTracker) Register(conn *ConnectionInfo) {
t.mu.Lock()
defer t.mu.Unlock()
t.connections[conn.ID] = conn
logger.Log().WithField("connection_id", conn.ID).
WithField("type", conn.Type).
WithField("remote_addr", conn.RemoteAddr).
Debug("WebSocket connection registered")
}
// Unregister removes a WebSocket connection from tracking.
func (t *WebSocketTracker) Unregister(connectionID string) {
t.mu.Lock()
defer t.mu.Unlock()
if conn, exists := t.connections[connectionID]; exists {
duration := time.Since(conn.ConnectedAt)
logger.Log().WithField("connection_id", connectionID).
WithField("type", conn.Type).
WithField("duration", duration.String()).
Debug("WebSocket connection unregistered")
delete(t.connections, connectionID)
}
}
// UpdateActivity updates the last activity timestamp for a connection.
func (t *WebSocketTracker) UpdateActivity(connectionID string) {
t.mu.Lock()
defer t.mu.Unlock()
if conn, exists := t.connections[connectionID]; exists {
conn.LastActivityAt = time.Now()
}
}
// GetConnection retrieves information about a specific connection.
func (t *WebSocketTracker) GetConnection(connectionID string) (*ConnectionInfo, bool) {
t.mu.RLock()
defer t.mu.RUnlock()
conn, exists := t.connections[connectionID]
return conn, exists
}
// GetAllConnections returns a slice of all active connections.
func (t *WebSocketTracker) GetAllConnections() []*ConnectionInfo {
t.mu.RLock()
defer t.mu.RUnlock()
connections := make([]*ConnectionInfo, 0, len(t.connections))
for _, conn := range t.connections {
// Create a copy to avoid race conditions
connCopy := *conn
connections = append(connections, &connCopy)
}
return connections
}
// GetStats returns aggregate statistics about WebSocket connections.
func (t *WebSocketTracker) GetStats() *ConnectionStats {
t.mu.RLock()
defer t.mu.RUnlock()
stats := &ConnectionStats{
TotalActive: len(t.connections),
LogsConnections: 0,
CerberusConnections: 0,
LastUpdated: time.Now(),
}
var oldestTime *time.Time
for _, conn := range t.connections {
switch conn.Type {
case "logs":
stats.LogsConnections++
case "cerberus":
stats.CerberusConnections++
}
if oldestTime == nil || conn.ConnectedAt.Before(*oldestTime) {
t := conn.ConnectedAt
oldestTime = &t
}
}
stats.OldestConnection = oldestTime
return stats
}
// GetCount returns the total number of active connections.
func (t *WebSocketTracker) GetCount() int {
t.mu.RLock()
defer t.mu.RUnlock()
return len(t.connections)
}

View File

@@ -0,0 +1,224 @@
package services
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewWebSocketTracker(t *testing.T) {
tracker := NewWebSocketTracker()
assert.NotNil(t, tracker)
assert.NotNil(t, tracker.connections)
assert.Equal(t, 0, tracker.GetCount())
}
func TestWebSocketTracker_Register(t *testing.T) {
tracker := NewWebSocketTracker()
conn := &ConnectionInfo{
ID: "test-conn-1",
Type: "logs",
ConnectedAt: time.Now(),
LastActivityAt: time.Now(),
RemoteAddr: "192.168.1.1:12345",
UserAgent: "Mozilla/5.0",
Filters: "level=error",
}
tracker.Register(conn)
assert.Equal(t, 1, tracker.GetCount())
// Verify the connection is retrievable
retrieved, exists := tracker.GetConnection("test-conn-1")
assert.True(t, exists)
assert.Equal(t, conn.ID, retrieved.ID)
assert.Equal(t, conn.Type, retrieved.Type)
}
func TestWebSocketTracker_Unregister(t *testing.T) {
tracker := NewWebSocketTracker()
conn := &ConnectionInfo{
ID: "test-conn-1",
Type: "cerberus",
ConnectedAt: time.Now(),
}
tracker.Register(conn)
assert.Equal(t, 1, tracker.GetCount())
tracker.Unregister("test-conn-1")
assert.Equal(t, 0, tracker.GetCount())
// Verify the connection is no longer retrievable
_, exists := tracker.GetConnection("test-conn-1")
assert.False(t, exists)
}
func TestWebSocketTracker_UnregisterNonExistent(t *testing.T) {
tracker := NewWebSocketTracker()
// Should not panic
tracker.Unregister("non-existent-id")
assert.Equal(t, 0, tracker.GetCount())
}
func TestWebSocketTracker_UpdateActivity(t *testing.T) {
tracker := NewWebSocketTracker()
initialTime := time.Now().Add(-1 * time.Hour)
conn := &ConnectionInfo{
ID: "test-conn-1",
Type: "logs",
ConnectedAt: initialTime,
LastActivityAt: initialTime,
}
tracker.Register(conn)
// Wait a moment to ensure time difference
time.Sleep(10 * time.Millisecond)
tracker.UpdateActivity("test-conn-1")
retrieved, exists := tracker.GetConnection("test-conn-1")
require.True(t, exists)
assert.True(t, retrieved.LastActivityAt.After(initialTime))
}
func TestWebSocketTracker_UpdateActivityNonExistent(t *testing.T) {
tracker := NewWebSocketTracker()
// Should not panic
tracker.UpdateActivity("non-existent-id")
}
func TestWebSocketTracker_GetAllConnections(t *testing.T) {
tracker := NewWebSocketTracker()
conn1 := &ConnectionInfo{
ID: "conn-1",
Type: "logs",
ConnectedAt: time.Now(),
}
conn2 := &ConnectionInfo{
ID: "conn-2",
Type: "cerberus",
ConnectedAt: time.Now(),
}
tracker.Register(conn1)
tracker.Register(conn2)
connections := tracker.GetAllConnections()
assert.Equal(t, 2, len(connections))
// Verify both connections are present (order may vary)
ids := make(map[string]bool)
for _, conn := range connections {
ids[conn.ID] = true
}
assert.True(t, ids["conn-1"])
assert.True(t, ids["conn-2"])
}
func TestWebSocketTracker_GetStats(t *testing.T) {
tracker := NewWebSocketTracker()
now := time.Now()
oldestTime := now.Add(-10 * time.Minute)
conn1 := &ConnectionInfo{
ID: "conn-1",
Type: "logs",
ConnectedAt: now,
}
conn2 := &ConnectionInfo{
ID: "conn-2",
Type: "cerberus",
ConnectedAt: oldestTime,
}
conn3 := &ConnectionInfo{
ID: "conn-3",
Type: "logs",
ConnectedAt: now.Add(-5 * time.Minute),
}
tracker.Register(conn1)
tracker.Register(conn2)
tracker.Register(conn3)
stats := tracker.GetStats()
assert.Equal(t, 3, stats.TotalActive)
assert.Equal(t, 2, stats.LogsConnections)
assert.Equal(t, 1, stats.CerberusConnections)
assert.NotNil(t, stats.OldestConnection)
assert.True(t, stats.OldestConnection.Equal(oldestTime))
assert.False(t, stats.LastUpdated.IsZero())
}
func TestWebSocketTracker_GetStatsEmpty(t *testing.T) {
tracker := NewWebSocketTracker()
stats := tracker.GetStats()
assert.Equal(t, 0, stats.TotalActive)
assert.Equal(t, 0, stats.LogsConnections)
assert.Equal(t, 0, stats.CerberusConnections)
assert.Nil(t, stats.OldestConnection)
assert.False(t, stats.LastUpdated.IsZero())
}
func TestWebSocketTracker_ConcurrentAccess(t *testing.T) {
tracker := NewWebSocketTracker()
// Test concurrent registration
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(id int) {
conn := &ConnectionInfo{
ID: string(rune('a' + id)),
Type: "logs",
ConnectedAt: time.Now(),
}
tracker.Register(conn)
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
assert.Equal(t, 10, tracker.GetCount())
// Test concurrent read
for i := 0; i < 10; i++ {
go func() {
_ = tracker.GetAllConnections()
_ = tracker.GetStats()
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
// Test concurrent unregister
for i := 0; i < 10; i++ {
go func(id int) {
tracker.Unregister(string(rune('a' + id)))
done <- true
}(i)
}
for i := 0; i < 10; i++ {
<-done
}
assert.Equal(t, 0, tracker.GetCount())
}