chore: clean .gitignore cache
This commit is contained in:
@@ -1,492 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
consoleStatusNotEnrolled = "not_enrolled"
|
||||
consoleStatusEnrolling = "enrolling"
|
||||
consoleStatusPendingAcceptance = "pending_acceptance"
|
||||
consoleStatusEnrolled = "enrolled"
|
||||
consoleStatusFailed = "failed"
|
||||
|
||||
defaultEnrollTimeout = 45 * time.Second
|
||||
)
|
||||
|
||||
var namePattern = regexp.MustCompile(`^[A-Za-z0-9_.\-]{1,64}$`)
|
||||
var enrollmentTokenPattern = regexp.MustCompile(`^[A-Za-z0-9]{10,64}$`)
|
||||
|
||||
// EnvCommandExecutor executes commands with optional environment overrides.
|
||||
type EnvCommandExecutor interface {
|
||||
ExecuteWithEnv(ctx context.Context, name string, args []string, env map[string]string) ([]byte, error)
|
||||
}
|
||||
|
||||
// SecureCommandExecutor is the production executor that avoids leaking args by passing secrets via env.
|
||||
type SecureCommandExecutor struct{}
|
||||
|
||||
// ExecuteWithEnv runs the command with provided env merged onto the current environment.
|
||||
func (r *SecureCommandExecutor) ExecuteWithEnv(ctx context.Context, name string, args []string, env map[string]string) ([]byte, error) {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
cmd.Env = append(os.Environ(), formatEnv(env)...)
|
||||
return cmd.CombinedOutput()
|
||||
}
|
||||
|
||||
func formatEnv(env map[string]string) []string {
|
||||
if len(env) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(env))
|
||||
for k, v := range env {
|
||||
result = append(result, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConsoleEnrollRequest captures enrollment input.
|
||||
type ConsoleEnrollRequest struct {
|
||||
EnrollmentKey string
|
||||
Tenant string
|
||||
AgentName string
|
||||
Force bool
|
||||
}
|
||||
|
||||
// ConsoleEnrollmentStatus is the safe, redacted status view.
|
||||
type ConsoleEnrollmentStatus struct {
|
||||
Status string `json:"status"`
|
||||
Tenant string `json:"tenant"`
|
||||
AgentName string `json:"agent_name"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
LastAttemptAt *time.Time `json:"last_attempt_at,omitempty"`
|
||||
EnrolledAt *time.Time `json:"enrolled_at,omitempty"`
|
||||
LastHeartbeatAt *time.Time `json:"last_heartbeat_at,omitempty"`
|
||||
KeyPresent bool `json:"key_present"`
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
}
|
||||
|
||||
// ConsoleEnrollmentService manages console enrollment lifecycle and persistence.
|
||||
type ConsoleEnrollmentService struct {
|
||||
db *gorm.DB
|
||||
exec EnvCommandExecutor
|
||||
dataDir string
|
||||
key []byte
|
||||
nowFn func() time.Time
|
||||
mu sync.Mutex
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewConsoleEnrollmentService constructs a service using the supplied secret material for encryption.
|
||||
func NewConsoleEnrollmentService(db *gorm.DB, executor EnvCommandExecutor, dataDir, secret string) *ConsoleEnrollmentService {
|
||||
return &ConsoleEnrollmentService{
|
||||
db: db,
|
||||
exec: executor,
|
||||
dataDir: dataDir,
|
||||
key: deriveKey(secret),
|
||||
nowFn: time.Now,
|
||||
timeout: defaultEnrollTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current enrollment state.
|
||||
func (s *ConsoleEnrollmentService) Status(ctx context.Context) (ConsoleEnrollmentStatus, error) {
|
||||
rec, err := s.load(ctx)
|
||||
if err != nil {
|
||||
return ConsoleEnrollmentStatus{}, err
|
||||
}
|
||||
return s.statusFromModel(rec), nil
|
||||
}
|
||||
|
||||
// Enroll performs an enrollment attempt. It is idempotent when already enrolled unless Force is set.
|
||||
func (s *ConsoleEnrollmentService) Enroll(ctx context.Context, req ConsoleEnrollRequest) (ConsoleEnrollmentStatus, error) {
|
||||
agent := strings.TrimSpace(req.AgentName)
|
||||
if agent == "" {
|
||||
return ConsoleEnrollmentStatus{}, fmt.Errorf("agent_name required")
|
||||
}
|
||||
if !namePattern.MatchString(agent) {
|
||||
return ConsoleEnrollmentStatus{}, fmt.Errorf("agent_name may only include letters, numbers, dot, dash, underscore")
|
||||
}
|
||||
tenant := strings.TrimSpace(req.Tenant)
|
||||
if tenant != "" && !namePattern.MatchString(tenant) {
|
||||
return ConsoleEnrollmentStatus{}, fmt.Errorf("tenant may only include letters, numbers, dot, dash, underscore")
|
||||
}
|
||||
token, err := normalizeEnrollmentKey(req.EnrollmentKey)
|
||||
if err != nil {
|
||||
return ConsoleEnrollmentStatus{}, err
|
||||
}
|
||||
if s.exec == nil {
|
||||
return ConsoleEnrollmentStatus{}, fmt.Errorf("executor unavailable")
|
||||
}
|
||||
|
||||
// CRITICAL: Check that LAPI is running before attempting enrollment
|
||||
// Console enrollment requires an active LAPI connection to register with crowdsec.net
|
||||
if err := s.checkLAPIAvailable(ctx); err != nil {
|
||||
return ConsoleEnrollmentStatus{}, err
|
||||
}
|
||||
|
||||
if err := s.ensureCAPIRegistered(ctx); err != nil {
|
||||
return ConsoleEnrollmentStatus{}, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
rec, err := s.load(ctx)
|
||||
if err != nil {
|
||||
return ConsoleEnrollmentStatus{}, err
|
||||
}
|
||||
|
||||
if rec.Status == consoleStatusEnrolling {
|
||||
return s.statusFromModel(rec), fmt.Errorf("enrollment already in progress")
|
||||
}
|
||||
// If already enrolled or pending acceptance, skip unless Force is set
|
||||
if (rec.Status == consoleStatusEnrolled || rec.Status == consoleStatusPendingAcceptance) && !req.Force {
|
||||
logger.Log().WithFields(map[string]any{
|
||||
"status": rec.Status,
|
||||
"agent_name": rec.AgentName,
|
||||
"tenant": rec.Tenant,
|
||||
}).Info("console enrollment skipped: already enrolled or pending acceptance - use force=true to re-enroll")
|
||||
return s.statusFromModel(rec), nil
|
||||
}
|
||||
|
||||
now := s.nowFn().UTC()
|
||||
rec.Status = consoleStatusEnrolling
|
||||
rec.AgentName = agent
|
||||
rec.Tenant = tenant
|
||||
rec.LastAttemptAt = &now
|
||||
rec.LastError = ""
|
||||
rec.LastCorrelationID = uuid.NewString()
|
||||
|
||||
encryptedKey, err := s.encrypt(token)
|
||||
if err != nil {
|
||||
return ConsoleEnrollmentStatus{}, fmt.Errorf("protect secret: %w", err)
|
||||
}
|
||||
rec.EncryptedEnrollKey = encryptedKey
|
||||
|
||||
if err := s.db.WithContext(ctx).Save(rec).Error; err != nil {
|
||||
return ConsoleEnrollmentStatus{}, err
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, s.timeout)
|
||||
defer cancel()
|
||||
|
||||
args := []string{"console", "enroll", "--name", agent}
|
||||
|
||||
// Add tenant as a tag if provided
|
||||
if tenant != "" {
|
||||
args = append(args, "--tags", fmt.Sprintf("tenant:%s", tenant))
|
||||
}
|
||||
|
||||
// Add overwrite flag if force is requested
|
||||
if req.Force {
|
||||
args = append(args, "--overwrite")
|
||||
}
|
||||
|
||||
// Add config path
|
||||
configPath := s.findConfigPath()
|
||||
if configPath != "" {
|
||||
args = append([]string{"-c", configPath}, args...)
|
||||
}
|
||||
|
||||
// Token is the last positional argument
|
||||
args = append(args, token)
|
||||
|
||||
logger.Log().WithField("tenant", tenant).WithField("agent", agent).WithField("force", req.Force).WithField("correlation_id", rec.LastCorrelationID).WithField("config", configPath).Info("starting crowdsec console enrollment")
|
||||
out, cmdErr := s.exec.ExecuteWithEnv(cmdCtx, "cscli", args, nil)
|
||||
|
||||
// Log command output for debugging (redacting the token)
|
||||
redactedOut := redactSecret(string(out), token)
|
||||
if cmdErr != nil {
|
||||
rec.Status = consoleStatusFailed
|
||||
// Redact token from both output and error message
|
||||
redactedErr := redactSecret(cmdErr.Error(), token)
|
||||
// Extract the meaningful error message from cscli output
|
||||
userMessage := extractCscliErrorMessage(redactedOut)
|
||||
if userMessage == "" {
|
||||
userMessage = redactedOut
|
||||
}
|
||||
rec.LastError = userMessage
|
||||
_ = s.db.WithContext(ctx).Save(rec)
|
||||
logger.Log().WithField("error", redactedErr).WithField("correlation_id", rec.LastCorrelationID).WithField("tenant", tenant).WithField("output", redactedOut).Warn("crowdsec console enrollment failed")
|
||||
return s.statusFromModel(rec), fmt.Errorf("%s", userMessage)
|
||||
}
|
||||
|
||||
logger.Log().WithField("correlation_id", rec.LastCorrelationID).WithField("output", redactedOut).Debug("cscli console enroll command output")
|
||||
|
||||
// Enrollment request was sent successfully, but user must still accept it on crowdsec.net.
|
||||
// cscli console enroll returns exit code 0 when the request is sent, NOT when enrollment is complete.
|
||||
// The CrowdSec help explicitly states: "After running this command your will need to validate the enrollment in the webapp."
|
||||
complete := s.nowFn().UTC()
|
||||
rec.Status = consoleStatusPendingAcceptance
|
||||
rec.LastAttemptAt = &complete
|
||||
rec.LastError = ""
|
||||
if err := s.db.WithContext(ctx).Save(rec).Error; err != nil {
|
||||
return ConsoleEnrollmentStatus{}, err
|
||||
}
|
||||
|
||||
logger.Log().WithField("tenant", tenant).WithField("agent", agent).WithField("correlation_id", rec.LastCorrelationID).Info("crowdsec console enrollment request sent - pending acceptance on crowdsec.net")
|
||||
return s.statusFromModel(rec), nil
|
||||
}
|
||||
|
||||
// checkLAPIAvailable verifies that CrowdSec Local API is running and reachable.
|
||||
// This is critical for console enrollment as the enrollment process requires LAPI.
|
||||
// It retries up to 3 times with 2-second delays to handle LAPI initialization timing.
|
||||
func (s *ConsoleEnrollmentService) checkLAPIAvailable(ctx context.Context) error {
|
||||
maxRetries := 3
|
||||
retryDelay := 2 * time.Second
|
||||
|
||||
var lastErr error
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
args := []string{"lapi", "status"}
|
||||
configPath := s.findConfigPath()
|
||||
if configPath != "" {
|
||||
args = append([]string{"-c", configPath}, args...)
|
||||
}
|
||||
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
out, err := s.exec.ExecuteWithEnv(checkCtx, "cscli", args, nil)
|
||||
cancel()
|
||||
|
||||
if err == nil {
|
||||
logger.Log().WithField("config", configPath).Debug("LAPI check succeeded")
|
||||
return nil // LAPI is available
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
if i < maxRetries-1 {
|
||||
logger.Log().WithError(err).WithField("attempt", i+1).WithField("output", string(out)).Debug("LAPI not ready, retrying")
|
||||
time.Sleep(retryDelay)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("CrowdSec Local API is not running after %d attempts - please wait for LAPI to initialize (typically 5-10 seconds after enabling CrowdSec): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (s *ConsoleEnrollmentService) ensureCAPIRegistered(ctx context.Context) error {
|
||||
// Check for credentials in config subdirectory first (standard layout),
|
||||
// then fall back to dataDir root for backward compatibility
|
||||
credsPath := filepath.Join(s.dataDir, "config", "online_api_credentials.yaml")
|
||||
if _, err := os.Stat(credsPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
credsPath = filepath.Join(s.dataDir, "online_api_credentials.yaml")
|
||||
if _, err := os.Stat(credsPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Log().Info("registering with crowdsec capi")
|
||||
args := []string{"capi", "register"}
|
||||
configPath := s.findConfigPath()
|
||||
if configPath != "" {
|
||||
args = append([]string{"-c", configPath}, args...)
|
||||
}
|
||||
|
||||
out, err := s.exec.ExecuteWithEnv(ctx, "cscli", args, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("capi register: %s: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// findConfigPath returns the path to the CrowdSec config file, checking
|
||||
// config subdirectory first (standard layout), then dataDir root.
|
||||
// Returns empty string if no config file is found.
|
||||
func (s *ConsoleEnrollmentService) findConfigPath() string {
|
||||
configPath := filepath.Join(s.dataDir, "config", "config.yaml")
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
return configPath
|
||||
}
|
||||
configPath = filepath.Join(s.dataDir, "config.yaml")
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
return configPath
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *ConsoleEnrollmentService) load(ctx context.Context) (*models.CrowdsecConsoleEnrollment, error) {
|
||||
var rec models.CrowdsecConsoleEnrollment
|
||||
err := s.db.WithContext(ctx).First(&rec).Error
|
||||
if err == nil {
|
||||
return &rec, nil
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
now := s.nowFn().UTC()
|
||||
rec = models.CrowdsecConsoleEnrollment{
|
||||
UUID: uuid.NewString(),
|
||||
Status: consoleStatusNotEnrolled,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := s.db.WithContext(ctx).Create(&rec).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rec, nil
|
||||
}
|
||||
|
||||
// ClearEnrollment resets the enrollment state to allow fresh enrollment.
|
||||
// This does NOT unenroll from crowdsec.net - that must be done manually on the console.
|
||||
func (s *ConsoleEnrollmentService) ClearEnrollment(ctx context.Context) error {
|
||||
if s.db == nil {
|
||||
return fmt.Errorf("database not initialized")
|
||||
}
|
||||
|
||||
var rec models.CrowdsecConsoleEnrollment
|
||||
if err := s.db.WithContext(ctx).First(&rec).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil // Already cleared
|
||||
}
|
||||
return fmt.Errorf("failed to find enrollment record: %w", err)
|
||||
}
|
||||
|
||||
logger.Log().WithField("previous_status", rec.Status).Info("clearing console enrollment state")
|
||||
|
||||
// Delete the record
|
||||
if err := s.db.WithContext(ctx).Delete(&rec).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete enrollment record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ConsoleEnrollmentService) statusFromModel(rec *models.CrowdsecConsoleEnrollment) ConsoleEnrollmentStatus {
|
||||
if rec == nil {
|
||||
return ConsoleEnrollmentStatus{Status: consoleStatusNotEnrolled}
|
||||
}
|
||||
return ConsoleEnrollmentStatus{
|
||||
Status: firstNonEmpty(rec.Status, consoleStatusNotEnrolled),
|
||||
Tenant: rec.Tenant,
|
||||
AgentName: rec.AgentName,
|
||||
LastError: rec.LastError,
|
||||
LastAttemptAt: rec.LastAttemptAt,
|
||||
EnrolledAt: rec.EnrolledAt,
|
||||
LastHeartbeatAt: rec.LastHeartbeatAt,
|
||||
KeyPresent: rec.EncryptedEnrollKey != "",
|
||||
CorrelationID: rec.LastCorrelationID,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConsoleEnrollmentService) encrypt(value string) (string, error) {
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
block, err := aes.NewCipher(s.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
sealed := gcm.Seal(nonce, nonce, []byte(value), nil)
|
||||
return base64.StdEncoding.EncodeToString(sealed), nil
|
||||
}
|
||||
|
||||
func deriveKey(secret string) []byte {
|
||||
if secret == "" {
|
||||
secret = "charon-console-enroll-default"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(secret))
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
func redactSecret(msg, secret string) string {
|
||||
if secret == "" {
|
||||
return msg
|
||||
}
|
||||
return strings.ReplaceAll(msg, secret, "<redacted>")
|
||||
}
|
||||
|
||||
// extractCscliErrorMessage extracts the meaningful error message from cscli output.
|
||||
// CrowdSec outputs error messages in formats like:
|
||||
// - "level=error msg=\"...\""
|
||||
// - "ERRO[...] ..."
|
||||
// - Plain error text
|
||||
func extractCscliErrorMessage(output string) string {
|
||||
output = strings.TrimSpace(output)
|
||||
if output == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to extract from level=error msg="..." format
|
||||
msgPattern := regexp.MustCompile(`msg="([^"]+)"`)
|
||||
if matches := msgPattern.FindStringSubmatch(output); len(matches) > 1 {
|
||||
return matches[1]
|
||||
}
|
||||
|
||||
// Try to extract from ERRO[...] format - get text after the timestamp bracket
|
||||
erroPattern := regexp.MustCompile(`ERRO\[[^\]]*\]\s*(.+)`)
|
||||
if matches := erroPattern.FindStringSubmatch(output); len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
// Try to find any line containing "error" or "failed" (case-insensitive)
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
lower := strings.ToLower(line)
|
||||
if strings.Contains(lower, "error") || strings.Contains(lower, "failed") || strings.Contains(lower, "invalid") {
|
||||
return strings.TrimSpace(line)
|
||||
}
|
||||
}
|
||||
|
||||
// If no pattern matched, return the first non-empty line (often the most relevant)
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func normalizeEnrollmentKey(raw string) (string, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", fmt.Errorf("enrollment_key required")
|
||||
}
|
||||
if enrollmentTokenPattern.MatchString(trimmed) {
|
||||
return trimmed, nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) == 0 {
|
||||
return "", fmt.Errorf("invalid enrollment key")
|
||||
}
|
||||
if strings.EqualFold(parts[0], "sudo") {
|
||||
parts = parts[1:]
|
||||
}
|
||||
|
||||
if len(parts) == 4 && parts[0] == "cscli" && parts[1] == "console" && parts[2] == "enroll" {
|
||||
token := parts[3]
|
||||
if enrollmentTokenPattern.MatchString(token) {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid enrollment key")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,111 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestApplyWithOpenFileHandles simulates the "device or resource busy" scenario
|
||||
// where the data directory has open file handles (e.g., from cache operations)
|
||||
func TestApplyWithOpenFileHandles(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
dataDir := filepath.Join(t.TempDir(), "crowdsec")
|
||||
require.NoError(t, os.MkdirAll(dataDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dataDir, "config.txt"), []byte("original"), 0o644))
|
||||
|
||||
// Create a subdirectory with nested files (similar to hub_cache)
|
||||
subDir := filepath.Join(dataDir, "hub_cache")
|
||||
require.NoError(t, os.MkdirAll(subDir, 0o755))
|
||||
cacheFile := filepath.Join(subDir, "cache.json")
|
||||
require.NoError(t, os.WriteFile(cacheFile, []byte(`{"test": "data"}`), 0o644))
|
||||
|
||||
// Open a file handle to simulate an in-use directory
|
||||
// This would cause os.Rename to fail with "device or resource busy" on some systems
|
||||
f, err := os.Open(cacheFile)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
// Create and cache a preset
|
||||
archive := makeTarGz(t, map[string]string{"new/preset.yaml": "new: preset"})
|
||||
_, err = cache.Store(context.Background(), "test/preset", "etag1", "hub", "preview", archive)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := NewHubService(nil, cache, dataDir)
|
||||
|
||||
// Apply should succeed using copy-based backup even with open file handles
|
||||
res, err := svc.Apply(context.Background(), "test/preset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "applied", res.Status)
|
||||
require.NotEmpty(t, res.BackupPath, "BackupPath should be set on success")
|
||||
|
||||
// Verify backup was created and contains the original files
|
||||
backupConfigPath := filepath.Join(res.BackupPath, "config.txt")
|
||||
backupCachePath := filepath.Join(res.BackupPath, "hub_cache", "cache.json")
|
||||
|
||||
// The backup should exist
|
||||
require.FileExists(t, backupConfigPath)
|
||||
require.FileExists(t, backupCachePath)
|
||||
|
||||
// Verify original content was preserved in backup
|
||||
content, err := os.ReadFile(backupConfigPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original", string(content))
|
||||
|
||||
cacheContent, err := os.ReadFile(backupCachePath)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(cacheContent), "test")
|
||||
|
||||
// Verify new preset was applied
|
||||
newPresetPath := filepath.Join(dataDir, "new", "preset.yaml")
|
||||
require.FileExists(t, newPresetPath)
|
||||
newContent, err := os.ReadFile(newPresetPath)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(newContent), "new: preset")
|
||||
}
|
||||
|
||||
// TestBackupPathOnlySetAfterSuccessfulBackup ensures that BackupPath is only
|
||||
// set in the result after a successful backup, not before attempting it.
|
||||
// This prevents misleading error messages that reference non-existent backups.
|
||||
func TestBackupPathOnlySetAfterSuccessfulBackup(t *testing.T) {
|
||||
t.Run("backup path not set when cache missing", func(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
dataDir := filepath.Join(t.TempDir(), "crowdsec")
|
||||
require.NoError(t, os.MkdirAll(dataDir, 0o755))
|
||||
|
||||
svc := NewHubService(nil, cache, dataDir)
|
||||
|
||||
// Try to apply a preset that doesn't exist in cache (no cscli available)
|
||||
res, err := svc.Apply(context.Background(), "nonexistent/preset")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, res.BackupPath, "BackupPath should be set when backup attempt is performed for rollback")
|
||||
})
|
||||
|
||||
t.Run("backup path set only after successful backup", func(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
dataDir := filepath.Join(t.TempDir(), "crowdsec")
|
||||
require.NoError(t, os.MkdirAll(dataDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dataDir, "file.txt"), []byte("data"), 0o644))
|
||||
|
||||
archive := makeTarGz(t, map[string]string{"new.yaml": "new: config"})
|
||||
_, err = cache.Store(context.Background(), "test/preset", "etag1", "hub", "preview", archive)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := NewHubService(nil, cache, dataDir)
|
||||
|
||||
res, err := svc.Apply(context.Background(), "test/preset")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.BackupPath, "BackupPath should be set after successful backup")
|
||||
require.FileExists(t, filepath.Join(res.BackupPath, "file.txt"), "Backup should contain original files")
|
||||
})
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
// Package crowdsec provides integration with CrowdSec for security decisions and remediation.
|
||||
package crowdsec
|
||||
@@ -1,265 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/util"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
ErrCacheExpired = errors.New("cache expired")
|
||||
)
|
||||
|
||||
// CachedPreset captures metadata about a pulled preset bundle.
|
||||
type CachedPreset struct {
|
||||
Slug string `json:"slug"`
|
||||
CacheKey string `json:"cache_key"`
|
||||
Etag string `json:"etag"`
|
||||
Source string `json:"source"`
|
||||
RetrievedAt time.Time `json:"retrieved_at"`
|
||||
PreviewPath string `json:"preview_path"`
|
||||
ArchivePath string `json:"archive_path"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
}
|
||||
|
||||
// HubCache persists pulled bundles on disk with TTL-based eviction.
|
||||
type HubCache struct {
|
||||
baseDir string
|
||||
ttl time.Duration
|
||||
nowFn func() time.Time
|
||||
}
|
||||
|
||||
var slugPattern = regexp.MustCompile(`^[A-Za-z0-9./_-]+$`)
|
||||
|
||||
// NewHubCache constructs a cache rooted at baseDir with the provided TTL.
|
||||
func NewHubCache(baseDir string, ttl time.Duration) (*HubCache, error) {
|
||||
if baseDir == "" {
|
||||
return nil, fmt.Errorf("baseDir required")
|
||||
}
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create cache dir: %w", err)
|
||||
}
|
||||
return &HubCache{baseDir: baseDir, ttl: ttl, nowFn: time.Now}, nil
|
||||
}
|
||||
|
||||
// TTL returns the configured time-to-live for cached entries.
|
||||
func (c *HubCache) TTL() time.Duration {
|
||||
return c.ttl
|
||||
}
|
||||
|
||||
// Store writes the bundle archive and preview to disk and returns the cache metadata.
|
||||
func (c *HubCache) Store(ctx context.Context, slug, etag, source, preview string, archive []byte) (CachedPreset, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return CachedPreset{}, err
|
||||
}
|
||||
cleanSlug := sanitizeSlug(slug)
|
||||
if cleanSlug == "" {
|
||||
return CachedPreset{}, fmt.Errorf("invalid slug")
|
||||
}
|
||||
dir := filepath.Join(c.baseDir, cleanSlug)
|
||||
logger.Log().WithField("slug", util.SanitizeForLog(cleanSlug)).WithField("cache_dir", util.SanitizeForLog(dir)).WithField("archive_size", len(archive)).Debug("storing preset in cache")
|
||||
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
logger.Log().WithError(err).WithField("dir", util.SanitizeForLog(dir)).Error("failed to create cache directory")
|
||||
return CachedPreset{}, fmt.Errorf("create slug dir: %w", err)
|
||||
}
|
||||
|
||||
ts := c.nowFn().UTC()
|
||||
cacheKey := fmt.Sprintf("%s-%d", cleanSlug, ts.Unix())
|
||||
|
||||
archivePath := filepath.Join(dir, "bundle.tgz")
|
||||
if err := os.WriteFile(archivePath, archive, 0o640); err != nil {
|
||||
return CachedPreset{}, fmt.Errorf("write archive: %w", err)
|
||||
}
|
||||
previewPath := filepath.Join(dir, "preview.yaml")
|
||||
if err := os.WriteFile(previewPath, []byte(preview), 0o640); err != nil {
|
||||
return CachedPreset{}, fmt.Errorf("write preview: %w", err)
|
||||
}
|
||||
|
||||
meta := CachedPreset{
|
||||
Slug: cleanSlug,
|
||||
CacheKey: cacheKey,
|
||||
Etag: etag,
|
||||
Source: source,
|
||||
RetrievedAt: ts,
|
||||
PreviewPath: previewPath,
|
||||
ArchivePath: archivePath,
|
||||
SizeBytes: int64(len(archive)),
|
||||
}
|
||||
metaPath := filepath.Join(dir, "metadata.json")
|
||||
raw, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
return CachedPreset{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(metaPath, raw, 0o640); err != nil {
|
||||
logger.Log().WithError(err).WithField("meta_path", util.SanitizeForLog(metaPath)).Error("failed to write metadata file")
|
||||
return CachedPreset{}, fmt.Errorf("write metadata: %w", err)
|
||||
}
|
||||
|
||||
logger.Log().WithField("slug", util.SanitizeForLog(cleanSlug)).WithField("cache_key", cacheKey).WithField("archive_path", util.SanitizeForLog(archivePath)).WithField("preview_path", util.SanitizeForLog(previewPath)).WithField("meta_path", util.SanitizeForLog(metaPath)).Info("preset successfully stored in cache")
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// Load returns cached preset metadata, enforcing TTL.
|
||||
func (c *HubCache) Load(ctx context.Context, slug string) (CachedPreset, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return CachedPreset{}, err
|
||||
}
|
||||
cleanSlug := sanitizeSlug(slug)
|
||||
if cleanSlug == "" {
|
||||
return CachedPreset{}, fmt.Errorf("invalid slug")
|
||||
}
|
||||
metaPath := filepath.Join(c.baseDir, cleanSlug, "metadata.json")
|
||||
logger.Log().WithField("slug", util.SanitizeForLog(cleanSlug)).WithField("meta_path", util.SanitizeForLog(metaPath)).Debug("attempting to load cached preset")
|
||||
|
||||
data, err := os.ReadFile(metaPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
logger.Log().WithField("slug", util.SanitizeForLog(cleanSlug)).WithField("meta_path", util.SanitizeForLog(metaPath)).Debug("preset not found in cache (cache miss)")
|
||||
return CachedPreset{}, ErrCacheMiss
|
||||
}
|
||||
logger.Log().WithError(err).WithField("slug", util.SanitizeForLog(cleanSlug)).WithField("meta_path", util.SanitizeForLog(metaPath)).Error("failed to read cached preset metadata")
|
||||
return CachedPreset{}, err
|
||||
}
|
||||
var meta CachedPreset
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
logger.Log().WithError(err).WithField("slug", util.SanitizeForLog(cleanSlug)).Error("failed to unmarshal cached preset metadata")
|
||||
return CachedPreset{}, fmt.Errorf("unmarshal metadata: %w", err)
|
||||
}
|
||||
|
||||
if c.ttl > 0 && c.nowFn().After(meta.RetrievedAt.Add(c.ttl)) {
|
||||
logger.Log().WithField("slug", util.SanitizeForLog(cleanSlug)).WithField("retrieved_at", meta.RetrievedAt).WithField("ttl", c.ttl).Debug("cached preset expired")
|
||||
return CachedPreset{}, ErrCacheExpired
|
||||
}
|
||||
|
||||
logger.Log().WithField("slug", util.SanitizeForLog(meta.Slug)).WithField("cache_key", meta.CacheKey).WithField("archive_path", util.SanitizeForLog(meta.ArchivePath)).Debug("successfully loaded cached preset")
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// LoadPreview returns the preview contents for a cached preset.
|
||||
func (c *HubCache) LoadPreview(ctx context.Context, slug string) (string, error) {
|
||||
meta, err := c.Load(ctx, slug)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := os.ReadFile(meta.PreviewPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// List returns cached presets that have not expired.
|
||||
func (c *HubCache) List(ctx context.Context) ([]CachedPreset, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results := make([]CachedPreset, 0)
|
||||
err := filepath.WalkDir(c.baseDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() || d.Name() != "metadata.json" {
|
||||
return nil
|
||||
}
|
||||
rel, err := filepath.Rel(c.baseDir, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
slug := filepath.Dir(rel)
|
||||
meta, err := c.Load(ctx, slug)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
results = append(results, meta)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Evict removes cached data for the given slug.
|
||||
func (c *HubCache) Evict(ctx context.Context, slug string) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
cleanSlug := sanitizeSlug(slug)
|
||||
if cleanSlug == "" {
|
||||
return fmt.Errorf("invalid slug")
|
||||
}
|
||||
return os.RemoveAll(filepath.Join(c.baseDir, cleanSlug))
|
||||
}
|
||||
|
||||
// sanitizeSlug guards against traversal and unsupported characters.
|
||||
func sanitizeSlug(slug string) string {
|
||||
trimmed := strings.TrimSpace(slug)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
cleaned := filepath.Clean(trimmed)
|
||||
cleaned = strings.ReplaceAll(cleaned, "\\", "/")
|
||||
if strings.HasPrefix(cleaned, "..") || strings.Contains(cleaned, string(os.PathSeparator)+"..") || strings.HasPrefix(cleaned, string(os.PathSeparator)) {
|
||||
return ""
|
||||
}
|
||||
if !slugPattern.MatchString(cleaned) {
|
||||
return ""
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// Exists returns true when a non-expired cache entry is present.
|
||||
func (c *HubCache) Exists(ctx context.Context, slug string) bool {
|
||||
if _, err := c.Load(ctx, slug); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Touch updates the timestamp to extend TTL; noop when missing.
|
||||
func (c *HubCache) Touch(ctx context.Context, slug string) error {
|
||||
meta, err := c.Load(ctx, slug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
meta.RetrievedAt = c.nowFn().UTC()
|
||||
raw, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metaPath := filepath.Join(c.baseDir, meta.Slug, "metadata.json")
|
||||
return os.WriteFile(metaPath, raw, 0o640)
|
||||
}
|
||||
|
||||
// Size returns aggregated size of cached archives (best effort).
|
||||
func (c *HubCache) Size(ctx context.Context) int64 {
|
||||
var total int64
|
||||
_ = filepath.WalkDir(c.baseDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
total += info.Size()
|
||||
return nil
|
||||
})
|
||||
return total
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHubCacheStoreLoadAndExpire(t *testing.T) {
|
||||
t.Parallel()
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
meta, err := cache.Store(ctx, "crowdsecurity/demo", "etag1", "hub", "preview-text", []byte("archive-bytes"))
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, meta.CacheKey)
|
||||
|
||||
loaded, err := cache.Load(ctx, "crowdsecurity/demo")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, meta.CacheKey, loaded.CacheKey)
|
||||
require.Equal(t, "etag1", loaded.Etag)
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(2 * time.Minute) }
|
||||
_, err = cache.Load(ctx, "crowdsecurity/demo")
|
||||
require.ErrorIs(t, err, ErrCacheExpired)
|
||||
}
|
||||
|
||||
func TestHubCacheRejectsBadSlug(t *testing.T) {
|
||||
t.Parallel()
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.Store(context.Background(), "../bad", "etag", "hub", "preview", []byte("data"))
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = cache.Store(context.Background(), "..\\bad", "etag", "hub", "preview", []byte("data"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheListAndEvict(t *testing.T) {
|
||||
t.Parallel()
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = cache.Store(ctx, "crowdsecurity/demo", "etag1", "hub", "preview", []byte("data1"))
|
||||
require.NoError(t, err)
|
||||
_, err = cache.Store(ctx, "crowdsecurity/other", "etag2", "hub", "preview", []byte("data2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
entries, err := cache.List(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
|
||||
require.NoError(t, cache.Evict(ctx, "crowdsecurity/demo"))
|
||||
entries, err = cache.List(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
require.Equal(t, "crowdsecurity/other", entries[0].Slug)
|
||||
}
|
||||
|
||||
func TestHubCacheTouchUpdatesTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
meta, err := cache.Store(ctx, "crowdsecurity/demo", "etag1", "hub", "preview", []byte("data1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(30 * time.Second) }
|
||||
require.NoError(t, cache.Touch(ctx, "crowdsecurity/demo"))
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(2 * time.Minute) }
|
||||
_, err = cache.Load(ctx, "crowdsecurity/demo")
|
||||
require.ErrorIs(t, err, ErrCacheExpired)
|
||||
}
|
||||
|
||||
func TestHubCachePreviewExistsAndSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
archive := []byte("archive-bytes-here")
|
||||
_, err = cache.Store(ctx, "crowdsecurity/demo", "etag123", "hub", "preview-content", archive)
|
||||
require.NoError(t, err)
|
||||
|
||||
preview, err := cache.LoadPreview(ctx, "crowdsecurity/demo")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "preview-content", preview)
|
||||
require.True(t, cache.Exists(ctx, "crowdsecurity/demo"))
|
||||
require.GreaterOrEqual(t, cache.Size(ctx), int64(len(archive)))
|
||||
}
|
||||
|
||||
func TestHubCacheExistsHonorsTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
meta, err := cache.Store(ctx, "crowdsecurity/demo", "etag123", "hub", "preview", []byte("data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(3 * time.Second) }
|
||||
require.False(t, cache.Exists(ctx, "crowdsecurity/demo"))
|
||||
}
|
||||
|
||||
func TestSanitizeSlugCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "demo/preset", sanitizeSlug(" demo/preset "))
|
||||
require.Equal(t, "", sanitizeSlug("../traverse"))
|
||||
require.Equal(t, "", sanitizeSlug("/abs/path"))
|
||||
require.Equal(t, "", sanitizeSlug("\\windows\\bad"))
|
||||
require.Equal(t, "", sanitizeSlug("bad spaces %"))
|
||||
}
|
||||
|
||||
func TestNewHubCacheRequiresBaseDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := NewHubCache("", time.Hour)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheTouchMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Touch(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrCacheMiss)
|
||||
}
|
||||
|
||||
func TestHubCacheTouchInvalidSlug(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Touch(context.Background(), "../bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheStoreContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err = cache.Store(ctx, "demo", "etag", "hub", "preview", []byte("data"))
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
func TestHubCacheLoadInvalidSlug(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.Load(context.Background(), "../bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheExistsContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
require.False(t, cache.Exists(ctx, "demo"))
|
||||
}
|
||||
|
||||
func TestHubCacheListSkipsExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Second)
|
||||
require.NoError(t, err)
|
||||
ctx := context.Background()
|
||||
fixed := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
cache.nowFn = func() time.Time { return fixed }
|
||||
_, err = cache.Store(ctx, "crowdsecurity/demo", "etag", "hub", "preview", []byte("data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.nowFn = func() time.Time { return fixed.Add(3 * time.Second) }
|
||||
entries, err := cache.List(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 0)
|
||||
}
|
||||
|
||||
func TestHubCacheEvictInvalidSlug(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
err = cache.Evict(context.Background(), "../bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheListContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = cache.List(ctx)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// TTL Tests
|
||||
// ============================================
|
||||
|
||||
func TestHubCacheTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("returns configured TTL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), 2*time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2*time.Hour, cache.TTL())
|
||||
})
|
||||
|
||||
t.Run("returns minute TTL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), time.Minute)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Minute, cache.TTL())
|
||||
})
|
||||
|
||||
t.Run("returns zero TTL if configured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache, err := NewHubCache(t.TempDir(), 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Duration(0), cache.TTL())
|
||||
})
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHubCacheStoreLoadAndExpire(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
meta, err := cache.Store(ctx, "crowdsecurity/demo", "etag1", "hub", "preview-text", []byte("archive-bytes"))
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, meta.CacheKey)
|
||||
|
||||
loaded, err := cache.Load(ctx, "crowdsecurity/demo")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, meta.CacheKey, loaded.CacheKey)
|
||||
require.Equal(t, "etag1", loaded.Etag)
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(2 * time.Minute) }
|
||||
_, err = cache.Load(ctx, "crowdsecurity/demo")
|
||||
require.ErrorIs(t, err, ErrCacheExpired)
|
||||
}
|
||||
|
||||
func TestHubCacheRejectsBadSlug(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.Store(context.Background(), "../bad", "etag", "hub", "preview", []byte("data"))
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = cache.Store(context.Background(), "..\\bad", "etag", "hub", "preview", []byte("data"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheListAndEvict(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = cache.Store(ctx, "crowdsecurity/demo", "etag1", "hub", "preview", []byte("data1"))
|
||||
require.NoError(t, err)
|
||||
_, err = cache.Store(ctx, "crowdsecurity/other", "etag2", "hub", "preview", []byte("data2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
entries, err := cache.List(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
|
||||
require.NoError(t, cache.Evict(ctx, "crowdsecurity/demo"))
|
||||
entries, err = cache.List(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
require.Equal(t, "crowdsecurity/other", entries[0].Slug)
|
||||
}
|
||||
|
||||
func TestHubCacheTouchUpdatesTTL(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
meta, err := cache.Store(ctx, "crowdsecurity/demo", "etag1", "hub", "preview", []byte("data1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(30 * time.Second) }
|
||||
require.NoError(t, cache.Touch(ctx, "crowdsecurity/demo"))
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(2 * time.Minute) }
|
||||
_, err = cache.Load(ctx, "crowdsecurity/demo")
|
||||
require.ErrorIs(t, err, ErrCacheExpired)
|
||||
}
|
||||
|
||||
func TestHubCachePreviewExistsAndSize(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
archive := []byte("archive-bytes-here")
|
||||
_, err = cache.Store(ctx, "crowdsecurity/demo", "etag123", "hub", "preview-content", archive)
|
||||
require.NoError(t, err)
|
||||
|
||||
preview, err := cache.LoadPreview(ctx, "crowdsecurity/demo")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "preview-content", preview)
|
||||
require.True(t, cache.Exists(ctx, "crowdsecurity/demo"))
|
||||
require.GreaterOrEqual(t, cache.Size(ctx), int64(len(archive)))
|
||||
}
|
||||
|
||||
func TestHubCacheExistsHonorsTTL(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
meta, err := cache.Store(ctx, "crowdsecurity/demo", "etag123", "hub", "preview", []byte("data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.nowFn = func() time.Time { return meta.RetrievedAt.Add(3 * time.Second) }
|
||||
require.False(t, cache.Exists(ctx, "crowdsecurity/demo"))
|
||||
}
|
||||
|
||||
func TestSanitizeSlugCases(t *testing.T) {
|
||||
require.Equal(t, "demo/preset", sanitizeSlug(" demo/preset "))
|
||||
require.Equal(t, "", sanitizeSlug("../traverse"))
|
||||
require.Equal(t, "", sanitizeSlug("/abs/path"))
|
||||
require.Equal(t, "", sanitizeSlug("\\windows\\bad"))
|
||||
require.Equal(t, "", sanitizeSlug("bad spaces %"))
|
||||
}
|
||||
|
||||
func TestNewHubCacheRequiresBaseDir(t *testing.T) {
|
||||
_, err := NewHubCache("", time.Hour)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheTouchMissing(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Touch(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrCacheMiss)
|
||||
}
|
||||
|
||||
func TestHubCacheTouchInvalidSlug(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Touch(context.Background(), "../bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheStoreContextCanceled(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err = cache.Store(ctx, "demo", "etag", "hub", "preview", []byte("data"))
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
func TestHubCacheLoadInvalidSlug(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.Load(context.Background(), "../bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheExistsContextCanceled(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
require.False(t, cache.Exists(ctx, "demo"))
|
||||
}
|
||||
|
||||
func TestHubCacheListSkipsExpired(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Second)
|
||||
require.NoError(t, err)
|
||||
ctx := context.Background()
|
||||
fixed := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
cache.nowFn = func() time.Time { return fixed }
|
||||
_, err = cache.Store(ctx, "crowdsecurity/demo", "etag", "hub", "preview", []byte("data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.nowFn = func() time.Time { return fixed.Add(3 * time.Second) }
|
||||
entries, err := cache.List(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 0)
|
||||
}
|
||||
|
||||
func TestHubCacheEvictInvalidSlug(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
err = cache.Evict(context.Background(), "../bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHubCacheListContextCanceled(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Hour)
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err = cache.List(ctx)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// TTL Tests
|
||||
// ============================================
|
||||
|
||||
func TestHubCacheTTL(t *testing.T) {
|
||||
t.Run("returns configured TTL", func(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), 2*time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2*time.Hour, cache.TTL())
|
||||
})
|
||||
|
||||
t.Run("returns minute TTL", func(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), time.Minute)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Minute, cache.TTL())
|
||||
})
|
||||
|
||||
t.Run("returns zero TTL if configured", func(t *testing.T) {
|
||||
cache, err := NewHubCache(t.TempDir(), 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Duration(0), cache.TTL())
|
||||
})
|
||||
}
|
||||
@@ -1,485 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubExec struct {
|
||||
responses map[string]error
|
||||
calls []string
|
||||
}
|
||||
|
||||
func (s *stubExec) Execute(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
cmd := strings.Join(append([]string{name}, args...), " ")
|
||||
s.calls = append(s.calls, cmd)
|
||||
for key, err := range s.responses {
|
||||
if strings.Contains(cmd, key) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return []byte("ok"), nil
|
||||
}
|
||||
|
||||
// TestPullThenApplyFlow verifies that pulling a preset and then applying it works correctly.
|
||||
func TestPullThenApplyFlow(t *testing.T) {
|
||||
// Create temp directories for cache and data
|
||||
cacheDir := t.TempDir()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// Create cache with 1 hour TTL
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a test archive
|
||||
archive := makeTestArchive(t, map[string]string{
|
||||
"config.yaml": "test: config\nvalue: 123",
|
||||
"profiles.yaml": "name: test",
|
||||
})
|
||||
|
||||
// Create hub service with mock HTTP client
|
||||
hub := NewHubService(nil, cache, dataDir)
|
||||
hub.HubBaseURL = "http://test.example.com"
|
||||
hub.HTTPClient = &http.Client{
|
||||
Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
switch req.URL.String() {
|
||||
case "http://test.example.com/api/index.json":
|
||||
body := `{"items":[{"name":"test/preset","title":"Test Preset","description":"Test","etag":"etag123","download_url":"http://test.example.com/test.tgz","preview_url":"http://test.example.com/test.yaml"}]}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
case "http://test.example.com/test.yaml":
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("test: preview\nkey: value")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
case "http://test.example.com/test.tgz":
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(archive)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
default:
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Step 1: Pull the preset
|
||||
t.Log("Step 1: Pulling preset")
|
||||
pullResult, err := hub.Pull(ctx, "test/preset")
|
||||
require.NoError(t, err, "Pull should succeed")
|
||||
require.Equal(t, "test/preset", pullResult.Meta.Slug)
|
||||
require.NotEmpty(t, pullResult.Meta.CacheKey)
|
||||
require.NotEmpty(t, pullResult.Preview)
|
||||
|
||||
// Verify cache files exist
|
||||
require.FileExists(t, pullResult.Meta.ArchivePath, "Archive should be cached")
|
||||
require.FileExists(t, pullResult.Meta.PreviewPath, "Preview should be cached")
|
||||
|
||||
// Read the cached files to verify content
|
||||
cachedArchive, err := os.ReadFile(pullResult.Meta.ArchivePath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, archive, cachedArchive, "Cached archive should match original")
|
||||
|
||||
cachedPreview, err := os.ReadFile(pullResult.Meta.PreviewPath)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(cachedPreview), "preview", "Cached preview should contain expected content")
|
||||
|
||||
t.Log("Step 2: Verifying cache can be loaded")
|
||||
// Verify we can load from cache
|
||||
loaded, err := cache.Load(ctx, "test/preset")
|
||||
require.NoError(t, err, "Should be able to load cached preset")
|
||||
require.Equal(t, pullResult.Meta.Slug, loaded.Slug)
|
||||
require.Equal(t, pullResult.Meta.CacheKey, loaded.CacheKey)
|
||||
|
||||
t.Log("Step 3: Applying preset from cache")
|
||||
// Step 2: Apply the preset (should use cached version)
|
||||
applyResult, err := hub.Apply(ctx, "test/preset")
|
||||
require.NoError(t, err, "Apply should succeed after pull")
|
||||
require.Equal(t, "applied", applyResult.Status)
|
||||
require.NotEmpty(t, applyResult.BackupPath)
|
||||
require.Equal(t, "test/preset", applyResult.AppliedPreset)
|
||||
|
||||
// Verify files were extracted to dataDir
|
||||
extractedConfig := filepath.Join(dataDir, "config.yaml")
|
||||
require.FileExists(t, extractedConfig, "Config should be extracted")
|
||||
content, err := os.ReadFile(extractedConfig)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(content), "test: config")
|
||||
}
|
||||
|
||||
func TestApplyRepullsOnCacheMissAfterCSCLIFailure(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
dataDir := filepath.Join(t.TempDir(), "data")
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
archive := makeTestArchive(t, map[string]string{"config.yaml": "test: repull"})
|
||||
|
||||
exec := &stubExec{responses: map[string]error{"install": fmt.Errorf("install failed")}}
|
||||
hub := NewHubService(exec, cache, dataDir)
|
||||
hub.HubBaseURL = "http://test.example.com"
|
||||
hub.HTTPClient = &http.Client{Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
switch req.URL.String() {
|
||||
case "http://test.example.com/api/index.json":
|
||||
body := `{"items":[{"name":"test/preset","title":"Test","etag":"e1"}]}`
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/test/preset.yaml":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("preview")), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/test/preset.tgz":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(archive)), Header: make(http.Header)}, nil
|
||||
default:
|
||||
return &http.Response{StatusCode: http.StatusNotFound, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil
|
||||
}
|
||||
})}
|
||||
|
||||
ctx := context.Background()
|
||||
res, err := hub.Apply(ctx, "test/preset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "applied", res.Status)
|
||||
require.False(t, res.UsedCSCLI)
|
||||
require.NotEmpty(t, res.CacheKey)
|
||||
|
||||
meta, loadErr := cache.Load(ctx, "test/preset")
|
||||
require.NoError(t, loadErr)
|
||||
require.Equal(t, res.CacheKey, meta.CacheKey)
|
||||
require.FileExists(t, filepath.Join(dataDir, "config.yaml"))
|
||||
}
|
||||
|
||||
func TestApplyRepullsOnCacheExpired(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
dataDir := filepath.Join(t.TempDir(), "data")
|
||||
cache, err := NewHubCache(cacheDir, 5*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
archive := makeTestArchive(t, map[string]string{"config.yaml": "test: expired"})
|
||||
ctx := context.Background()
|
||||
_, err = cache.Store(ctx, "expired/preset", "etag-old", "hub", "old", archive)
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
hub := NewHubService(nil, cache, dataDir)
|
||||
hub.HubBaseURL = "http://test.example.com"
|
||||
hub.HTTPClient = &http.Client{Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
switch req.URL.String() {
|
||||
case "http://test.example.com/api/index.json":
|
||||
body := `{"items":[{"name":"expired/preset","title":"Expired","etag":"e2"}]}`
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/expired/preset.yaml":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("preview new")), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/expired/preset.tgz":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(archive)), Header: make(http.Header)}, nil
|
||||
default:
|
||||
return &http.Response{StatusCode: http.StatusNotFound, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil
|
||||
}
|
||||
})}
|
||||
|
||||
res, err := hub.Apply(ctx, "expired/preset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "applied", res.Status)
|
||||
require.False(t, res.UsedCSCLI)
|
||||
|
||||
meta, loadErr := cache.Load(ctx, "expired/preset")
|
||||
require.NoError(t, loadErr)
|
||||
require.Equal(t, "e2", meta.Etag)
|
||||
require.FileExists(t, filepath.Join(dataDir, "config.yaml"))
|
||||
}
|
||||
|
||||
func TestPullAcceptsNamespacedIndexEntry(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
dataDir := filepath.Join(t.TempDir(), "data")
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
archive := makeTestArchive(t, map[string]string{"config.yaml": "test: namespaced"})
|
||||
|
||||
hub := NewHubService(nil, cache, dataDir)
|
||||
hub.HubBaseURL = "http://test.example.com"
|
||||
hub.HTTPClient = &http.Client{Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
switch req.URL.String() {
|
||||
case "http://test.example.com/api/index.json":
|
||||
body := `{"items":[{"name":"crowdsecurity/bot-mitigation-essentials","title":"Bot Mitigation Essentials","etag":"etag-bme"}]}`
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/crowdsecurity/bot-mitigation-essentials.yaml":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("namespaced preview")), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/crowdsecurity/bot-mitigation-essentials.tgz":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(archive)), Header: make(http.Header)}, nil
|
||||
default:
|
||||
return &http.Response{StatusCode: http.StatusNotFound, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil
|
||||
}
|
||||
})}
|
||||
|
||||
ctx := context.Background()
|
||||
res, err := hub.Pull(ctx, "bot-mitigation-essentials")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "bot-mitigation-essentials", res.Meta.Slug)
|
||||
require.Equal(t, "etag-bme", res.Meta.Etag)
|
||||
require.Contains(t, res.Preview, "namespaced preview")
|
||||
}
|
||||
|
||||
func TestHubFallbackToMirrorOnForbidden(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
dataDir := t.TempDir()
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
archive := makeTestArchive(t, map[string]string{"config.yaml": "mirror"})
|
||||
|
||||
hub := NewHubService(nil, cache, dataDir)
|
||||
hub.HubBaseURL = "http://primary.example.com"
|
||||
hub.MirrorBaseURL = "http://mirror.example.com"
|
||||
hub.HTTPClient = &http.Client{Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
switch req.URL.String() {
|
||||
case "http://primary.example.com/api/index.json":
|
||||
return &http.Response{StatusCode: http.StatusForbidden, Body: io.NopCloser(strings.NewReader("blocked")), Header: make(http.Header)}, nil
|
||||
case "http://mirror.example.com/api/index.json":
|
||||
body := `{"items":[{"name":"fallback/preset","title":"Fallback","etag":"etag-mirror"}]}`
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header)}, nil
|
||||
case "http://primary.example.com/fallback/preset.yaml":
|
||||
return &http.Response{StatusCode: http.StatusForbidden, Body: io.NopCloser(strings.NewReader("blocked")), Header: make(http.Header)}, nil
|
||||
case "http://mirror.example.com/fallback/preset.yaml":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("mirror preview")), Header: make(http.Header)}, nil
|
||||
case "http://primary.example.com/fallback/preset.tgz":
|
||||
return &http.Response{StatusCode: http.StatusForbidden, Body: io.NopCloser(strings.NewReader("blocked")), Header: make(http.Header)}, nil
|
||||
case "http://mirror.example.com/fallback/preset.tgz":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(archive)), Header: make(http.Header)}, nil
|
||||
default:
|
||||
return &http.Response{StatusCode: http.StatusNotFound, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil
|
||||
}
|
||||
})}
|
||||
|
||||
ctx := context.Background()
|
||||
res, err := hub.Pull(ctx, "fallback/preset")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "etag-mirror", res.Meta.Etag)
|
||||
require.Contains(t, res.Preview, "mirror preview")
|
||||
}
|
||||
|
||||
// TestApplyWithoutPullFails verifies that applying without pulling first fails with proper error.
|
||||
func TestApplyWithoutPullFails(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create hub service without cscli (nil executor) and empty cache
|
||||
hub := NewHubService(nil, cache, dataDir)
|
||||
hub.HubBaseURL = "http://test.example.com"
|
||||
hub.HTTPClient = &http.Client{Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusInternalServerError, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil
|
||||
})}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to apply without pulling first
|
||||
_, err = hub.Apply(ctx, "nonexistent/preset")
|
||||
require.Error(t, err, "Apply should fail without cache and without cscli")
|
||||
require.ErrorIs(t, err, ErrCacheMiss, "Error should expose cache miss for guidance")
|
||||
require.Contains(t, err.Error(), "refresh cache", "Error should surface repull failure context")
|
||||
}
|
||||
|
||||
// TestCacheExpiration verifies that expired cache is not used.
|
||||
func TestCacheExpiration(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
|
||||
// Create cache with very short TTL
|
||||
cache, err := NewHubCache(cacheDir, 1*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store a preset
|
||||
archive := makeTestArchive(t, map[string]string{"test.yaml": "content"})
|
||||
ctx := context.Background()
|
||||
cached, err := cache.Store(ctx, "test/preset", "etag1", "hub", "preview", archive)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Try to load - should get ErrCacheExpired
|
||||
_, err = cache.Load(ctx, "test/preset")
|
||||
require.ErrorIs(t, err, ErrCacheExpired, "Should get cache expired error")
|
||||
|
||||
// Verify the cache files still exist on disk (not deleted)
|
||||
require.FileExists(t, cached.ArchivePath, "Archive file should still exist")
|
||||
require.FileExists(t, cached.PreviewPath, "Preview file should still exist")
|
||||
}
|
||||
|
||||
// TestCacheListAfterPull verifies that pulled presets appear in cache list.
|
||||
func TestCacheListAfterPull(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
archive := makeTestArchive(t, map[string]string{"test.yaml": "content"})
|
||||
|
||||
hub := NewHubService(nil, cache, dataDir)
|
||||
hub.HubBaseURL = "http://test.example.com"
|
||||
hub.HTTPClient = &http.Client{
|
||||
Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
switch req.URL.String() {
|
||||
case "http://test.example.com/api/index.json":
|
||||
body := `{"items":[{"name":"preset1","title":"Preset 1","etag":"e1"}]}`
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/preset1.yaml":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("preview1")), Header: make(http.Header)}, nil
|
||||
case "http://test.example.com/preset1.tgz":
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(archive)), Header: make(http.Header)}, nil
|
||||
default:
|
||||
return &http.Response{StatusCode: http.StatusNotFound, Body: io.NopCloser(strings.NewReader("")), Header: make(http.Header)}, nil
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pull preset
|
||||
_, err = hub.Pull(ctx, "preset1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// List cache contents
|
||||
cached, err := cache.List(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cached, 1, "Should have one cached preset")
|
||||
require.Equal(t, "preset1", cached[0].Slug)
|
||||
}
|
||||
|
||||
// makeTestArchive creates a test tar.gz archive.
|
||||
func makeTestArchive(t *testing.T, files map[string]string) []byte {
|
||||
t.Helper()
|
||||
buf := &bytes.Buffer{}
|
||||
gw := gzip.NewWriter(buf)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
for name, content := range files {
|
||||
hdr := &tar.Header{
|
||||
Name: name,
|
||||
Mode: 0o644,
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
require.NoError(t, tw.WriteHeader(hdr))
|
||||
_, err := tw.Write([]byte(content))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, tw.Close())
|
||||
require.NoError(t, gw.Close())
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// mockTransport is a mock http.RoundTripper for testing.
|
||||
type mockTransport func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (m mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return m(req)
|
||||
}
|
||||
|
||||
// TestApplyReadsArchiveBeforeBackup verifies the fix for the bug where Apply() would:
|
||||
// 1. Load cache metadata (getting archive path inside DataDir/hub_cache)
|
||||
// 2. Backup DataDir (moving the cache including the archive!)
|
||||
// 3. Try to read archive from original path (FAIL - file no longer exists!)
|
||||
//
|
||||
// The fix reads the archive into memory BEFORE creating the backup, so the
|
||||
// archive data is preserved even after the backup operation moves the files.
|
||||
func TestApplyReadsArchiveBeforeBackup(t *testing.T) {
|
||||
// Create base directory structure that mirrors production:
|
||||
// baseDir/
|
||||
// └── crowdsec/ <- DataDir (gets backed up)
|
||||
// └── hub_cache/ <- Cache lives INSIDE DataDir (the bug!)
|
||||
// └── test/preset/
|
||||
// ├── bundle.tgz
|
||||
// ├── preview.yaml
|
||||
// └── metadata.json
|
||||
baseDir := t.TempDir()
|
||||
dataDir := filepath.Join(baseDir, "crowdsec")
|
||||
cacheDir := filepath.Join(dataDir, "hub_cache") // Cache INSIDE DataDir - this is key!
|
||||
|
||||
// Create DataDir with some existing config to make backup realistic
|
||||
require.NoError(t, os.MkdirAll(dataDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dataDir, "config.yaml"), []byte("existing: config"), 0o644))
|
||||
|
||||
// Create cache inside DataDir
|
||||
cache, err := NewHubCache(cacheDir, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test archive
|
||||
archive := makeTestArchive(t, map[string]string{
|
||||
"config.yaml": "test: applied_config\nvalue: 123",
|
||||
"profiles.yaml": "name: test_profile",
|
||||
})
|
||||
|
||||
// Pre-populate cache (simulating a prior Pull operation)
|
||||
ctx := context.Background()
|
||||
cachedMeta, err := cache.Store(ctx, "test/preset", "etag-pre", "hub", "preview: content", archive)
|
||||
require.NoError(t, err)
|
||||
require.FileExists(t, cachedMeta.ArchivePath, "Archive should exist in cache")
|
||||
|
||||
// Verify cache is inside DataDir (the scenario that triggers the bug)
|
||||
require.True(t, strings.HasPrefix(cachedMeta.ArchivePath, dataDir),
|
||||
"Cache archive must be inside DataDir for this test to be valid")
|
||||
|
||||
// Create hub service WITHOUT cscli (nil executor) to force cache fallback path
|
||||
hub := NewHubService(nil, cache, dataDir)
|
||||
hub.HubBaseURL = "http://test.example.com"
|
||||
// HTTP client that fails everything - we don't want to hit network
|
||||
hub.HTTPClient = &http.Client{
|
||||
Transport: mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader("intentionally failing")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
// Apply - this SHOULD succeed because:
|
||||
// 1. Archive is read into memory BEFORE backup
|
||||
// 2. Backup moves DataDir (including cache) but we already have archive bytes
|
||||
// 3. Extract uses the in-memory archive bytes
|
||||
//
|
||||
// BEFORE THE FIX, this would fail with:
|
||||
// "read archive: open /tmp/.../crowdsec/hub_cache/.../bundle.tgz: no such file or directory"
|
||||
result, err := hub.Apply(ctx, "test/preset")
|
||||
require.NoError(t, err, "Apply should succeed - archive must be read before backup")
|
||||
require.Equal(t, "applied", result.Status)
|
||||
require.NotEmpty(t, result.BackupPath, "Backup should have been created")
|
||||
require.False(t, result.UsedCSCLI, "Should have used cache fallback, not cscli")
|
||||
|
||||
// Verify backup was created
|
||||
_, statErr := os.Stat(result.BackupPath)
|
||||
require.NoError(t, statErr, "Backup directory should exist")
|
||||
|
||||
// Verify files were extracted to DataDir
|
||||
extractedConfig := filepath.Join(dataDir, "config.yaml")
|
||||
require.FileExists(t, extractedConfig, "Config should be extracted")
|
||||
content, err := os.ReadFile(extractedConfig)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(content), "test: applied_config",
|
||||
"Extracted config should contain content from archive, not original")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,65 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFetchIndexParsesRawIndexFormat(t *testing.T) {
|
||||
svc := NewHubService(nil, nil, t.TempDir())
|
||||
svc.HubBaseURL = "http://example.com"
|
||||
|
||||
// This JSON represents the "raw" index format (map of maps) which has no "items" field.
|
||||
// json.Unmarshal into HubIndex will succeed but result in empty Items.
|
||||
// The fix should detect this and fall back to parseRawIndex.
|
||||
rawIndexBody := `{
|
||||
"collections": {
|
||||
"crowdsecurity/base-http-scenarios": {
|
||||
"path": "collections/crowdsecurity/base-http-scenarios.yaml",
|
||||
"version": "0.1",
|
||||
"description": "Base HTTP scenarios"
|
||||
}
|
||||
},
|
||||
"parsers": {
|
||||
"crowdsecurity/nginx-logs": {
|
||||
"path": "parsers/s01-parse/crowdsecurity/nginx-logs.yaml",
|
||||
"version": "1.0",
|
||||
"description": "Parse Nginx logs"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
svc.HTTPClient = &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.String() == "http://example.com"+defaultHubIndexPath {
|
||||
resp := newResponse(http.StatusOK, rawIndexBody)
|
||||
resp.Header.Set("Content-Type", "application/json")
|
||||
return resp, nil
|
||||
}
|
||||
return newResponse(http.StatusNotFound, ""), nil
|
||||
})}
|
||||
|
||||
idx, err := svc.FetchIndex(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, idx.Items)
|
||||
|
||||
// Verify we found the items
|
||||
foundCollection := false
|
||||
foundParser := false
|
||||
for _, item := range idx.Items {
|
||||
if item.Name == "crowdsecurity/base-http-scenarios" {
|
||||
foundCollection = true
|
||||
require.Equal(t, "collections", item.Type)
|
||||
require.Equal(t, "0.1", item.Version)
|
||||
}
|
||||
if item.Name == "crowdsecurity/nginx-logs" {
|
||||
foundParser = true
|
||||
require.Equal(t, "parsers", item.Type)
|
||||
require.Equal(t, "1.0", item.Version)
|
||||
}
|
||||
}
|
||||
require.True(t, foundCollection, "should find collection from raw index")
|
||||
require.True(t, foundParser, "should find parser from raw index")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,55 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
// Preset represents a curated CrowdSec preset offered by Charon.
|
||||
type Preset struct {
|
||||
Slug string `json:"slug"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
Source string `json:"source"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
RequiresHub bool `json:"requires_hub"`
|
||||
}
|
||||
|
||||
var curatedPresets = []Preset{
|
||||
{
|
||||
Slug: "honeypot-friendly-defaults",
|
||||
Title: "Honeypot Friendly Defaults",
|
||||
Summary: "Lightweight parser and collection set tuned to reduce noise for tarpits and honeypots.",
|
||||
Source: "charon-curated",
|
||||
Tags: []string{"low-noise", "ssh", "http"},
|
||||
RequiresHub: false,
|
||||
},
|
||||
{
|
||||
Slug: "crowdsecurity/base-http-scenarios",
|
||||
Title: "Bot Mitigation Essentials",
|
||||
Summary: "Core scenarios for bad bots and credential stuffing with minimal false positives (maps to base-http-scenarios).",
|
||||
Source: "hub",
|
||||
Tags: []string{"bots", "auth", "web"},
|
||||
RequiresHub: true,
|
||||
},
|
||||
{
|
||||
Slug: "geolocation-aware",
|
||||
Title: "Geolocation Aware",
|
||||
Summary: "Adds geo-aware decisions to tighten access by region; best paired with existing ACLs.",
|
||||
Source: "charon-curated",
|
||||
Tags: []string{"geo", "access-control"},
|
||||
RequiresHub: false,
|
||||
},
|
||||
}
|
||||
|
||||
// ListCuratedPresets returns a copy of curated presets to avoid external mutation.
|
||||
func ListCuratedPresets() []Preset {
|
||||
out := make([]Preset, len(curatedPresets))
|
||||
copy(out, curatedPresets)
|
||||
return out
|
||||
}
|
||||
|
||||
// FindPreset returns a preset by slug.
|
||||
func FindPreset(slug string) (Preset, bool) {
|
||||
for _, p := range curatedPresets {
|
||||
if p.Slug == slug {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return Preset{}, false
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestListCuratedPresetsReturnsCopy(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := ListCuratedPresets()
|
||||
if len(got) == 0 {
|
||||
t.Fatalf("expected curated presets, got none")
|
||||
}
|
||||
|
||||
// mutate the copy and ensure originals stay intact on subsequent calls
|
||||
got[0].Title = "mutated"
|
||||
again := ListCuratedPresets()
|
||||
if again[0].Title == "mutated" {
|
||||
t.Fatalf("expected curated presets to be returned as copy, but mutation leaked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPreset(t *testing.T) {
|
||||
t.Parallel()
|
||||
preset, ok := FindPreset("honeypot-friendly-defaults")
|
||||
if !ok {
|
||||
t.Fatalf("expected to find curated preset")
|
||||
}
|
||||
if preset.Slug != "honeypot-friendly-defaults" {
|
||||
t.Fatalf("unexpected preset slug %s", preset.Slug)
|
||||
}
|
||||
if preset.Title == "" {
|
||||
t.Fatalf("expected preset to have a title")
|
||||
}
|
||||
if preset.Summary == "" {
|
||||
t.Fatalf("expected preset to have a summary")
|
||||
}
|
||||
|
||||
if _, ok := FindPreset("missing"); ok {
|
||||
t.Fatalf("expected missing preset to return ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPresetCaseVariants(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
slug string
|
||||
found bool
|
||||
}{
|
||||
{"exact match", "crowdsecurity/base-http-scenarios", true},
|
||||
{"another preset", "geolocation-aware", true},
|
||||
{"case sensitive miss", "BOT-MITIGATION-ESSENTIALS", false},
|
||||
{"partial match miss", "bot-mitigation", false},
|
||||
{"empty slug", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ok := FindPreset(tt.slug)
|
||||
if ok != tt.found {
|
||||
t.Errorf("FindPreset(%q) found=%v, want %v", tt.slug, ok, tt.found)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCuratedPresetsReturnsDifferentCopy(t *testing.T) {
|
||||
t.Parallel()
|
||||
list1 := ListCuratedPresets()
|
||||
list2 := ListCuratedPresets()
|
||||
|
||||
if len(list1) == 0 {
|
||||
t.Fatalf("expected non-empty preset list")
|
||||
}
|
||||
|
||||
// Verify mutating one copy doesn't affect the other
|
||||
list1[0].Title = "MODIFIED"
|
||||
if list2[0].Title == "MODIFIED" {
|
||||
t.Fatalf("expected independent copies but mutation leaked")
|
||||
}
|
||||
|
||||
// Verify subsequent calls return fresh copies
|
||||
list3 := ListCuratedPresets()
|
||||
if list3[0].Title == "MODIFIED" {
|
||||
t.Fatalf("mutation leaked to fresh copy")
|
||||
}
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestListCuratedPresetsReturnsCopy(t *testing.T) {
|
||||
got := ListCuratedPresets()
|
||||
if len(got) == 0 {
|
||||
t.Fatalf("expected curated presets, got none")
|
||||
}
|
||||
|
||||
// mutate the copy and ensure originals stay intact on subsequent calls
|
||||
got[0].Title = "mutated"
|
||||
again := ListCuratedPresets()
|
||||
if again[0].Title == "mutated" {
|
||||
t.Fatalf("expected curated presets to be returned as copy, but mutation leaked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPreset(t *testing.T) {
|
||||
preset, ok := FindPreset("honeypot-friendly-defaults")
|
||||
if !ok {
|
||||
t.Fatalf("expected to find curated preset")
|
||||
}
|
||||
if preset.Slug != "honeypot-friendly-defaults" {
|
||||
t.Fatalf("unexpected preset slug %s", preset.Slug)
|
||||
}
|
||||
if preset.Title == "" {
|
||||
t.Fatalf("expected preset to have a title")
|
||||
}
|
||||
if preset.Summary == "" {
|
||||
t.Fatalf("expected preset to have a summary")
|
||||
}
|
||||
|
||||
if _, ok := FindPreset("missing"); ok {
|
||||
t.Fatalf("expected missing preset to return ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPresetCaseVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slug string
|
||||
found bool
|
||||
}{
|
||||
{"exact match", "crowdsecurity/base-http-scenarios", true},
|
||||
{"another preset", "geolocation-aware", true},
|
||||
{"case sensitive miss", "BOT-MITIGATION-ESSENTIALS", false},
|
||||
{"partial match miss", "bot-mitigation", false},
|
||||
{"empty slug", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, ok := FindPreset(tt.slug)
|
||||
if ok != tt.found {
|
||||
t.Errorf("FindPreset(%q) found=%v, want %v", tt.slug, ok, tt.found)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCuratedPresetsReturnsDifferentCopy(t *testing.T) {
|
||||
list1 := ListCuratedPresets()
|
||||
list2 := ListCuratedPresets()
|
||||
|
||||
if len(list1) == 0 {
|
||||
t.Fatalf("expected non-empty preset list")
|
||||
}
|
||||
|
||||
// Verify mutating one copy doesn't affect the other
|
||||
list1[0].Title = "MODIFIED"
|
||||
if list2[0].Title == "MODIFIED" {
|
||||
t.Fatalf("expected independent copies but mutation leaked")
|
||||
}
|
||||
|
||||
// Verify subsequent calls return fresh copies
|
||||
list3 := ListCuratedPresets()
|
||||
if list3[0].Title == "MODIFIED" {
|
||||
t.Fatalf("mutation leaked to fresh copy")
|
||||
}
|
||||
}
|
||||
@@ -1,335 +0,0 @@
|
||||
// Package crowdsec provides integration with CrowdSec for security decisions and remediation.
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
neturl "net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/network"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultLAPIURL is the default CrowdSec LAPI URL.
|
||||
// Port 8085 is used to avoid conflict with Charon management API on port 8080.
|
||||
defaultLAPIURL = "http://127.0.0.1:8085"
|
||||
defaultHealthTimeout = 5 * time.Second
|
||||
defaultRegistrationName = "caddy-bouncer"
|
||||
)
|
||||
|
||||
// BouncerRegistration holds information about a registered bouncer.
|
||||
type BouncerRegistration struct {
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"api_key"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Valid bool `json:"valid"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
}
|
||||
|
||||
// LAPIHealthResponse represents the health check response from CrowdSec LAPI.
|
||||
type LAPIHealthResponse struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// validateLAPIURL validates a CrowdSec LAPI URL for security (SSRF protection - MEDIUM-001).
|
||||
// CrowdSec LAPI typically runs on localhost or within an internal network.
|
||||
// This function ensures the URL:
|
||||
// 1. Uses only http/https schemes
|
||||
// 2. Points to localhost OR is explicitly within allowed private networks
|
||||
// 3. Does not point to arbitrary external URLs
|
||||
//
|
||||
// Returns: error if URL is invalid or suspicious
|
||||
func validateLAPIURL(lapiURL string) error {
|
||||
// Empty URL defaults to localhost, which is safe
|
||||
if lapiURL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parsed, err := neturl.Parse(lapiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LAPI URL format: %w", err)
|
||||
}
|
||||
|
||||
// Only allow http/https
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return fmt.Errorf("LAPI URL must use http or https scheme (got: %s)", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
if host == "" {
|
||||
return fmt.Errorf("missing hostname in LAPI URL")
|
||||
}
|
||||
|
||||
// Allow localhost addresses (CrowdSec typically runs locally)
|
||||
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For non-localhost, the LAPI URL should be explicitly configured
|
||||
// and point to an internal service. We accept RFC 1918 private IPs
|
||||
// but log a warning for operational visibility.
|
||||
// This prevents accidental/malicious configuration to external URLs.
|
||||
|
||||
// Parse IP to check if it's in private range
|
||||
// If not an IP, it's a hostname - for security, we only allow
|
||||
// localhost hostnames or IPs. Custom hostnames could resolve to
|
||||
// arbitrary locations via DNS.
|
||||
|
||||
// Note: This is a conservative approach. If you need to allow
|
||||
// specific internal hostnames, add them to an allowlist.
|
||||
|
||||
return fmt.Errorf("LAPI URL must be localhost for security (got: %s). For remote LAPI, ensure it's on a trusted internal network", host)
|
||||
}
|
||||
|
||||
// EnsureBouncerRegistered checks if a caddy bouncer is registered with CrowdSec LAPI.
|
||||
// If not registered and cscli is available, it will attempt to register one.
|
||||
// Returns the API key for the bouncer (from env var or newly registered).
|
||||
func EnsureBouncerRegistered(ctx context.Context, lapiURL string) (string, error) {
|
||||
// CRITICAL FIX: Validate LAPI URL before making requests (MEDIUM-001)
|
||||
if err := validateLAPIURL(lapiURL); err != nil {
|
||||
return "", fmt.Errorf("LAPI URL validation failed: %w", err)
|
||||
}
|
||||
|
||||
// First check if API key is provided via environment
|
||||
apiKey := getBouncerAPIKey()
|
||||
if apiKey != "" {
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Check if cscli is available
|
||||
if !hasCSCLI() {
|
||||
return "", fmt.Errorf("no API key provided and cscli not available for bouncer registration")
|
||||
}
|
||||
|
||||
// Check if bouncer already exists
|
||||
existing, err := getExistingBouncer(ctx, defaultRegistrationName)
|
||||
if err == nil && existing.APIKey != "" {
|
||||
return existing.APIKey, nil
|
||||
}
|
||||
|
||||
// Register new bouncer using cscli
|
||||
return registerBouncer(ctx, defaultRegistrationName)
|
||||
}
|
||||
|
||||
// CheckLAPIHealth verifies CrowdSec LAPI is responding.
|
||||
func CheckLAPIHealth(lapiURL string) bool {
|
||||
if lapiURL == "" {
|
||||
lapiURL = defaultLAPIURL
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultHealthTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Try the /health endpoint first (standard LAPI health check)
|
||||
healthURL := strings.TrimRight(lapiURL, "/") + "/health"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, http.NoBody)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use SSRF-safe HTTP client with localhost allowed (LAPI is localhost-only)
|
||||
client := network.NewSafeHTTPClient(
|
||||
network.WithTimeout(defaultHealthTimeout),
|
||||
network.WithAllowLocalhost(), // LAPI validated to be localhost only
|
||||
)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// Fallback: try the /v1/decisions endpoint with a HEAD request
|
||||
return checkDecisionsEndpoint(ctx, lapiURL)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
logger.Log().WithError(err).Warn("Failed to close response body")
|
||||
}
|
||||
}()
|
||||
|
||||
// Check content-type to ensure we're getting JSON from actual LAPI (not HTML from frontend)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "" && !strings.Contains(contentType, "application/json") {
|
||||
// Not JSON response, likely hitting a frontend/proxy
|
||||
return false
|
||||
}
|
||||
|
||||
// LAPI returns 200 OK for healthy status
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return true
|
||||
}
|
||||
|
||||
// If health endpoint returned non-OK, try decisions endpoint fallback
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return checkDecisionsEndpoint(ctx, lapiURL)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetLAPIVersion retrieves the CrowdSec LAPI version.
|
||||
func GetLAPIVersion(ctx context.Context, lapiURL string) (string, error) {
|
||||
if lapiURL == "" {
|
||||
lapiURL = defaultLAPIURL
|
||||
}
|
||||
|
||||
versionURL := strings.TrimRight(lapiURL, "/") + "/v1/version"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL, http.NoBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create version request: %w", err)
|
||||
}
|
||||
|
||||
// Use SSRF-safe HTTP client with localhost allowed (LAPI is localhost-only)
|
||||
client := network.NewSafeHTTPClient(
|
||||
network.WithTimeout(defaultHealthTimeout),
|
||||
network.WithAllowLocalhost(), // LAPI validated to be localhost only
|
||||
)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("version request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
logger.Log().WithError(err).Warn("Failed to close response body")
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("version request returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read version response: %w", err)
|
||||
}
|
||||
|
||||
var versionResp struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &versionResp); err != nil {
|
||||
// Some versions return plain text
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
}
|
||||
|
||||
return versionResp.Version, nil
|
||||
}
|
||||
|
||||
// checkDecisionsEndpoint is a fallback health check using the decisions endpoint.
|
||||
func checkDecisionsEndpoint(ctx context.Context, lapiURL string) bool {
|
||||
decisionsURL := strings.TrimRight(lapiURL, "/") + "/v1/decisions"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, decisionsURL, http.NoBody)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use SSRF-safe HTTP client with localhost allowed (LAPI is localhost-only)
|
||||
client := network.NewSafeHTTPClient(
|
||||
network.WithTimeout(defaultHealthTimeout),
|
||||
network.WithAllowLocalhost(), // LAPI validated to be localhost only
|
||||
)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
logger.Log().WithError(err).Warn("Failed to close response body")
|
||||
}
|
||||
}()
|
||||
|
||||
// Check content-type to avoid false positives from HTML responses
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "" && !strings.Contains(contentType, "application/json") {
|
||||
// Not JSON response, likely hitting a frontend/proxy
|
||||
return false
|
||||
}
|
||||
|
||||
// 401 is expected without auth, but indicates LAPI is running
|
||||
// 200 with empty array is also valid (no decisions)
|
||||
return resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// getBouncerAPIKey returns the bouncer API key from environment variables.
|
||||
func getBouncerAPIKey() string {
|
||||
// Check multiple possible env var names for the API key
|
||||
envVars := []string{
|
||||
"CROWDSEC_API_KEY",
|
||||
"CROWDSEC_BOUNCER_API_KEY",
|
||||
"CERBERUS_SECURITY_CROWDSEC_API_KEY",
|
||||
"CHARON_SECURITY_CROWDSEC_API_KEY",
|
||||
"CPM_SECURITY_CROWDSEC_API_KEY",
|
||||
}
|
||||
|
||||
for _, key := range envVars {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// hasCSCLI checks if cscli command is available.
|
||||
func hasCSCLI() bool {
|
||||
_, err := exec.LookPath("cscli")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// getExistingBouncer retrieves an existing bouncer registration by name.
|
||||
func getExistingBouncer(ctx context.Context, name string) (BouncerRegistration, error) {
|
||||
cmd := exec.CommandContext(ctx, "cscli", "bouncers", "list", "-o", "json")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return BouncerRegistration{}, fmt.Errorf("list bouncers: %w", err)
|
||||
}
|
||||
|
||||
var bouncers []struct {
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"api_key"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Valid bool `json:"valid"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(output, &bouncers); err != nil {
|
||||
return BouncerRegistration{}, fmt.Errorf("parse bouncers: %w", err)
|
||||
}
|
||||
|
||||
for _, b := range bouncers {
|
||||
if b.Name == name {
|
||||
var createdAt time.Time
|
||||
if b.CreatedAt != "" {
|
||||
createdAt, _ = time.Parse(time.RFC3339, b.CreatedAt)
|
||||
}
|
||||
return BouncerRegistration{
|
||||
Name: b.Name,
|
||||
APIKey: b.APIKey,
|
||||
IPAddress: b.IPAddress,
|
||||
Valid: b.Valid,
|
||||
CreatedAt: createdAt,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return BouncerRegistration{}, fmt.Errorf("bouncer %q not found", name)
|
||||
}
|
||||
|
||||
// registerBouncer registers a new bouncer with CrowdSec using cscli.
|
||||
func registerBouncer(ctx context.Context, name string) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, "cscli", "bouncers", "add", name, "-o", "raw")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("register bouncer: %w", err)
|
||||
}
|
||||
|
||||
apiKey := strings.TrimSpace(string(output))
|
||||
if apiKey == "" {
|
||||
return "", fmt.Errorf("empty API key returned from bouncer registration")
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
@@ -1,414 +0,0 @@
|
||||
package crowdsec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func writeFakeCSCLI(t *testing.T, script string) (binDir string) {
|
||||
t.Helper()
|
||||
|
||||
binDir = t.TempDir()
|
||||
path := filepath.Join(binDir, "cscli")
|
||||
requireMode := fs.FileMode(0o755)
|
||||
|
||||
err := os.WriteFile(path, []byte(script), requireMode)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write fake cscli: %v", err)
|
||||
}
|
||||
return binDir
|
||||
}
|
||||
|
||||
func withEnv(t *testing.T, key, value string, fn func()) {
|
||||
t.Helper()
|
||||
old, had := os.LookupEnv(key)
|
||||
if value == "" {
|
||||
_ = os.Unsetenv(key)
|
||||
} else {
|
||||
_ = os.Setenv(key, value)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if had {
|
||||
_ = os.Setenv(key, old)
|
||||
} else {
|
||||
_ = os.Unsetenv(key)
|
||||
}
|
||||
})
|
||||
fn()
|
||||
}
|
||||
|
||||
func withPath(t *testing.T, newPath string, fn func()) {
|
||||
t.Helper()
|
||||
old, had := os.LookupEnv("PATH")
|
||||
_ = os.Setenv("PATH", newPath)
|
||||
t.Cleanup(func() {
|
||||
if had {
|
||||
_ = os.Setenv("PATH", old)
|
||||
} else {
|
||||
_ = os.Unsetenv("PATH")
|
||||
}
|
||||
})
|
||||
fn()
|
||||
}
|
||||
|
||||
func TestCheckLAPIHealth_Healthy(t *testing.T) {
|
||||
// Create a mock LAPI server that returns 200 OK with JSON content-type
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/health" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
healthy := CheckLAPIHealth(server.URL)
|
||||
assert.True(t, healthy, "LAPI should be healthy")
|
||||
}
|
||||
|
||||
func TestCheckLAPIHealth_Unhealthy(t *testing.T) {
|
||||
// Create a mock LAPI server that returns 500
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
healthy := CheckLAPIHealth(server.URL)
|
||||
assert.False(t, healthy, "LAPI should be unhealthy")
|
||||
}
|
||||
|
||||
func TestCheckLAPIHealth_Unreachable(t *testing.T) {
|
||||
// Use an invalid URL that won't connect
|
||||
healthy := CheckLAPIHealth("http://127.0.0.1:19999")
|
||||
assert.False(t, healthy, "LAPI should be unreachable")
|
||||
}
|
||||
|
||||
func TestCheckLAPIHealth_FallbackToDecisions(t *testing.T) {
|
||||
// Create a mock LAPI server where /health fails but /v1/decisions returns 401
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/health" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/v1/decisions" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized) // Expected without auth
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
healthy := CheckLAPIHealth(server.URL)
|
||||
// Should fallback to decisions endpoint check which returns 401 (indicates running)
|
||||
assert.True(t, healthy, "LAPI should be healthy via decisions fallback")
|
||||
}
|
||||
|
||||
func TestCheckLAPIHealth_DefaultURL(t *testing.T) {
|
||||
// With empty URL, should use default (which won't be running in test)
|
||||
healthy := CheckLAPIHealth("")
|
||||
assert.False(t, healthy, "Default LAPI should not be running in test environment")
|
||||
}
|
||||
|
||||
func TestGetBouncerAPIKey_FromEnv(t *testing.T) {
|
||||
// Save and restore original env
|
||||
original := os.Getenv("CROWDSEC_API_KEY")
|
||||
defer func() {
|
||||
if original != "" {
|
||||
_ = os.Setenv("CROWDSEC_API_KEY", original)
|
||||
} else {
|
||||
_ = os.Unsetenv("CROWDSEC_API_KEY")
|
||||
}
|
||||
}()
|
||||
|
||||
// Set test value
|
||||
_ = os.Setenv("CROWDSEC_API_KEY", "test-api-key-123")
|
||||
|
||||
key := getBouncerAPIKey()
|
||||
assert.Equal(t, "test-api-key-123", key)
|
||||
}
|
||||
|
||||
func TestGetBouncerAPIKey_Empty(t *testing.T) {
|
||||
// Save and restore original env vars
|
||||
envVars := []string{
|
||||
"CROWDSEC_API_KEY",
|
||||
"CROWDSEC_BOUNCER_API_KEY",
|
||||
"CERBERUS_SECURITY_CROWDSEC_API_KEY",
|
||||
"CHARON_SECURITY_CROWDSEC_API_KEY",
|
||||
"CPM_SECURITY_CROWDSEC_API_KEY",
|
||||
}
|
||||
|
||||
originals := make(map[string]string)
|
||||
for _, key := range envVars {
|
||||
originals[key] = os.Getenv(key)
|
||||
_ = os.Unsetenv(key)
|
||||
}
|
||||
defer func() {
|
||||
for key, val := range originals {
|
||||
if val != "" {
|
||||
_ = os.Setenv(key, val)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
key := getBouncerAPIKey()
|
||||
assert.Empty(t, key)
|
||||
}
|
||||
|
||||
func TestGetBouncerAPIKey_Fallback(t *testing.T) {
|
||||
// Test fallback to secondary env var
|
||||
envVars := []string{
|
||||
"CROWDSEC_API_KEY",
|
||||
"CROWDSEC_BOUNCER_API_KEY",
|
||||
"CERBERUS_SECURITY_CROWDSEC_API_KEY",
|
||||
"CHARON_SECURITY_CROWDSEC_API_KEY",
|
||||
"CPM_SECURITY_CROWDSEC_API_KEY",
|
||||
}
|
||||
|
||||
originals := make(map[string]string)
|
||||
for _, key := range envVars {
|
||||
originals[key] = os.Getenv(key)
|
||||
_ = os.Unsetenv(key)
|
||||
}
|
||||
defer func() {
|
||||
for key, val := range originals {
|
||||
if val != "" {
|
||||
_ = os.Setenv(key, val)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Set only the fallback env var
|
||||
_ = os.Setenv("CERBERUS_SECURITY_CROWDSEC_API_KEY", "fallback-key-456")
|
||||
|
||||
key := getBouncerAPIKey()
|
||||
assert.Equal(t, "fallback-key-456", key)
|
||||
}
|
||||
|
||||
func TestEnsureBouncerRegistered_UsesEnvKey(t *testing.T) {
|
||||
withEnv(t, "CROWDSEC_API_KEY", "env-key", func() {
|
||||
key, err := EnsureBouncerRegistered(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "env-key", key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnsureBouncerRegistered_NoEnvNoCSCLI(t *testing.T) {
|
||||
// Ensure all key env vars are empty
|
||||
for _, k := range []string{"CROWDSEC_API_KEY", "CROWDSEC_BOUNCER_API_KEY", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"} {
|
||||
withEnv(t, k, "", func() {})
|
||||
}
|
||||
|
||||
withPath(t, "", func() {
|
||||
_, err := EnsureBouncerRegistered(context.Background(), "")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnsureBouncerRegistered_ReturnsExistingBouncerKey(t *testing.T) {
|
||||
for _, k := range []string{"CROWDSEC_API_KEY", "CROWDSEC_BOUNCER_API_KEY", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"} {
|
||||
withEnv(t, k, "", func() {})
|
||||
}
|
||||
|
||||
origPath := os.Getenv("PATH")
|
||||
|
||||
binDir := writeFakeCSCLI(t, `#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
if [[ "$1" == "bouncers" && "$2" == "list" ]]; then
|
||||
echo '[{"name":"caddy-bouncer","api_key":"existing-key","ip_address":"","valid":true,"created_at":"2025-01-01T00:00:00Z"}]'
|
||||
exit 0
|
||||
fi
|
||||
echo "unexpected args" >&2
|
||||
exit 2
|
||||
`)
|
||||
|
||||
withPath(t, binDir+":"+origPath, func() {
|
||||
key, err := EnsureBouncerRegistered(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "existing-key", key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnsureBouncerRegistered_RegistersNewWhenNoneExists(t *testing.T) {
|
||||
for _, k := range []string{"CROWDSEC_API_KEY", "CROWDSEC_BOUNCER_API_KEY", "CERBERUS_SECURITY_CROWDSEC_API_KEY", "CHARON_SECURITY_CROWDSEC_API_KEY", "CPM_SECURITY_CROWDSEC_API_KEY"} {
|
||||
withEnv(t, k, "", func() {})
|
||||
}
|
||||
|
||||
origPath := os.Getenv("PATH")
|
||||
|
||||
binDir := writeFakeCSCLI(t, `#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
if [[ "$1" == "bouncers" && "$2" == "list" ]]; then
|
||||
echo '[]'
|
||||
exit 0
|
||||
fi
|
||||
if [[ "$1" == "bouncers" && "$2" == "add" ]]; then
|
||||
echo 'new-key'
|
||||
exit 0
|
||||
fi
|
||||
echo "unexpected args" >&2
|
||||
exit 2
|
||||
`)
|
||||
|
||||
withPath(t, binDir+":"+origPath, func() {
|
||||
key, err := EnsureBouncerRegistered(context.Background(), "")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "new-key", key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetLAPIVersion_JSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/version" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"version":"1.2.3"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ver, err := GetLAPIVersion(context.Background(), server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1.2.3", ver)
|
||||
}
|
||||
|
||||
func TestGetLAPIVersion_PlainText(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/version" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("vX.Y.Z\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ver, err := GetLAPIVersion(context.Background(), server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "vX.Y.Z", ver)
|
||||
}
|
||||
|
||||
func TestValidateLAPIURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid localhost with port",
|
||||
url: "http://localhost:8085",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid 127.0.0.1",
|
||||
url: "http://127.0.0.1:8085",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "external URL blocked",
|
||||
url: "http://evil.com",
|
||||
wantErr: true,
|
||||
errContains: "must be localhost",
|
||||
},
|
||||
{
|
||||
name: "HTTPS localhost",
|
||||
url: "https://localhost:8085",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid scheme",
|
||||
url: "ftp://localhost:8085",
|
||||
wantErr: true,
|
||||
errContains: "scheme",
|
||||
},
|
||||
{
|
||||
name: "no scheme",
|
||||
url: "localhost:8085",
|
||||
wantErr: true,
|
||||
errContains: "scheme",
|
||||
},
|
||||
{
|
||||
name: "empty URL allowed (defaults to localhost)",
|
||||
url: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
url: "http://[::1]:8085",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "private IP 192.168.x.x blocked (security)",
|
||||
url: "http://192.168.1.100:8085",
|
||||
wantErr: true,
|
||||
errContains: "must be localhost",
|
||||
},
|
||||
{
|
||||
name: "private IP 10.x.x.x blocked (security)",
|
||||
url: "http://10.0.0.50:8085",
|
||||
wantErr: true,
|
||||
errContains: "must be localhost",
|
||||
},
|
||||
{
|
||||
name: "missing hostname",
|
||||
url: "http://:8085",
|
||||
wantErr: true,
|
||||
errContains: "missing hostname",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateLAPIURL(tt.url)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureBouncerRegistered_InvalidURL(t *testing.T) {
|
||||
// Test that SSRF validation is applied
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "external URL rejected",
|
||||
url: "http://attacker.com:8085",
|
||||
errContains: "must be localhost",
|
||||
},
|
||||
{
|
||||
name: "invalid scheme rejected",
|
||||
url: "ftp://localhost:8085",
|
||||
errContains: "scheme",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := EnsureBouncerRegistered(context.Background(), tt.url)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Moved</title></head>
|
||||
<body><h1>Moved</h1><p>Resource moved.</p></body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user