Files
Charon/backend/internal/services/credential_service.go
2026-03-04 18:34:49 +00:00

654 lines
20 KiB
Go

package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/Wikid82/charon/backend/internal/crypto"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/Wikid82/charon/backend/internal/models"
"github.com/google/uuid"
"golang.org/x/net/idna"
"gorm.io/gorm"
)
var (
// ErrCredentialNotFound is returned when a credential is not found.
ErrCredentialNotFound = errors.New("credential not found")
// ErrNoMatchingCredential is returned when no credential matches the domain.
ErrNoMatchingCredential = errors.New("no matching credential found for domain")
// ErrMultiCredentialNotEnabled is returned when trying to use multi-credential features on a provider that doesn't have it enabled.
ErrMultiCredentialNotEnabled = errors.New("multi-credential mode not enabled for this provider")
)
// CreateCredentialRequest represents the request to create a new credential.
type CreateCredentialRequest struct {
Label string `json:"label" binding:"required"`
ZoneFilter string `json:"zone_filter"` // Comma-separated domains
Credentials map[string]string `json:"credentials" binding:"required"`
PropagationTimeout int `json:"propagation_timeout"`
PollingInterval int `json:"polling_interval"`
Enabled bool `json:"enabled"`
}
// UpdateCredentialRequest represents the request to update a credential.
type UpdateCredentialRequest struct {
Label *string `json:"label"`
ZoneFilter *string `json:"zone_filter"`
Credentials map[string]string `json:"credentials,omitempty"`
PropagationTimeout *int `json:"propagation_timeout"`
PollingInterval *int `json:"polling_interval"`
Enabled *bool `json:"enabled"`
}
// CredentialService provides operations for managing DNS provider credentials.
type CredentialService interface {
List(ctx context.Context, providerID uint) ([]models.DNSProviderCredential, error)
Get(ctx context.Context, providerID, credentialID uint) (*models.DNSProviderCredential, error)
Create(ctx context.Context, providerID uint, req CreateCredentialRequest) (*models.DNSProviderCredential, error)
Update(ctx context.Context, providerID, credentialID uint, req UpdateCredentialRequest) (*models.DNSProviderCredential, error)
Delete(ctx context.Context, providerID, credentialID uint) error
Test(ctx context.Context, providerID, credentialID uint) (*TestResult, error)
GetCredentialForDomain(ctx context.Context, providerID uint, domain string) (*models.DNSProviderCredential, error)
EnableMultiCredentials(ctx context.Context, providerID uint) error
}
// credentialService implements the CredentialService interface.
type credentialService struct {
db *gorm.DB
encryptor *crypto.EncryptionService
rotationService *crypto.RotationService
securityService *SecurityService
}
// NewCredentialService creates a new credential service.
func NewCredentialService(db *gorm.DB, encryptor *crypto.EncryptionService) CredentialService {
// Attempt to create rotation service (optional for backward compatibility)
rotationService, err := crypto.NewRotationService(db)
if err != nil {
fmt.Printf("Warning: RotationService initialization failed, using basic encryption: %v\n", err)
}
return &credentialService{
db: db,
encryptor: encryptor,
rotationService: rotationService,
securityService: NewSecurityService(db),
}
}
// List retrieves all credentials for a DNS provider.
func (s *credentialService) List(ctx context.Context, providerID uint) ([]models.DNSProviderCredential, error) {
// Verify provider exists and has multi-credential enabled
var provider models.DNSProvider
if err := s.db.WithContext(ctx).Where("id = ?", providerID).First(&provider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDNSProviderNotFound
}
return nil, err
}
if !provider.UseMultiCredentials {
return nil, ErrMultiCredentialNotEnabled
}
var credentials []models.DNSProviderCredential
err := s.db.WithContext(ctx).
Where("dns_provider_id = ?", providerID).
Order("label ASC").
Find(&credentials).Error
return credentials, err
}
// Get retrieves a specific credential by ID.
func (s *credentialService) Get(ctx context.Context, providerID, credentialID uint) (*models.DNSProviderCredential, error) {
var credential models.DNSProviderCredential
err := s.db.WithContext(ctx).
Where("id = ? AND dns_provider_id = ?", credentialID, providerID).
First(&credential).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCredentialNotFound
}
return nil, err
}
return &credential, nil
}
// Create creates a new credential for a DNS provider.
func (s *credentialService) Create(ctx context.Context, providerID uint, req CreateCredentialRequest) (*models.DNSProviderCredential, error) {
// Verify provider exists and has multi-credential enabled
var provider models.DNSProvider
if err := s.db.WithContext(ctx).Where("id = ?", providerID).First(&provider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDNSProviderNotFound
}
return nil, err
}
if !provider.UseMultiCredentials {
return nil, ErrMultiCredentialNotEnabled
}
// Validate credentials for provider type
if err := validateCredentials(provider.ProviderType, req.Credentials); err != nil {
return nil, err
}
// Encrypt credentials using RotationService if available
var encryptedCreds string
var keyVersion int
credentialsJSON, err := json.Marshal(req.Credentials)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
if s.rotationService != nil {
encryptedCreds, keyVersion, err = s.rotationService.EncryptWithCurrentKey(credentialsJSON)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
} else {
encryptedCreds, err = s.encryptor.Encrypt(credentialsJSON)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
keyVersion = 1
}
// Set defaults
propagationTimeout := req.PropagationTimeout
if propagationTimeout == 0 {
propagationTimeout = provider.PropagationTimeout
}
pollingInterval := req.PollingInterval
if pollingInterval == 0 {
pollingInterval = provider.PollingInterval
}
enabled := req.Enabled
// Default to true if not specified in request
if !enabled && req.Enabled {
enabled = true
} else if !req.Enabled {
enabled = true // Default to enabled
}
// Create credential
credential := &models.DNSProviderCredential{
UUID: uuid.New().String(),
DNSProviderID: providerID,
Label: req.Label,
ZoneFilter: strings.TrimSpace(req.ZoneFilter),
CredentialsEncrypted: encryptedCreds,
KeyVersion: keyVersion,
PropagationTimeout: propagationTimeout,
PollingInterval: pollingInterval,
Enabled: enabled,
}
if err := s.db.WithContext(ctx).Create(credential).Error; err != nil {
return nil, err
}
// Log audit event
detailsJSON, _ := json.Marshal(map[string]interface{}{
"label": req.Label,
"zone_filter": req.ZoneFilter,
"provider_id": providerID,
})
if err := s.securityService.LogAudit(&models.SecurityAudit{
Actor: getActorFromContext(ctx),
Action: "credential_create",
EventCategory: "dns_provider",
ResourceID: &provider.ID,
ResourceUUID: provider.UUID,
Details: string(detailsJSON),
IPAddress: getIPFromContext(ctx),
UserAgent: getUserAgentFromContext(ctx),
}); err != nil {
logger.Log().WithError(err).Warn("Failed to log audit event")
}
return credential, nil
}
// Update updates an existing credential.
func (s *credentialService) Update(ctx context.Context, providerID, credentialID uint, req UpdateCredentialRequest) (*models.DNSProviderCredential, error) {
// Fetch existing credential
credential, err := s.Get(ctx, providerID, credentialID)
if err != nil {
return nil, err
}
// Fetch provider for validation and audit logging
var provider models.DNSProvider
if findErr := s.db.WithContext(ctx).Where("id = ?", providerID).First(&provider).Error; findErr != nil {
return nil, findErr
}
// Track changed fields for audit log
changedFields := make(map[string]interface{})
oldValues := make(map[string]interface{})
newValues := make(map[string]interface{})
// Update fields if provided
if req.Label != nil && *req.Label != credential.Label {
oldValues["label"] = credential.Label
newValues["label"] = *req.Label
changedFields["label"] = true
credential.Label = *req.Label
}
if req.ZoneFilter != nil && *req.ZoneFilter != credential.ZoneFilter {
oldValues["zone_filter"] = credential.ZoneFilter
newValues["zone_filter"] = *req.ZoneFilter
changedFields["zone_filter"] = true
credential.ZoneFilter = strings.TrimSpace(*req.ZoneFilter)
}
if req.PropagationTimeout != nil && *req.PropagationTimeout != credential.PropagationTimeout {
oldValues["propagation_timeout"] = credential.PropagationTimeout
newValues["propagation_timeout"] = *req.PropagationTimeout
changedFields["propagation_timeout"] = true
credential.PropagationTimeout = *req.PropagationTimeout
}
if req.PollingInterval != nil && *req.PollingInterval != credential.PollingInterval {
oldValues["polling_interval"] = credential.PollingInterval
newValues["polling_interval"] = *req.PollingInterval
changedFields["polling_interval"] = true
credential.PollingInterval = *req.PollingInterval
}
if req.Enabled != nil && *req.Enabled != credential.Enabled {
oldValues["enabled"] = credential.Enabled
newValues["enabled"] = *req.Enabled
changedFields["enabled"] = true
credential.Enabled = *req.Enabled
}
// Handle credentials update
if len(req.Credentials) > 0 {
// Validate credentials
if err := validateCredentials(provider.ProviderType, req.Credentials); err != nil {
return nil, err
}
// Encrypt new credentials with version tracking
credentialsJSON, err := json.Marshal(req.Credentials)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
var encryptedCreds string
var keyVersion int
if s.rotationService != nil {
encryptedCreds, keyVersion, err = s.rotationService.EncryptWithCurrentKey(credentialsJSON)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
} else {
encryptedCreds, err = s.encryptor.Encrypt(credentialsJSON)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
keyVersion = 1
}
changedFields["credentials"] = true
credential.CredentialsEncrypted = encryptedCreds
credential.KeyVersion = keyVersion
}
// Save updates
if err := s.db.WithContext(ctx).Save(credential).Error; err != nil {
return nil, err
}
// Log audit event if any changes were made
if len(changedFields) > 0 {
detailsJSON, _ := json.Marshal(map[string]interface{}{
"credential_id": credentialID,
"changed_fields": changedFields,
"old_values": oldValues,
"new_values": newValues,
})
if err := s.securityService.LogAudit(&models.SecurityAudit{
Actor: getActorFromContext(ctx),
Action: "credential_update",
EventCategory: "dns_provider",
ResourceID: &provider.ID,
ResourceUUID: provider.UUID,
Details: string(detailsJSON),
IPAddress: getIPFromContext(ctx),
UserAgent: getUserAgentFromContext(ctx),
}); err != nil {
logger.Log().WithError(err).Warn("Failed to log audit event")
}
}
return credential, nil
}
// Delete deletes a credential.
func (s *credentialService) Delete(ctx context.Context, providerID, credentialID uint) error {
// Fetch credential and provider for audit log
credential, err := s.Get(ctx, providerID, credentialID)
if err != nil {
return err
}
var provider models.DNSProvider
if err := s.db.WithContext(ctx).Where("id = ?", providerID).First(&provider).Error; err != nil {
return err
}
const maxDeleteAttempts = 5
var result *gorm.DB
for attempt := 1; attempt <= maxDeleteAttempts; attempt++ {
result = s.db.WithContext(ctx).Delete(&models.DNSProviderCredential{}, credentialID)
if result.Error == nil {
break
}
errMsg := strings.ToLower(result.Error.Error())
isTransientLock := strings.Contains(errMsg, "database is locked") || strings.Contains(errMsg, "database table is locked") || strings.Contains(errMsg, "busy")
if !isTransientLock || attempt == maxDeleteAttempts {
return result.Error
}
time.Sleep(time.Duration(attempt) * 10 * time.Millisecond)
}
if result == nil || result.RowsAffected == 0 {
return ErrCredentialNotFound
}
// Log audit event
detailsJSON, _ := json.Marshal(map[string]interface{}{
"credential_id": credentialID,
"label": credential.Label,
"zone_filter": credential.ZoneFilter,
})
if err := s.securityService.LogAudit(&models.SecurityAudit{
Actor: getActorFromContext(ctx),
Action: "credential_delete",
EventCategory: "dns_provider",
ResourceID: &provider.ID,
ResourceUUID: provider.UUID,
Details: string(detailsJSON),
IPAddress: getIPFromContext(ctx),
UserAgent: getUserAgentFromContext(ctx),
}); err != nil {
logger.Log().WithError(err).Warn("Failed to log audit event")
}
return nil
}
// Test tests a credential's connectivity.
func (s *credentialService) Test(ctx context.Context, providerID, credentialID uint) (*TestResult, error) {
credential, err := s.Get(ctx, providerID, credentialID)
if err != nil {
return nil, err
}
var provider models.DNSProvider
if findErr := s.db.WithContext(ctx).Where("id = ?", providerID).First(&provider).Error; findErr != nil {
return nil, findErr
}
// Decrypt credentials
var decryptedData []byte
if s.rotationService != nil {
decryptedData, err = s.rotationService.DecryptWithVersion(credential.CredentialsEncrypted, credential.KeyVersion)
if err != nil {
return &TestResult{
Success: false,
Error: "Failed to decrypt credentials",
Code: "DECRYPTION_ERROR",
}, nil
}
} else {
decryptedData, err = s.encryptor.Decrypt(credential.CredentialsEncrypted)
if err != nil {
return &TestResult{
Success: false,
Error: "Failed to decrypt credentials",
Code: "DECRYPTION_ERROR",
}, nil
}
}
var credentials map[string]string
if err := json.Unmarshal(decryptedData, &credentials); err != nil {
return &TestResult{
Success: false,
Error: "Invalid credential format",
Code: "INVALID_FORMAT",
}, nil
}
// Perform test using the shared test function
result := testDNSProviderCredentials(provider.ProviderType, credentials)
// Update credential statistics
if result.Success {
credential.SuccessCount++
credential.LastError = ""
} else {
credential.FailureCount++
credential.LastError = result.Error
}
_ = s.db.WithContext(ctx).Save(credential)
// Log audit event
detailsJSON, _ := json.Marshal(map[string]interface{}{
"credential_id": credentialID,
"label": credential.Label,
"test_result": result.Success,
"error": result.Error,
})
if err := s.securityService.LogAudit(&models.SecurityAudit{
Actor: getActorFromContext(ctx),
Action: "credential_test",
EventCategory: "dns_provider",
ResourceID: &provider.ID,
ResourceUUID: provider.UUID,
Details: string(detailsJSON),
IPAddress: getIPFromContext(ctx),
UserAgent: getUserAgentFromContext(ctx),
}); err != nil {
logger.Log().WithError(err).Warn("Failed to log audit event")
}
return result, nil
}
// GetCredentialForDomain selects the best credential match for a domain.
// Priority: exact match > wildcard match > catch-all (empty zone_filter)
func (s *credentialService) GetCredentialForDomain(ctx context.Context, providerID uint, domain string) (*models.DNSProviderCredential, error) {
// Verify provider exists
var provider models.DNSProvider
if err := s.db.WithContext(ctx).Where("id = ?", providerID).First(&provider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDNSProviderNotFound
}
return nil, err
}
// If not using multi-credentials, return nil (caller should use provider's main credentials)
if !provider.UseMultiCredentials {
return nil, nil
}
// Normalize domain (convert IDN to punycode)
normalizedDomain, err := idna.ToASCII(strings.ToLower(strings.TrimSpace(domain)))
if err != nil {
return nil, fmt.Errorf("failed to normalize domain: %w", err)
}
// Find all enabled credentials for this provider (without preload)
var credentials []models.DNSProviderCredential
if err := s.db.WithContext(ctx).
Where("dns_provider_id = ? AND enabled = ?", providerID, true).
Find(&credentials).Error; err != nil {
return nil, err
}
if len(credentials) == 0 {
return nil, ErrNoMatchingCredential
}
// Priority 1: Exact match
for _, cred := range credentials {
if matchesDomain(cred.ZoneFilter, normalizedDomain, true) {
return &cred, nil
}
}
// Priority 2: Wildcard match
for _, cred := range credentials {
if matchesDomain(cred.ZoneFilter, normalizedDomain, false) {
return &cred, nil
}
}
// Priority 3: Catch-all (empty zone_filter)
for _, cred := range credentials {
if strings.TrimSpace(cred.ZoneFilter) == "" {
return &cred, nil
}
}
return nil, ErrNoMatchingCredential
}
// matchesDomain checks if a domain matches a zone filter pattern.
// exactOnly=true means only check for exact matches, false allows wildcards.
func matchesDomain(zoneFilter, domain string, exactOnly bool) bool {
if strings.TrimSpace(zoneFilter) == "" {
return false // Empty filter is catch-all, handled separately
}
// Parse comma-separated zones
zones := strings.Split(zoneFilter, ",")
for _, zone := range zones {
zone = strings.ToLower(strings.TrimSpace(zone))
if zone == "" {
continue
}
// Normalize zone (IDN to punycode)
normalizedZone, err := idna.ToASCII(zone)
if err != nil {
continue
}
// Exact match
if normalizedZone == domain {
return true
}
// Wildcard match (only if not exact-only)
if !exactOnly && strings.HasPrefix(normalizedZone, "*.") {
suffix := normalizedZone[2:] // Remove "*."
if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
return true
}
}
}
return false
}
// EnableMultiCredentials migrates a provider from single to multi-credential mode.
func (s *credentialService) EnableMultiCredentials(ctx context.Context, providerID uint) error {
// Fetch provider
var provider models.DNSProvider
if err := s.db.WithContext(ctx).Where("id = ?", providerID).First(&provider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrDNSProviderNotFound
}
return err
}
// Already enabled
if provider.UseMultiCredentials {
return nil
}
// Check if provider has existing credentials
if provider.CredentialsEncrypted == "" {
return errors.New("provider has no credentials to migrate")
}
// Create a default credential with existing credentials
credential := &models.DNSProviderCredential{
UUID: uuid.New().String(),
DNSProviderID: provider.ID,
Label: "Default (migrated)",
ZoneFilter: "", // Empty = catch-all
CredentialsEncrypted: provider.CredentialsEncrypted,
KeyVersion: provider.KeyVersion,
PropagationTimeout: provider.PropagationTimeout,
PollingInterval: provider.PollingInterval,
Enabled: true,
}
// Start transaction
tx := s.db.WithContext(ctx).Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Create default credential
if err := tx.Create(credential).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to create default credential: %w", err)
}
// Enable multi-credential mode
if err := tx.Model(&provider).Update("use_multi_credentials", true).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to enable multi-credential mode: %w", err)
}
// Commit transaction
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
// Log audit event
detailsJSON, _ := json.Marshal(map[string]interface{}{
"provider_id": providerID,
"provider_name": provider.Name,
"migrated_credential_label": credential.Label,
})
if err := s.securityService.LogAudit(&models.SecurityAudit{
Actor: getActorFromContext(ctx),
Action: "multi_credential_enabled",
EventCategory: "dns_provider",
ResourceID: &provider.ID,
ResourceUUID: provider.UUID,
Details: string(detailsJSON),
IPAddress: getIPFromContext(ctx),
UserAgent: getUserAgentFromContext(ctx),
}); err != nil {
logger.Log().WithError(err).Warn("Failed to log audit event")
}
return nil
}