Files
Charon/backend/internal/services/security_service.go

473 lines
14 KiB
Go

package services
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/Wikid82/charon/backend/internal/models"
"gorm.io/gorm"
)
var (
ErrSecurityConfigNotFound = errors.New("security config not found")
ErrInvalidAdminCIDR = errors.New("invalid admin whitelist CIDR")
ErrBreakGlassInvalid = errors.New("break-glass token invalid")
)
type SecurityService struct {
db *gorm.DB
auditChan chan *models.SecurityAudit
done chan struct{} // Channel to signal goroutine to stop
wg sync.WaitGroup // WaitGroup to track goroutine completion
closed bool // Flag to prevent double-close
mu sync.Mutex // Mutex to protect closed flag
}
// NewSecurityService returns a SecurityService using the provided DB
func NewSecurityService(db *gorm.DB) *SecurityService {
s := &SecurityService{
db: db,
auditChan: make(chan *models.SecurityAudit, 100), // Buffered channel with capacity 100
done: make(chan struct{}),
}
// Start background goroutine to process audit events asynchronously
s.wg.Add(1)
go s.processAuditEvents()
return s
}
// Close gracefully stops the SecurityService and waits for audit processing to complete
func (s *SecurityService) Close() {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return // Already closed
}
s.closed = true
s.mu.Unlock()
close(s.done) // Signal the goroutine to stop
close(s.auditChan) // Close the audit channel
s.wg.Wait() // Wait for the goroutine to finish
}
// Flush processes all pending audit logs synchronously (useful for testing)
func (s *SecurityService) Flush() {
// Wait for all pending audits to be processed
// In practice, we wait for the channel to be empty and then a bit more
// to ensure the database write completes
for i := 0; i < 20; i++ { // Max 200ms wait
if len(s.auditChan) == 0 {
time.Sleep(10 * time.Millisecond) // Extra wait for DB write
return
}
time.Sleep(10 * time.Millisecond)
}
}
// Get returns the first SecurityConfig row (singleton config)
func (s *SecurityService) Get() (*models.SecurityConfig, error) {
var cfg models.SecurityConfig
// Prefer the canonical singleton row named "default".
// Some environments may contain multiple rows (e.g., tests or prior data);
// returning an arbitrary "first" row can break break-glass token validation.
if err := s.db.Where("name = ?", "default").First(&cfg).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Backward compatibility: if there is no explicit "default" row,
// fall back to the first row if any exists.
if err2 := s.db.First(&cfg).Error; err2 != nil {
if errors.Is(err2, gorm.ErrRecordNotFound) {
return nil, ErrSecurityConfigNotFound
}
return nil, err2
}
} else {
return nil, err
}
}
return &cfg, nil
}
// Upsert validates and saves a security config
func (s *SecurityService) Upsert(cfg *models.SecurityConfig) error {
// Validate AdminWhitelist - comma-separated list of CIDRs
if cfg.AdminWhitelist != "" {
parts := strings.Split(cfg.AdminWhitelist, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// Validate as IP or CIDR using the same helper as AccessListService
if !isValidCIDR(p) {
return ErrInvalidAdminCIDR
}
}
}
// If a breakglass token is present in BreakGlassHash as empty string,
// do not overwrite it here. Token generation should be done explicitly.
// Validate CrowdSec mode on input prior to any DB operations: only 'local' or 'disabled' supported
if cfg.CrowdSecMode != "" && cfg.CrowdSecMode != "local" && cfg.CrowdSecMode != "disabled" {
return fmt.Errorf("invalid crowdsec mode: %s", cfg.CrowdSecMode)
}
// Upsert behaviour: try to find existing record
var existing models.SecurityConfig
if err := s.db.Where("name = ?", cfg.Name).First(&existing).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// New record
return s.db.Create(cfg).Error
}
return err
}
// Preserve existing BreakGlassHash if not provided
if cfg.BreakGlassHash == "" {
cfg.BreakGlassHash = existing.BreakGlassHash
}
existing.Enabled = cfg.Enabled
existing.AdminWhitelist = cfg.AdminWhitelist
// Validate CrowdSec mode: only 'local' or 'disabled' supported. Reject external/remote values.
if cfg.CrowdSecMode != "" && cfg.CrowdSecMode != "local" && cfg.CrowdSecMode != "disabled" {
return fmt.Errorf("invalid crowdsec mode: %s", cfg.CrowdSecMode)
}
existing.CrowdSecMode = cfg.CrowdSecMode
existing.CrowdSecAPIURL = cfg.CrowdSecAPIURL
existing.WAFMode = cfg.WAFMode
existing.WAFRulesSource = cfg.WAFRulesSource
existing.WAFLearning = cfg.WAFLearning
existing.WAFParanoiaLevel = cfg.WAFParanoiaLevel
existing.WAFExclusions = cfg.WAFExclusions
existing.RateLimitEnable = cfg.RateLimitEnable
existing.RateLimitBurst = cfg.RateLimitBurst
existing.RateLimitRequests = cfg.RateLimitRequests
existing.RateLimitWindowSec = cfg.RateLimitWindowSec
existing.RateLimitBypassList = cfg.RateLimitBypassList
return s.db.Save(&existing).Error
}
// GenerateBreakGlassToken generates a token, stores its bcrypt hash, and returns the plaintext token
func (s *SecurityService) GenerateBreakGlassToken(name string) (string, error) {
tokenBytes := make([]byte, 24)
if _, err := rand.Read(tokenBytes); err != nil {
return "", err
}
token := hex.EncodeToString(tokenBytes)
hash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost)
if err != nil {
return "", err
}
var cfg models.SecurityConfig
if err := s.db.Where("name = ?", name).First(&cfg).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
cfg = models.SecurityConfig{Name: name, BreakGlassHash: string(hash)}
if createErr := s.db.Create(&cfg).Error; createErr != nil {
return "", createErr
}
return token, nil
}
return "", err
}
cfg.BreakGlassHash = string(hash)
if err := s.db.Save(&cfg).Error; err != nil {
return "", err
}
return token, nil
}
// VerifyBreakGlassToken validates a provided token against the stored hash
func (s *SecurityService) VerifyBreakGlassToken(name, token string) (bool, error) {
var cfg models.SecurityConfig
if err := s.db.Where("name = ?", name).First(&cfg).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, ErrSecurityConfigNotFound
}
return false, err
}
if cfg.BreakGlassHash == "" {
return false, ErrBreakGlassInvalid
}
if err := bcrypt.CompareHashAndPassword([]byte(cfg.BreakGlassHash), []byte(token)); err != nil {
return false, ErrBreakGlassInvalid
}
return true, nil
}
// LogDecision stores a security decision record
func (s *SecurityService) LogDecision(d *models.SecurityDecision) error {
if d == nil {
return nil
}
if d.UUID == "" {
d.UUID = uuid.NewString()
}
if d.CreatedAt.IsZero() {
d.CreatedAt = time.Now()
}
return s.db.Create(d).Error
}
// ListDecisions returns recent security decisions, ordered by created_at desc
func (s *SecurityService) ListDecisions(limit int) ([]models.SecurityDecision, error) {
var res []models.SecurityDecision
q := s.db.Order("created_at desc")
if limit > 0 {
q = q.Limit(limit)
}
if err := q.Find(&res).Error; err != nil {
return nil, err
}
return res, nil
}
// LogAudit stores an audit entry asynchronously via buffered channel
func (s *SecurityService) LogAudit(a *models.SecurityAudit) error {
if a == nil {
return nil
}
if a.UUID == "" {
a.UUID = uuid.NewString()
}
if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now()
}
// Non-blocking send to avoid blocking main operations
select {
case s.auditChan <- a:
return nil
default:
if err := s.persistAuditWithRetry(a); err != nil {
return fmt.Errorf("persist audit synchronously: %w", err)
}
return nil
}
}
func (s *SecurityService) persistAuditWithRetry(audit *models.SecurityAudit) error {
const maxAttempts = 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
err := s.db.Create(audit).Error
if err == nil {
return nil
}
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, "no such table") || strings.Contains(errMsg, "database is closed") {
return nil
}
isTransientLock := strings.Contains(errMsg, "database is locked") || strings.Contains(errMsg, "database table is locked") || strings.Contains(errMsg, "busy")
if isTransientLock && attempt < maxAttempts {
time.Sleep(time.Duration(attempt) * 10 * time.Millisecond)
continue
}
if isTransientLock {
return nil
}
return err
}
return nil
}
// processAuditEvents processes audit events from the channel in the background
func (s *SecurityService) processAuditEvents() {
defer s.wg.Done() // Mark goroutine as done when it exits
for {
select {
case audit, ok := <-s.auditChan:
if !ok {
// Channel closed, exit goroutine
return
}
if err := s.persistAuditWithRetry(audit); err != nil {
// Silently ignore errors from closed databases (common in tests)
// Only log for other types of errors
errMsg := err.Error()
if !strings.Contains(errMsg, "no such table") &&
!strings.Contains(errMsg, "database is closed") {
fmt.Printf("Failed to write audit log: %v\n", err)
}
}
case <-s.done:
// Service is shutting down - drain remaining audit events before exiting
for audit := range s.auditChan {
if err := s.persistAuditWithRetry(audit); err != nil {
errMsg := err.Error()
if !strings.Contains(errMsg, "no such table") &&
!strings.Contains(errMsg, "database is closed") {
fmt.Printf("Failed to write audit log: %v\n", err)
}
}
}
return
}
}
}
// AuditLogFilter represents filtering criteria for audit log queries
type AuditLogFilter struct {
Actor string
Action string
EventCategory string
ResourceUUID string
StartDate *time.Time
EndDate *time.Time
}
// ListAuditLogs retrieves audit logs with pagination and filtering
func (s *SecurityService) ListAuditLogs(filter AuditLogFilter, page, limit int) ([]models.SecurityAudit, int64, error) {
var audits []models.SecurityAudit
var total int64
// Build query with filters
query := s.db.Model(&models.SecurityAudit{})
if filter.Actor != "" {
query = query.Where("actor = ?", filter.Actor)
}
if filter.Action != "" {
query = query.Where("action = ?", filter.Action)
}
if filter.EventCategory != "" {
query = query.Where("event_category = ?", filter.EventCategory)
}
if filter.ResourceUUID != "" {
query = query.Where("resource_uuid = ?", filter.ResourceUUID)
}
if filter.StartDate != nil {
query = query.Where("created_at >= ?", *filter.StartDate)
}
if filter.EndDate != nil {
query = query.Where("created_at <= ?", *filter.EndDate)
}
// Get total count
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// Apply pagination
offset := (page - 1) * limit
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&audits).Error; err != nil {
return nil, 0, err
}
return audits, total, nil
}
// GetAuditLogByUUID retrieves a single audit log by its UUID
func (s *SecurityService) GetAuditLogByUUID(auditUUID string) (*models.SecurityAudit, error) {
var audit models.SecurityAudit
if err := s.db.Where("uuid = ?", auditUUID).First(&audit).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("audit log not found")
}
return nil, err
}
return &audit, nil
}
// ListAuditLogsByProvider retrieves audit logs for a specific DNS provider with pagination
func (s *SecurityService) ListAuditLogsByProvider(providerID uint, page, limit int) ([]models.SecurityAudit, int64, error) {
var audits []models.SecurityAudit
var total int64
query := s.db.Model(&models.SecurityAudit{}).
Where("event_category = ? AND resource_id = ?", "dns_provider", providerID)
// Get total count
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// Apply pagination
offset := (page - 1) * limit
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&audits).Error; err != nil {
return nil, 0, err
}
return audits, total, nil
}
// UpsertRuleSet saves or updates a ruleset content
func (s *SecurityService) UpsertRuleSet(r *models.SecurityRuleSet) error {
if r == nil {
return nil
}
// Basic validations
if r.Name == "" {
return fmt.Errorf("rule set name required")
}
// Prevent huge payloads from being stored in DB (e.g., limit 2MB)
if len(r.Content) > 2*1024*1024 {
return fmt.Errorf("ruleset content too large")
}
var existing models.SecurityRuleSet
if err := s.db.Where("name = ?", r.Name).First(&existing).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if r.UUID == "" {
r.UUID = uuid.NewString()
}
if r.LastUpdated.IsZero() {
r.LastUpdated = time.Now()
}
return s.db.Create(r).Error
}
return err
}
existing.SourceURL = r.SourceURL
existing.Content = r.Content
existing.Mode = r.Mode
existing.LastUpdated = r.LastUpdated
return s.db.Save(&existing).Error
}
// DeleteRuleSet removes a ruleset by id
func (s *SecurityService) DeleteRuleSet(id uint) error {
var rs models.SecurityRuleSet
if err := s.db.Where("id = ?", id).First(&rs).Error; err != nil {
return err
}
return s.db.Delete(&rs).Error
}
// ListRuleSets returns all known rulesets
func (s *SecurityService) ListRuleSets() ([]models.SecurityRuleSet, error) {
var res []models.SecurityRuleSet
if err := s.db.Find(&res).Error; err != nil {
return nil, err
}
return res, nil
}
// helper: reused from access_list_service validation for CIDR/IP parsing
func isValidCIDR(cidr string) bool {
// Try parsing as single IP
if ip := net.ParseIP(cidr); ip != nil {
return true
}
// Try parsing as CIDR
_, _, err := net.ParseCIDR(cidr)
return err == nil
}