feat: add WebSocket connection tracking backend
Co-authored-by: Wikid82 <176516789+Wikid82@users.noreply.github.com>
This commit is contained in:
@@ -16,11 +16,15 @@ import (
|
||||
// CerberusLogsHandler handles WebSocket connections for streaming security logs.
|
||||
type CerberusLogsHandler struct {
|
||||
watcher *services.LogWatcher
|
||||
tracker *services.WebSocketTracker
|
||||
}
|
||||
|
||||
// NewCerberusLogsHandler creates a new handler for Cerberus security log streaming.
|
||||
func NewCerberusLogsHandler(watcher *services.LogWatcher) *CerberusLogsHandler {
|
||||
return &CerberusLogsHandler{watcher: watcher}
|
||||
func NewCerberusLogsHandler(watcher *services.LogWatcher, tracker *services.WebSocketTracker) *CerberusLogsHandler {
|
||||
return &CerberusLogsHandler{
|
||||
watcher: watcher,
|
||||
tracker: tracker,
|
||||
}
|
||||
}
|
||||
|
||||
// LiveLogs handles WebSocket connections for Cerberus security log streaming.
|
||||
@@ -52,6 +56,22 @@ func (h *CerberusLogsHandler) LiveLogs(c *gin.Context) {
|
||||
subscriberID := uuid.New().String()
|
||||
logger.Log().WithField("subscriber_id", subscriberID).Info("Cerberus logs WebSocket connected")
|
||||
|
||||
// Register connection with tracker if available
|
||||
if h.tracker != nil {
|
||||
filters := c.Request.URL.RawQuery
|
||||
connInfo := &services.ConnectionInfo{
|
||||
ID: subscriberID,
|
||||
Type: "cerberus",
|
||||
ConnectedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
RemoteAddr: c.Request.RemoteAddr,
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
Filters: filters,
|
||||
}
|
||||
h.tracker.Register(connInfo)
|
||||
defer h.tracker.Unregister(subscriberID)
|
||||
}
|
||||
|
||||
// Parse query filters
|
||||
sourceFilter := strings.ToLower(c.Query("source")) // waf, crowdsec, ratelimit, acl, normal
|
||||
levelFilter := strings.ToLower(c.Query("level")) // info, warn, error
|
||||
@@ -117,6 +137,11 @@ func (h *CerberusLogsHandler) LiveLogs(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update activity timestamp
|
||||
if h.tracker != nil {
|
||||
h.tracker.UpdateActivity(subscriberID)
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
// Send ping to keep connection alive
|
||||
if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||
|
||||
@@ -29,10 +29,12 @@ func TestCerberusLogsHandler_NewHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
watcher := services.NewLogWatcher("/tmp/test.log")
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
tracker := services.NewWebSocketTracker()
|
||||
handler := NewCerberusLogsHandler(watcher, tracker)
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
assert.Equal(t, watcher, handler.watcher)
|
||||
assert.Equal(t, tracker, handler.tracker)
|
||||
}
|
||||
|
||||
// TestCerberusLogsHandler_SuccessfulConnection verifies WebSocket upgrade.
|
||||
@@ -51,7 +53,7 @@ func TestCerberusLogsHandler_SuccessfulConnection(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer watcher.Stop()
|
||||
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
// Create test server
|
||||
router := gin.New()
|
||||
@@ -88,7 +90,7 @@ func TestCerberusLogsHandler_ReceiveLogEntries(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer watcher.Stop()
|
||||
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
// Create test server
|
||||
router := gin.New()
|
||||
@@ -157,7 +159,7 @@ func TestCerberusLogsHandler_SourceFilter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer watcher.Stop()
|
||||
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/ws", handler.LiveLogs)
|
||||
@@ -236,7 +238,7 @@ func TestCerberusLogsHandler_BlockedOnlyFilter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer watcher.Stop()
|
||||
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/ws", handler.LiveLogs)
|
||||
@@ -313,7 +315,7 @@ func TestCerberusLogsHandler_IPFilter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer watcher.Stop()
|
||||
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/ws", handler.LiveLogs)
|
||||
@@ -388,7 +390,7 @@ func TestCerberusLogsHandler_ClientDisconnect(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer watcher.Stop()
|
||||
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/ws", handler.LiveLogs)
|
||||
@@ -424,7 +426,7 @@ func TestCerberusLogsHandler_MultipleClients(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer watcher.Stop()
|
||||
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/ws", handler.LiveLogs)
|
||||
@@ -486,7 +488,7 @@ func TestCerberusLogsHandler_UpgradeFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
watcher := services.NewLogWatcher("/tmp/test.log")
|
||||
handler := NewCerberusLogsHandler(watcher)
|
||||
handler := NewCerberusLogsHandler(watcher, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/ws", handler.LiveLogs)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/services"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
@@ -31,8 +32,26 @@ type LogEntry struct {
|
||||
Fields map[string]interface{} `json:"fields"`
|
||||
}
|
||||
|
||||
// LogsWSHandler handles WebSocket connections for live log streaming.
|
||||
type LogsWSHandler struct {
|
||||
tracker *services.WebSocketTracker
|
||||
}
|
||||
|
||||
// NewLogsWSHandler creates a new handler for log streaming.
|
||||
func NewLogsWSHandler(tracker *services.WebSocketTracker) *LogsWSHandler {
|
||||
return &LogsWSHandler{tracker: tracker}
|
||||
}
|
||||
|
||||
// LogsWebSocketHandler handles WebSocket connections for live log streaming.
|
||||
// DEPRECATED: Use NewLogsWSHandler().HandleWebSocket instead. Kept for backward compatibility.
|
||||
func LogsWebSocketHandler(c *gin.Context) {
|
||||
// For backward compatibility, create a nil tracker if called directly
|
||||
handler := NewLogsWSHandler(nil)
|
||||
handler.HandleWebSocket(c)
|
||||
}
|
||||
|
||||
// HandleWebSocket handles WebSocket connections for live log streaming.
|
||||
func (h *LogsWSHandler) HandleWebSocket(c *gin.Context) {
|
||||
logger.Log().Info("WebSocket connection attempt received")
|
||||
|
||||
// Upgrade HTTP connection to WebSocket
|
||||
@@ -52,6 +71,22 @@ func LogsWebSocketHandler(c *gin.Context) {
|
||||
|
||||
logger.Log().WithField("subscriber_id", subscriberID).Info("WebSocket connection established successfully")
|
||||
|
||||
// Register connection with tracker if available
|
||||
if h.tracker != nil {
|
||||
filters := c.Request.URL.RawQuery
|
||||
connInfo := &services.ConnectionInfo{
|
||||
ID: subscriberID,
|
||||
Type: "logs",
|
||||
ConnectedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
RemoteAddr: c.Request.RemoteAddr,
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
Filters: filters,
|
||||
}
|
||||
h.tracker.Register(connInfo)
|
||||
defer h.tracker.Unregister(subscriberID)
|
||||
}
|
||||
|
||||
// Parse query parameters for filtering
|
||||
levelFilter := strings.ToLower(c.Query("level"))
|
||||
sourceFilter := strings.ToLower(c.Query("source"))
|
||||
@@ -115,6 +150,11 @@ func LogsWebSocketHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update activity timestamp
|
||||
if h.tracker != nil {
|
||||
h.tracker.UpdateActivity(subscriberID)
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
// Send ping to keep connection alive
|
||||
if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
|
||||
|
||||
34
backend/internal/api/handlers/websocket_status_handler.go
Normal file
34
backend/internal/api/handlers/websocket_status_handler.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/services"
|
||||
)
|
||||
|
||||
// WebSocketStatusHandler provides endpoints for WebSocket connection monitoring.
|
||||
type WebSocketStatusHandler struct {
|
||||
tracker *services.WebSocketTracker
|
||||
}
|
||||
|
||||
// NewWebSocketStatusHandler creates a new handler for WebSocket status monitoring.
|
||||
func NewWebSocketStatusHandler(tracker *services.WebSocketTracker) *WebSocketStatusHandler {
|
||||
return &WebSocketStatusHandler{tracker: tracker}
|
||||
}
|
||||
|
||||
// GetConnections returns a list of all active WebSocket connections.
|
||||
func (h *WebSocketStatusHandler) GetConnections(c *gin.Context) {
|
||||
connections := h.tracker.GetAllConnections()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"connections": connections,
|
||||
"count": len(connections),
|
||||
})
|
||||
}
|
||||
|
||||
// GetStats returns aggregate statistics about WebSocket connections.
|
||||
func (h *WebSocketStatusHandler) GetStats(c *gin.Context) {
|
||||
stats := h.tracker.GetStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
169
backend/internal/api/handlers/websocket_status_handler_test.go
Normal file
169
backend/internal/api/handlers/websocket_status_handler_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/services"
|
||||
)
|
||||
|
||||
func TestWebSocketStatusHandler_GetConnections(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tracker := services.NewWebSocketTracker()
|
||||
handler := NewWebSocketStatusHandler(tracker)
|
||||
|
||||
// Register test connections
|
||||
conn1 := &services.ConnectionInfo{
|
||||
ID: "conn-1",
|
||||
Type: "logs",
|
||||
ConnectedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
RemoteAddr: "192.168.1.1:12345",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
Filters: "level=error",
|
||||
}
|
||||
conn2 := &services.ConnectionInfo{
|
||||
ID: "conn-2",
|
||||
Type: "cerberus",
|
||||
ConnectedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
RemoteAddr: "192.168.1.2:54321",
|
||||
UserAgent: "Chrome/90.0",
|
||||
Filters: "source=waf",
|
||||
}
|
||||
|
||||
tracker.Register(conn1)
|
||||
tracker.Register(conn2)
|
||||
|
||||
// Create test request
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/connections", nil)
|
||||
|
||||
// Call handler
|
||||
handler.GetConnections(c)
|
||||
|
||||
// Verify response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, float64(2), response["count"])
|
||||
connections, ok := response["connections"].([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, connections, 2)
|
||||
}
|
||||
|
||||
func TestWebSocketStatusHandler_GetConnectionsEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tracker := services.NewWebSocketTracker()
|
||||
handler := NewWebSocketStatusHandler(tracker)
|
||||
|
||||
// Create test request
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/connections", nil)
|
||||
|
||||
// Call handler
|
||||
handler.GetConnections(c)
|
||||
|
||||
// Verify response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, float64(0), response["count"])
|
||||
connections, ok := response["connections"].([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, connections, 0)
|
||||
}
|
||||
|
||||
func TestWebSocketStatusHandler_GetStats(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tracker := services.NewWebSocketTracker()
|
||||
handler := NewWebSocketStatusHandler(tracker)
|
||||
|
||||
// Register test connections
|
||||
conn1 := &services.ConnectionInfo{
|
||||
ID: "conn-1",
|
||||
Type: "logs",
|
||||
ConnectedAt: time.Now(),
|
||||
}
|
||||
conn2 := &services.ConnectionInfo{
|
||||
ID: "conn-2",
|
||||
Type: "logs",
|
||||
ConnectedAt: time.Now(),
|
||||
}
|
||||
conn3 := &services.ConnectionInfo{
|
||||
ID: "conn-3",
|
||||
Type: "cerberus",
|
||||
ConnectedAt: time.Now(),
|
||||
}
|
||||
|
||||
tracker.Register(conn1)
|
||||
tracker.Register(conn2)
|
||||
tracker.Register(conn3)
|
||||
|
||||
// Create test request
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/stats", nil)
|
||||
|
||||
// Call handler
|
||||
handler.GetStats(c)
|
||||
|
||||
// Verify response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var stats services.ConnectionStats
|
||||
err := json.Unmarshal(w.Body.Bytes(), &stats)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, stats.TotalActive)
|
||||
assert.Equal(t, 2, stats.LogsConnections)
|
||||
assert.Equal(t, 1, stats.CerberusConnections)
|
||||
assert.NotNil(t, stats.OldestConnection)
|
||||
assert.False(t, stats.LastUpdated.IsZero())
|
||||
}
|
||||
|
||||
func TestWebSocketStatusHandler_GetStatsEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tracker := services.NewWebSocketTracker()
|
||||
handler := NewWebSocketStatusHandler(tracker)
|
||||
|
||||
// Create test request
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/websocket/stats", nil)
|
||||
|
||||
// Call handler
|
||||
handler.GetStats(c)
|
||||
|
||||
// Verify response
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var stats services.ConnectionStats
|
||||
err := json.Unmarshal(w.Body.Bytes(), &stats)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, stats.TotalActive)
|
||||
assert.Equal(t, 0, stats.LogsConnections)
|
||||
assert.Equal(t, 0, stats.CerberusConnections)
|
||||
assert.Nil(t, stats.OldestConnection)
|
||||
assert.False(t, stats.LastUpdated.IsZero())
|
||||
}
|
||||
@@ -119,6 +119,10 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error {
|
||||
logService := services.NewLogService(&cfg)
|
||||
logsHandler := handlers.NewLogsHandler(logService)
|
||||
|
||||
// WebSocket tracker for connection monitoring
|
||||
wsTracker := services.NewWebSocketTracker()
|
||||
wsStatusHandler := handlers.NewWebSocketStatusHandler(wsTracker)
|
||||
|
||||
// Notification Service (needed for multiple handlers)
|
||||
notificationService := services.NewNotificationService(db)
|
||||
|
||||
@@ -160,7 +164,14 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error {
|
||||
protected.GET("/logs", logsHandler.List)
|
||||
protected.GET("/logs/:filename", logsHandler.Read)
|
||||
protected.GET("/logs/:filename/download", logsHandler.Download)
|
||||
protected.GET("/logs/live", handlers.LogsWebSocketHandler)
|
||||
|
||||
// WebSocket endpoints
|
||||
logsWSHandler := handlers.NewLogsWSHandler(wsTracker)
|
||||
protected.GET("/logs/live", logsWSHandler.HandleWebSocket)
|
||||
|
||||
// WebSocket status monitoring
|
||||
protected.GET("/websocket/connections", wsStatusHandler.GetConnections)
|
||||
protected.GET("/websocket/stats", wsStatusHandler.GetStats)
|
||||
|
||||
// Security Notification Settings
|
||||
securityNotificationService := services.NewSecurityNotificationService(db)
|
||||
@@ -395,7 +406,7 @@ func Register(router *gin.Engine, db *gorm.DB, cfg config.Config) error {
|
||||
if err := logWatcher.Start(context.Background()); err != nil {
|
||||
logger.Log().WithError(err).Error("Failed to start security log watcher")
|
||||
}
|
||||
cerberusLogsHandler := handlers.NewCerberusLogsHandler(logWatcher)
|
||||
cerberusLogsHandler := handlers.NewCerberusLogsHandler(logWatcher, wsTracker)
|
||||
protected.GET("/cerberus/logs/ws", cerberusLogsHandler.LiveLogs)
|
||||
|
||||
// Access Lists
|
||||
|
||||
140
backend/internal/services/websocket_tracker.go
Normal file
140
backend/internal/services/websocket_tracker.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Package services provides business logic services for the application.
|
||||
package services
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
)
|
||||
|
||||
// ConnectionInfo tracks information about a single WebSocket connection.
|
||||
type ConnectionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "logs" or "cerberus"
|
||||
ConnectedAt time.Time `json:"connected_at"`
|
||||
LastActivityAt time.Time `json:"last_activity_at"`
|
||||
RemoteAddr string `json:"remote_addr,omitempty"`
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
Filters string `json:"filters,omitempty"` // Query parameters used for filtering
|
||||
}
|
||||
|
||||
// ConnectionStats provides aggregate statistics about WebSocket connections.
|
||||
type ConnectionStats struct {
|
||||
TotalActive int `json:"total_active"`
|
||||
LogsConnections int `json:"logs_connections"`
|
||||
CerberusConnections int `json:"cerberus_connections"`
|
||||
OldestConnection *time.Time `json:"oldest_connection,omitempty"`
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
}
|
||||
|
||||
// WebSocketTracker tracks active WebSocket connections and provides statistics.
|
||||
type WebSocketTracker struct {
|
||||
mu sync.RWMutex
|
||||
connections map[string]*ConnectionInfo
|
||||
}
|
||||
|
||||
// NewWebSocketTracker creates a new WebSocket connection tracker.
|
||||
func NewWebSocketTracker() *WebSocketTracker {
|
||||
return &WebSocketTracker{
|
||||
connections: make(map[string]*ConnectionInfo),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a new WebSocket connection to tracking.
|
||||
func (t *WebSocketTracker) Register(conn *ConnectionInfo) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.connections[conn.ID] = conn
|
||||
logger.Log().WithField("connection_id", conn.ID).
|
||||
WithField("type", conn.Type).
|
||||
WithField("remote_addr", conn.RemoteAddr).
|
||||
Debug("WebSocket connection registered")
|
||||
}
|
||||
|
||||
// Unregister removes a WebSocket connection from tracking.
|
||||
func (t *WebSocketTracker) Unregister(connectionID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if conn, exists := t.connections[connectionID]; exists {
|
||||
duration := time.Since(conn.ConnectedAt)
|
||||
logger.Log().WithField("connection_id", connectionID).
|
||||
WithField("type", conn.Type).
|
||||
WithField("duration", duration.String()).
|
||||
Debug("WebSocket connection unregistered")
|
||||
delete(t.connections, connectionID)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp for a connection.
|
||||
func (t *WebSocketTracker) UpdateActivity(connectionID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if conn, exists := t.connections[connectionID]; exists {
|
||||
conn.LastActivityAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnection retrieves information about a specific connection.
|
||||
func (t *WebSocketTracker) GetConnection(connectionID string) (*ConnectionInfo, bool) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
conn, exists := t.connections[connectionID]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// GetAllConnections returns a slice of all active connections.
|
||||
func (t *WebSocketTracker) GetAllConnections() []*ConnectionInfo {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
connections := make([]*ConnectionInfo, 0, len(t.connections))
|
||||
for _, conn := range t.connections {
|
||||
// Create a copy to avoid race conditions
|
||||
connCopy := *conn
|
||||
connections = append(connections, &connCopy)
|
||||
}
|
||||
return connections
|
||||
}
|
||||
|
||||
// GetStats returns aggregate statistics about WebSocket connections.
|
||||
func (t *WebSocketTracker) GetStats() *ConnectionStats {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
stats := &ConnectionStats{
|
||||
TotalActive: len(t.connections),
|
||||
LogsConnections: 0,
|
||||
CerberusConnections: 0,
|
||||
LastUpdated: time.Now(),
|
||||
}
|
||||
|
||||
var oldestTime *time.Time
|
||||
for _, conn := range t.connections {
|
||||
switch conn.Type {
|
||||
case "logs":
|
||||
stats.LogsConnections++
|
||||
case "cerberus":
|
||||
stats.CerberusConnections++
|
||||
}
|
||||
|
||||
if oldestTime == nil || conn.ConnectedAt.Before(*oldestTime) {
|
||||
t := conn.ConnectedAt
|
||||
oldestTime = &t
|
||||
}
|
||||
}
|
||||
|
||||
stats.OldestConnection = oldestTime
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetCount returns the total number of active connections.
|
||||
func (t *WebSocketTracker) GetCount() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return len(t.connections)
|
||||
}
|
||||
224
backend/internal/services/websocket_tracker_test.go
Normal file
224
backend/internal/services/websocket_tracker_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewWebSocketTracker(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
assert.NotNil(t, tracker)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
assert.Equal(t, 0, tracker.GetCount())
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_Register(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
conn := &ConnectionInfo{
|
||||
ID: "test-conn-1",
|
||||
Type: "logs",
|
||||
ConnectedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
RemoteAddr: "192.168.1.1:12345",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
Filters: "level=error",
|
||||
}
|
||||
|
||||
tracker.Register(conn)
|
||||
assert.Equal(t, 1, tracker.GetCount())
|
||||
|
||||
// Verify the connection is retrievable
|
||||
retrieved, exists := tracker.GetConnection("test-conn-1")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, conn.ID, retrieved.ID)
|
||||
assert.Equal(t, conn.Type, retrieved.Type)
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_Unregister(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
conn := &ConnectionInfo{
|
||||
ID: "test-conn-1",
|
||||
Type: "cerberus",
|
||||
ConnectedAt: time.Now(),
|
||||
}
|
||||
|
||||
tracker.Register(conn)
|
||||
assert.Equal(t, 1, tracker.GetCount())
|
||||
|
||||
tracker.Unregister("test-conn-1")
|
||||
assert.Equal(t, 0, tracker.GetCount())
|
||||
|
||||
// Verify the connection is no longer retrievable
|
||||
_, exists := tracker.GetConnection("test-conn-1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_UnregisterNonExistent(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
// Should not panic
|
||||
tracker.Unregister("non-existent-id")
|
||||
assert.Equal(t, 0, tracker.GetCount())
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_UpdateActivity(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
initialTime := time.Now().Add(-1 * time.Hour)
|
||||
conn := &ConnectionInfo{
|
||||
ID: "test-conn-1",
|
||||
Type: "logs",
|
||||
ConnectedAt: initialTime,
|
||||
LastActivityAt: initialTime,
|
||||
}
|
||||
|
||||
tracker.Register(conn)
|
||||
|
||||
// Wait a moment to ensure time difference
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
tracker.UpdateActivity("test-conn-1")
|
||||
|
||||
retrieved, exists := tracker.GetConnection("test-conn-1")
|
||||
require.True(t, exists)
|
||||
assert.True(t, retrieved.LastActivityAt.After(initialTime))
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_UpdateActivityNonExistent(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
// Should not panic
|
||||
tracker.UpdateActivity("non-existent-id")
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_GetAllConnections(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
conn1 := &ConnectionInfo{
|
||||
ID: "conn-1",
|
||||
Type: "logs",
|
||||
ConnectedAt: time.Now(),
|
||||
}
|
||||
conn2 := &ConnectionInfo{
|
||||
ID: "conn-2",
|
||||
Type: "cerberus",
|
||||
ConnectedAt: time.Now(),
|
||||
}
|
||||
|
||||
tracker.Register(conn1)
|
||||
tracker.Register(conn2)
|
||||
|
||||
connections := tracker.GetAllConnections()
|
||||
assert.Equal(t, 2, len(connections))
|
||||
|
||||
// Verify both connections are present (order may vary)
|
||||
ids := make(map[string]bool)
|
||||
for _, conn := range connections {
|
||||
ids[conn.ID] = true
|
||||
}
|
||||
assert.True(t, ids["conn-1"])
|
||||
assert.True(t, ids["conn-2"])
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_GetStats(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
now := time.Now()
|
||||
oldestTime := now.Add(-10 * time.Minute)
|
||||
|
||||
conn1 := &ConnectionInfo{
|
||||
ID: "conn-1",
|
||||
Type: "logs",
|
||||
ConnectedAt: now,
|
||||
}
|
||||
conn2 := &ConnectionInfo{
|
||||
ID: "conn-2",
|
||||
Type: "cerberus",
|
||||
ConnectedAt: oldestTime,
|
||||
}
|
||||
conn3 := &ConnectionInfo{
|
||||
ID: "conn-3",
|
||||
Type: "logs",
|
||||
ConnectedAt: now.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
tracker.Register(conn1)
|
||||
tracker.Register(conn2)
|
||||
tracker.Register(conn3)
|
||||
|
||||
stats := tracker.GetStats()
|
||||
assert.Equal(t, 3, stats.TotalActive)
|
||||
assert.Equal(t, 2, stats.LogsConnections)
|
||||
assert.Equal(t, 1, stats.CerberusConnections)
|
||||
assert.NotNil(t, stats.OldestConnection)
|
||||
assert.True(t, stats.OldestConnection.Equal(oldestTime))
|
||||
assert.False(t, stats.LastUpdated.IsZero())
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_GetStatsEmpty(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
stats := tracker.GetStats()
|
||||
assert.Equal(t, 0, stats.TotalActive)
|
||||
assert.Equal(t, 0, stats.LogsConnections)
|
||||
assert.Equal(t, 0, stats.CerberusConnections)
|
||||
assert.Nil(t, stats.OldestConnection)
|
||||
assert.False(t, stats.LastUpdated.IsZero())
|
||||
}
|
||||
|
||||
func TestWebSocketTracker_ConcurrentAccess(t *testing.T) {
|
||||
tracker := NewWebSocketTracker()
|
||||
|
||||
// Test concurrent registration
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
conn := &ConnectionInfo{
|
||||
ID: string(rune('a' + id)),
|
||||
Type: "logs",
|
||||
ConnectedAt: time.Now(),
|
||||
}
|
||||
tracker.Register(conn)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, tracker.GetCount())
|
||||
|
||||
// Test concurrent read
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
_ = tracker.GetAllConnections()
|
||||
_ = tracker.GetStats()
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Test concurrent unregister
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
tracker.Unregister(string(rune('a' + id)))
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, tracker.GetCount())
|
||||
}
|
||||
Reference in New Issue
Block a user