- The security config Upsert update path copied all rate limit fields from the incoming request onto the existing database record except RateLimitMode, so the seeded default value of "disabled" always survived a POST regardless of what the caller sent - This silently prevented the Caddy rate_limit handler from being injected on any container with a pre-existing config record (i.e., every real deployment and every CI run after migration) - Added the missing field assignment so RateLimitMode is correctly persisted on update alongside all other rate limit settings - Integration test payload now also sends rate_limit_enable alongside rate_limit_mode so the handler sync logic fires via its explicit first branch, providing belt-and-suspenders correctness independent of which path the caller uses to express intent
474 lines
14 KiB
Go
474 lines
14 KiB
Go
package services
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
"github.com/Wikid82/charon/backend/internal/models"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
ErrSecurityConfigNotFound = errors.New("security config not found")
|
|
ErrInvalidAdminCIDR = errors.New("invalid admin whitelist CIDR")
|
|
ErrBreakGlassInvalid = errors.New("break-glass token invalid")
|
|
)
|
|
|
|
type SecurityService struct {
|
|
db *gorm.DB
|
|
auditChan chan *models.SecurityAudit
|
|
done chan struct{} // Channel to signal goroutine to stop
|
|
wg sync.WaitGroup // WaitGroup to track goroutine completion
|
|
closed bool // Flag to prevent double-close
|
|
mu sync.Mutex // Mutex to protect closed flag
|
|
}
|
|
|
|
// NewSecurityService returns a SecurityService using the provided DB
|
|
func NewSecurityService(db *gorm.DB) *SecurityService {
|
|
s := &SecurityService{
|
|
db: db,
|
|
auditChan: make(chan *models.SecurityAudit, 100), // Buffered channel with capacity 100
|
|
done: make(chan struct{}),
|
|
}
|
|
// Start background goroutine to process audit events asynchronously
|
|
s.wg.Add(1)
|
|
go s.processAuditEvents()
|
|
return s
|
|
}
|
|
|
|
// Close gracefully stops the SecurityService and waits for audit processing to complete
|
|
func (s *SecurityService) Close() {
|
|
s.mu.Lock()
|
|
if s.closed {
|
|
s.mu.Unlock()
|
|
return // Already closed
|
|
}
|
|
s.closed = true
|
|
s.mu.Unlock()
|
|
|
|
close(s.done) // Signal the goroutine to stop
|
|
close(s.auditChan) // Close the audit channel
|
|
s.wg.Wait() // Wait for the goroutine to finish
|
|
}
|
|
|
|
// Flush processes all pending audit logs synchronously (useful for testing)
|
|
func (s *SecurityService) Flush() {
|
|
// Wait for all pending audits to be processed
|
|
// In practice, we wait for the channel to be empty and then a bit more
|
|
// to ensure the database write completes
|
|
for i := 0; i < 20; i++ { // Max 200ms wait
|
|
if len(s.auditChan) == 0 {
|
|
time.Sleep(10 * time.Millisecond) // Extra wait for DB write
|
|
return
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
}
|
|
|
|
// Get returns the first SecurityConfig row (singleton config)
|
|
func (s *SecurityService) Get() (*models.SecurityConfig, error) {
|
|
var cfg models.SecurityConfig
|
|
// Prefer the canonical singleton row named "default".
|
|
// Some environments may contain multiple rows (e.g., tests or prior data);
|
|
// returning an arbitrary "first" row can break break-glass token validation.
|
|
if err := s.db.Where("name = ?", "default").First(&cfg).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Backward compatibility: if there is no explicit "default" row,
|
|
// fall back to the first row if any exists.
|
|
if err2 := s.db.First(&cfg).Error; err2 != nil {
|
|
if errors.Is(err2, gorm.ErrRecordNotFound) {
|
|
return nil, ErrSecurityConfigNotFound
|
|
}
|
|
return nil, err2
|
|
}
|
|
} else {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
// Upsert validates and saves a security config
|
|
func (s *SecurityService) Upsert(cfg *models.SecurityConfig) error {
|
|
// Validate AdminWhitelist - comma-separated list of CIDRs
|
|
if cfg.AdminWhitelist != "" {
|
|
parts := strings.Split(cfg.AdminWhitelist, ",")
|
|
for _, p := range parts {
|
|
p = strings.TrimSpace(p)
|
|
if p == "" {
|
|
continue
|
|
}
|
|
// Validate as IP or CIDR using the same helper as AccessListService
|
|
if !isValidCIDR(p) {
|
|
return ErrInvalidAdminCIDR
|
|
}
|
|
}
|
|
}
|
|
|
|
// If a breakglass token is present in BreakGlassHash as empty string,
|
|
// do not overwrite it here. Token generation should be done explicitly.
|
|
|
|
// Validate CrowdSec mode on input prior to any DB operations: only 'local' or 'disabled' supported
|
|
if cfg.CrowdSecMode != "" && cfg.CrowdSecMode != "local" && cfg.CrowdSecMode != "disabled" {
|
|
return fmt.Errorf("invalid crowdsec mode: %s", cfg.CrowdSecMode)
|
|
}
|
|
|
|
// Upsert behaviour: try to find existing record
|
|
var existing models.SecurityConfig
|
|
if err := s.db.Where("name = ?", cfg.Name).First(&existing).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// New record
|
|
return s.db.Create(cfg).Error
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Preserve existing BreakGlassHash if not provided
|
|
if cfg.BreakGlassHash == "" {
|
|
cfg.BreakGlassHash = existing.BreakGlassHash
|
|
}
|
|
existing.Enabled = cfg.Enabled
|
|
existing.AdminWhitelist = cfg.AdminWhitelist
|
|
// Validate CrowdSec mode: only 'local' or 'disabled' supported. Reject external/remote values.
|
|
if cfg.CrowdSecMode != "" && cfg.CrowdSecMode != "local" && cfg.CrowdSecMode != "disabled" {
|
|
return fmt.Errorf("invalid crowdsec mode: %s", cfg.CrowdSecMode)
|
|
}
|
|
existing.CrowdSecMode = cfg.CrowdSecMode
|
|
existing.CrowdSecAPIURL = cfg.CrowdSecAPIURL
|
|
existing.WAFMode = cfg.WAFMode
|
|
existing.WAFRulesSource = cfg.WAFRulesSource
|
|
existing.WAFLearning = cfg.WAFLearning
|
|
existing.WAFParanoiaLevel = cfg.WAFParanoiaLevel
|
|
existing.WAFExclusions = cfg.WAFExclusions
|
|
existing.RateLimitEnable = cfg.RateLimitEnable
|
|
existing.RateLimitMode = cfg.RateLimitMode
|
|
existing.RateLimitBurst = cfg.RateLimitBurst
|
|
existing.RateLimitRequests = cfg.RateLimitRequests
|
|
existing.RateLimitWindowSec = cfg.RateLimitWindowSec
|
|
existing.RateLimitBypassList = cfg.RateLimitBypassList
|
|
|
|
return s.db.Save(&existing).Error
|
|
}
|
|
|
|
// GenerateBreakGlassToken generates a token, stores its bcrypt hash, and returns the plaintext token
|
|
func (s *SecurityService) GenerateBreakGlassToken(name string) (string, error) {
|
|
tokenBytes := make([]byte, 24)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return "", err
|
|
}
|
|
token := hex.EncodeToString(tokenBytes)
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var cfg models.SecurityConfig
|
|
if err := s.db.Where("name = ?", name).First(&cfg).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
cfg = models.SecurityConfig{Name: name, BreakGlassHash: string(hash)}
|
|
if createErr := s.db.Create(&cfg).Error; createErr != nil {
|
|
return "", createErr
|
|
}
|
|
return token, nil
|
|
}
|
|
return "", err
|
|
}
|
|
|
|
cfg.BreakGlassHash = string(hash)
|
|
if err := s.db.Save(&cfg).Error; err != nil {
|
|
return "", err
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
// VerifyBreakGlassToken validates a provided token against the stored hash
|
|
func (s *SecurityService) VerifyBreakGlassToken(name, token string) (bool, error) {
|
|
var cfg models.SecurityConfig
|
|
if err := s.db.Where("name = ?", name).First(&cfg).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return false, ErrSecurityConfigNotFound
|
|
}
|
|
return false, err
|
|
}
|
|
if cfg.BreakGlassHash == "" {
|
|
return false, ErrBreakGlassInvalid
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(cfg.BreakGlassHash), []byte(token)); err != nil {
|
|
return false, ErrBreakGlassInvalid
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// LogDecision stores a security decision record
|
|
func (s *SecurityService) LogDecision(d *models.SecurityDecision) error {
|
|
if d == nil {
|
|
return nil
|
|
}
|
|
if d.UUID == "" {
|
|
d.UUID = uuid.NewString()
|
|
}
|
|
if d.CreatedAt.IsZero() {
|
|
d.CreatedAt = time.Now()
|
|
}
|
|
return s.db.Create(d).Error
|
|
}
|
|
|
|
// ListDecisions returns recent security decisions, ordered by created_at desc
|
|
func (s *SecurityService) ListDecisions(limit int) ([]models.SecurityDecision, error) {
|
|
var res []models.SecurityDecision
|
|
q := s.db.Order("created_at desc")
|
|
if limit > 0 {
|
|
q = q.Limit(limit)
|
|
}
|
|
if err := q.Find(&res).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// LogAudit stores an audit entry asynchronously via buffered channel
|
|
func (s *SecurityService) LogAudit(a *models.SecurityAudit) error {
|
|
if a == nil {
|
|
return nil
|
|
}
|
|
if a.UUID == "" {
|
|
a.UUID = uuid.NewString()
|
|
}
|
|
if a.CreatedAt.IsZero() {
|
|
a.CreatedAt = time.Now()
|
|
}
|
|
|
|
// Non-blocking send to avoid blocking main operations
|
|
select {
|
|
case s.auditChan <- a:
|
|
return nil
|
|
default:
|
|
if err := s.persistAuditWithRetry(a); err != nil {
|
|
return fmt.Errorf("persist audit synchronously: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *SecurityService) persistAuditWithRetry(audit *models.SecurityAudit) error {
|
|
const maxAttempts = 5
|
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|
err := s.db.Create(audit).Error
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
errMsg := strings.ToLower(err.Error())
|
|
if strings.Contains(errMsg, "no such table") || strings.Contains(errMsg, "database is closed") {
|
|
return nil
|
|
}
|
|
|
|
isTransientLock := strings.Contains(errMsg, "database is locked") || strings.Contains(errMsg, "database table is locked") || strings.Contains(errMsg, "busy")
|
|
if isTransientLock && attempt < maxAttempts {
|
|
time.Sleep(time.Duration(attempt) * 10 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
if isTransientLock {
|
|
return nil
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// processAuditEvents processes audit events from the channel in the background
|
|
func (s *SecurityService) processAuditEvents() {
|
|
defer s.wg.Done() // Mark goroutine as done when it exits
|
|
|
|
for {
|
|
select {
|
|
case audit, ok := <-s.auditChan:
|
|
if !ok {
|
|
// Channel closed, exit goroutine
|
|
return
|
|
}
|
|
if err := s.persistAuditWithRetry(audit); err != nil {
|
|
// Silently ignore errors from closed databases (common in tests)
|
|
// Only log for other types of errors
|
|
errMsg := err.Error()
|
|
if !strings.Contains(errMsg, "no such table") &&
|
|
!strings.Contains(errMsg, "database is closed") {
|
|
fmt.Printf("Failed to write audit log: %v\n", err)
|
|
}
|
|
}
|
|
case <-s.done:
|
|
// Service is shutting down - drain remaining audit events before exiting
|
|
for audit := range s.auditChan {
|
|
if err := s.persistAuditWithRetry(audit); err != nil {
|
|
errMsg := err.Error()
|
|
if !strings.Contains(errMsg, "no such table") &&
|
|
!strings.Contains(errMsg, "database is closed") {
|
|
fmt.Printf("Failed to write audit log: %v\n", err)
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// AuditLogFilter represents filtering criteria for audit log queries
|
|
type AuditLogFilter struct {
|
|
Actor string
|
|
Action string
|
|
EventCategory string
|
|
ResourceUUID string
|
|
StartDate *time.Time
|
|
EndDate *time.Time
|
|
}
|
|
|
|
// ListAuditLogs retrieves audit logs with pagination and filtering
|
|
func (s *SecurityService) ListAuditLogs(filter AuditLogFilter, page, limit int) ([]models.SecurityAudit, int64, error) {
|
|
var audits []models.SecurityAudit
|
|
var total int64
|
|
|
|
// Build query with filters
|
|
query := s.db.Model(&models.SecurityAudit{})
|
|
|
|
if filter.Actor != "" {
|
|
query = query.Where("actor = ?", filter.Actor)
|
|
}
|
|
if filter.Action != "" {
|
|
query = query.Where("action = ?", filter.Action)
|
|
}
|
|
if filter.EventCategory != "" {
|
|
query = query.Where("event_category = ?", filter.EventCategory)
|
|
}
|
|
if filter.ResourceUUID != "" {
|
|
query = query.Where("resource_uuid = ?", filter.ResourceUUID)
|
|
}
|
|
if filter.StartDate != nil {
|
|
query = query.Where("created_at >= ?", *filter.StartDate)
|
|
}
|
|
if filter.EndDate != nil {
|
|
query = query.Where("created_at <= ?", *filter.EndDate)
|
|
}
|
|
|
|
// Get total count
|
|
if err := query.Count(&total).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Apply pagination
|
|
offset := (page - 1) * limit
|
|
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&audits).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
return audits, total, nil
|
|
}
|
|
|
|
// GetAuditLogByUUID retrieves a single audit log by its UUID
|
|
func (s *SecurityService) GetAuditLogByUUID(auditUUID string) (*models.SecurityAudit, error) {
|
|
var audit models.SecurityAudit
|
|
if err := s.db.Where("uuid = ?", auditUUID).First(&audit).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, errors.New("audit log not found")
|
|
}
|
|
return nil, err
|
|
}
|
|
return &audit, nil
|
|
}
|
|
|
|
// ListAuditLogsByProvider retrieves audit logs for a specific DNS provider with pagination
|
|
func (s *SecurityService) ListAuditLogsByProvider(providerID uint, page, limit int) ([]models.SecurityAudit, int64, error) {
|
|
var audits []models.SecurityAudit
|
|
var total int64
|
|
|
|
query := s.db.Model(&models.SecurityAudit{}).
|
|
Where("event_category = ? AND resource_id = ?", "dns_provider", providerID)
|
|
|
|
// Get total count
|
|
if err := query.Count(&total).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Apply pagination
|
|
offset := (page - 1) * limit
|
|
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&audits).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
return audits, total, nil
|
|
}
|
|
|
|
// UpsertRuleSet saves or updates a ruleset content
|
|
func (s *SecurityService) UpsertRuleSet(r *models.SecurityRuleSet) error {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
// Basic validations
|
|
if r.Name == "" {
|
|
return fmt.Errorf("rule set name required")
|
|
}
|
|
// Prevent huge payloads from being stored in DB (e.g., limit 2MB)
|
|
if len(r.Content) > 2*1024*1024 {
|
|
return fmt.Errorf("ruleset content too large")
|
|
}
|
|
var existing models.SecurityRuleSet
|
|
if err := s.db.Where("name = ?", r.Name).First(&existing).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
if r.UUID == "" {
|
|
r.UUID = uuid.NewString()
|
|
}
|
|
if r.LastUpdated.IsZero() {
|
|
r.LastUpdated = time.Now()
|
|
}
|
|
return s.db.Create(r).Error
|
|
}
|
|
return err
|
|
}
|
|
existing.SourceURL = r.SourceURL
|
|
existing.Content = r.Content
|
|
existing.Mode = r.Mode
|
|
existing.LastUpdated = r.LastUpdated
|
|
return s.db.Save(&existing).Error
|
|
}
|
|
|
|
// DeleteRuleSet removes a ruleset by id
|
|
func (s *SecurityService) DeleteRuleSet(id uint) error {
|
|
var rs models.SecurityRuleSet
|
|
if err := s.db.Where("id = ?", id).First(&rs).Error; err != nil {
|
|
return err
|
|
}
|
|
return s.db.Delete(&rs).Error
|
|
}
|
|
|
|
// ListRuleSets returns all known rulesets
|
|
func (s *SecurityService) ListRuleSets() ([]models.SecurityRuleSet, error) {
|
|
var res []models.SecurityRuleSet
|
|
if err := s.db.Find(&res).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// helper: reused from access_list_service validation for CIDR/IP parsing
|
|
func isValidCIDR(cidr string) bool {
|
|
// Try parsing as single IP
|
|
if ip := net.ParseIP(cidr); ip != nil {
|
|
return true
|
|
}
|
|
// Try parsing as CIDR
|
|
_, _, err := net.ParseCIDR(cidr)
|
|
return err == nil
|
|
}
|