Files
Charon/backend/internal/api/handlers/feature_flags_handler_test.go
GitHub Actions e6c4e46dd8 chore: Refactor test setup for Gin framework
- Removed redundant `gin.SetMode(gin.TestMode)` calls from individual test files.
- Introduced a centralized `TestMain` function in `testmain_test.go` to set the Gin mode for all tests.
- Ensured consistent test environment setup across various handler test files.
2026-03-25 22:00:07 +00:00

285 lines
8.5 KiB
Go

package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/Wikid82/charon/backend/internal/models"
)
func setupFlagsDB(t *testing.T) *gorm.DB {
db := OpenTestDB(t)
if err := db.AutoMigrate(&models.Setting{}); err != nil {
t.Fatalf("auto migrate failed: %v", err)
}
return db
}
func TestFeatureFlags_GetAndUpdate(t *testing.T) {
db := setupFlagsDB(t)
h := NewFeatureFlagsHandler(db)
r := gin.New()
r.GET("/api/v1/feature-flags", h.GetFlags)
r.PUT("/api/v1/feature-flags", h.UpdateFlags)
// 1) GET should return all default flags (as keys)
req := httptest.NewRequest(http.MethodGet, "/api/v1/feature-flags", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 got %d body=%s", w.Code, w.Body.String())
}
var flags map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &flags); err != nil {
t.Fatalf("invalid json: %v", err)
}
// ensure keys present
for _, k := range defaultFlags {
if _, ok := flags[k]; !ok {
t.Fatalf("missing default flag key: %s", k)
}
}
// 2) PUT update a single flag
payload := map[string]bool{
defaultFlags[0]: true,
}
b, _ := json.Marshal(payload)
req2 := httptest.NewRequest(http.MethodPut, "/api/v1/feature-flags", bytes.NewReader(b))
req2.Header.Set("Content-Type", "application/json")
w2 := httptest.NewRecorder()
r.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("expected 200 on update got %d body=%s", w2.Code, w2.Body.String())
}
// confirm DB persisted
var s models.Setting
if err := db.Where("key = ?", defaultFlags[0]).First(&s).Error; err != nil {
t.Fatalf("expected setting persisted, db error: %v", err)
}
if s.Value != "true" {
t.Fatalf("expected stored value 'true' got '%s'", s.Value)
}
}
func TestFeatureFlags_EnvFallback(t *testing.T) {
// Ensure env fallback is used when DB not present
t.Setenv("FEATURE_CERBERUS_ENABLED", "true")
db := setupFlagsDB(t)
// Do not write any settings so DB lookup fails and env is used
h := NewFeatureFlagsHandler(db)
r := gin.New()
r.GET("/api/v1/feature-flags", h.GetFlags)
req := httptest.NewRequest(http.MethodGet, "/api/v1/feature-flags", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 got %d body=%s", w.Code, w.Body.String())
}
var flags map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &flags); err != nil {
t.Fatalf("invalid json: %v", err)
}
if !flags["feature.cerberus.enabled"] {
t.Fatalf("expected feature.cerberus.enabled to be true via env fallback")
}
}
// setupBenchmarkFlagsDB creates an in-memory SQLite database for feature flags benchmarks
func setupBenchmarkFlagsDB(b *testing.B) *gorm.DB {
b.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
b.Fatal(err)
}
if err := db.AutoMigrate(&models.Setting{}); err != nil {
b.Fatal(err)
}
return db
}
// BenchmarkGetFlags measures GetFlags performance with batch query
func BenchmarkGetFlags(b *testing.B) {
db := setupBenchmarkFlagsDB(b)
// Seed database with all default flags
db.Create(&models.Setting{Key: "feature.cerberus.enabled", Value: "true", Type: "bool", Category: "feature"})
db.Create(&models.Setting{Key: "feature.uptime.enabled", Value: "false", Type: "bool", Category: "feature"})
db.Create(&models.Setting{Key: "feature.crowdsec.console_enrollment", Value: "true", Type: "bool", Category: "feature"})
h := NewFeatureFlagsHandler(db)
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.GET("/api/v1/feature-flags", h.GetFlags)
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodGet, "/api/v1/feature-flags", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
b.Fatalf("expected 200 got %d", w.Code)
}
}
}
// BenchmarkUpdateFlags measures UpdateFlags performance with transaction wrapping
func BenchmarkUpdateFlags(b *testing.B) {
db := setupBenchmarkFlagsDB(b)
h := NewFeatureFlagsHandler(db)
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.PUT("/api/v1/feature-flags", h.UpdateFlags)
payload := map[string]bool{
"feature.cerberus.enabled": true,
"feature.uptime.enabled": false,
"feature.crowdsec.console_enrollment": true,
}
payloadBytes, _ := json.Marshal(payload)
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodPut, "/api/v1/feature-flags", bytes.NewReader(payloadBytes))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
b.Fatalf("expected 200 got %d", w.Code)
}
}
}
// TestGetFlags_BatchQuery verifies that GetFlags uses a single batch query
func TestGetFlags_BatchQuery(t *testing.T) {
db := setupFlagsDB(t)
// Insert multiple flags
db.Create(&models.Setting{Key: "feature.cerberus.enabled", Value: "true", Type: "bool", Category: "feature"})
db.Create(&models.Setting{Key: "feature.uptime.enabled", Value: "false", Type: "bool", Category: "feature"})
db.Create(&models.Setting{Key: "feature.crowdsec.console_enrollment", Value: "true", Type: "bool", Category: "feature"})
h := NewFeatureFlagsHandler(db)
r := gin.New()
r.GET("/api/v1/feature-flags", h.GetFlags)
req := httptest.NewRequest(http.MethodGet, "/api/v1/feature-flags", http.NoBody)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 got %d body=%s", w.Code, w.Body.String())
}
var flags map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &flags); err != nil {
t.Fatalf("invalid json: %v", err)
}
// Verify all flags returned with correct values
if !flags["feature.cerberus.enabled"] {
t.Errorf("expected cerberus.enabled to be true")
}
if flags["feature.uptime.enabled"] {
t.Errorf("expected uptime.enabled to be false")
}
if !flags["feature.crowdsec.console_enrollment"] {
t.Errorf("expected crowdsec.console_enrollment to be true")
}
}
// TestUpdateFlags_TransactionRollback verifies transaction rollback on error
func TestUpdateFlags_TransactionRollback(t *testing.T) {
db := setupFlagsDB(t)
// Close the DB to force an error during transaction
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("failed to get sql.DB: %v", err)
}
_ = sqlDB.Close()
h := NewFeatureFlagsHandler(db)
r := gin.New()
r.PUT("/api/v1/feature-flags", h.UpdateFlags)
payload := map[string]bool{
"feature.cerberus.enabled": true,
}
b, _ := json.Marshal(payload)
req := httptest.NewRequest(http.MethodPut, "/api/v1/feature-flags", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Should return error due to closed DB
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 got %d body=%s", w.Code, w.Body.String())
}
}
// TestUpdateFlags_TransactionAtomic verifies all updates succeed or all fail
func TestUpdateFlags_TransactionAtomic(t *testing.T) {
db := setupFlagsDB(t)
h := NewFeatureFlagsHandler(db)
r := gin.New()
r.PUT("/api/v1/feature-flags", h.UpdateFlags)
// Update multiple flags
payload := map[string]bool{
"feature.cerberus.enabled": true,
"feature.uptime.enabled": false,
"feature.crowdsec.console_enrollment": true,
}
b, _ := json.Marshal(payload)
req := httptest.NewRequest(http.MethodPut, "/api/v1/feature-flags", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 got %d body=%s", w.Code, w.Body.String())
}
// Verify all flags persisted
var s1 models.Setting
if err := db.Where("key = ?", "feature.cerberus.enabled").First(&s1).Error; err != nil {
t.Errorf("expected cerberus.enabled to be persisted")
} else if s1.Value != "true" {
t.Errorf("expected cerberus.enabled to be true, got %s", s1.Value)
}
var s2 models.Setting
if err := db.Where("key = ?", "feature.uptime.enabled").First(&s2).Error; err != nil {
t.Errorf("expected uptime.enabled to be persisted")
} else if s2.Value != "false" {
t.Errorf("expected uptime.enabled to be false, got %s", s2.Value)
}
var s3 models.Setting
if err := db.Where("key = ?", "feature.crowdsec.console_enrollment").First(&s3).Error; err != nil {
t.Errorf("expected crowdsec.console_enrollment to be persisted")
} else if s3.Value != "true" {
t.Errorf("expected crowdsec.console_enrollment to be true, got %s", s3.Value)
}
}