- Added a reset of the models.Setting struct before querying for settings in both the Manager and Cerberus components to avoid ID leakage from previous queries. - Introduced new functions in Cerberus for checking admin authentication and admin whitelist status. - Enhanced middleware logic to allow admin users to bypass ACL checks if their IP is whitelisted. - Added tests to verify the behavior of the middleware with respect to ACLs and admin whitelisting. - Created a new utility for checking if an IP is in a CIDR list. - Updated various services to use `Where` clause for fetching records by ID instead of directly passing the ID to `First`, ensuring consistency in query patterns. - Added comprehensive tests for settings queries to demonstrate and verify the fix for ID leakage issues.
450 lines
13 KiB
Go
450 lines
13 KiB
Go
package services
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/Wikid82/charon/backend/internal/models"
|
|
"github.com/google/uuid"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
ErrAccessListNotFound = errors.New("access list not found")
|
|
ErrInvalidAccessListType = errors.New("invalid access list type")
|
|
ErrInvalidIPAddress = errors.New("invalid IP address or CIDR")
|
|
ErrInvalidCountryCode = errors.New("invalid country code")
|
|
ErrAccessListInUse = errors.New("access list is in use by proxy hosts")
|
|
)
|
|
|
|
// ValidAccessListTypes defines allowed access list types
|
|
var ValidAccessListTypes = []string{"whitelist", "blacklist", "geo_whitelist", "geo_blacklist"}
|
|
|
|
// RFC1918PrivateNetworks defines private IP ranges
|
|
var RFC1918PrivateNetworks = []string{
|
|
"10.0.0.0/8",
|
|
"172.16.0.0/12",
|
|
"192.168.0.0/16",
|
|
"127.0.0.0/8", // localhost
|
|
"169.254.0.0/16", // link-local
|
|
"fc00::/7", // IPv6 ULA
|
|
"fe80::/10", // IPv6 link-local
|
|
"::1/128", // IPv6 localhost
|
|
}
|
|
|
|
// ISO 3166-1 alpha-2 country codes (comprehensive list for validation)
|
|
var validCountryCodes = map[string]bool{
|
|
// North America
|
|
"US": true, "CA": true, "MX": true,
|
|
// Europe
|
|
"GB": true, "DE": true, "FR": true, "IT": true, "ES": true, "NL": true, "BE": true,
|
|
"SE": true, "NO": true, "DK": true, "FI": true, "PL": true, "CZ": true, "AT": true,
|
|
"CH": true, "IE": true, "PT": true, "GR": true, "HU": true, "RO": true, "BG": true,
|
|
"HR": true, "SI": true, "SK": true, "LT": true, "LV": true, "EE": true, "IS": true,
|
|
"LU": true, "MT": true, "CY": true, "UA": true, "BY": true,
|
|
// Asia
|
|
"JP": true, "CN": true, "IN": true, "KR": true, "SG": true, "MY": true, "TH": true,
|
|
"ID": true, "PH": true, "VN": true, "TW": true, "HK": true, "PK": true, "BD": true,
|
|
"KP": true, "IR": true, "IQ": true, "SY": true, "AF": true, "LK": true, "MM": true,
|
|
// Middle East
|
|
"TR": true, "IL": true, "SA": true, "AE": true, "QA": true, "KW": true, "OM": true,
|
|
"BH": true, "JO": true, "LB": true, "YE": true,
|
|
// Africa
|
|
"EG": true, "ZA": true, "NG": true, "KE": true, "ET": true, "TZ": true, "MA": true,
|
|
"DZ": true, "SD": true, "UG": true, "GH": true,
|
|
// South America
|
|
"BR": true, "AR": true, "CL": true, "CO": true, "PE": true, "VE": true, "EC": true,
|
|
"BO": true, "PY": true, "UY": true,
|
|
// Caribbean / Central America
|
|
"CU": true, "DO": true, "PR": true, "JM": true, "HT": true, "PA": true, "CR": true,
|
|
// Oceania
|
|
"AU": true, "NZ": true,
|
|
// Russia & CIS
|
|
"RU": true, "KZ": true, "UZ": true, "AZ": true, "GE": true, "AM": true,
|
|
}
|
|
|
|
// AccessListService handles access list CRUD and IP testing operations.
|
|
type AccessListService struct {
|
|
db *gorm.DB
|
|
geoipSvc *GeoIPService
|
|
}
|
|
|
|
// NewAccessListService creates a new AccessListService.
|
|
func NewAccessListService(db *gorm.DB) *AccessListService {
|
|
return &AccessListService{db: db}
|
|
}
|
|
|
|
// SetGeoIPService sets the GeoIP service for geo-based access list lookups.
|
|
// This method allows optional injection of the GeoIP service.
|
|
func (s *AccessListService) SetGeoIPService(geoipSvc *GeoIPService) {
|
|
s.geoipSvc = geoipSvc
|
|
}
|
|
|
|
// GetGeoIPService returns the configured GeoIP service (may be nil).
|
|
func (s *AccessListService) GetGeoIPService() *GeoIPService {
|
|
return s.geoipSvc
|
|
}
|
|
|
|
// Create creates a new access list with validation
|
|
func (s *AccessListService) Create(acl *models.AccessList) error {
|
|
if err := s.validateAccessList(acl); err != nil {
|
|
return err
|
|
}
|
|
|
|
acl.UUID = uuid.New().String()
|
|
return s.db.Create(acl).Error
|
|
}
|
|
|
|
// GetByID retrieves an access list by ID
|
|
func (s *AccessListService) GetByID(id uint) (*models.AccessList, error) {
|
|
var acl models.AccessList
|
|
if err := s.db.Where("id = ?", id).First(&acl).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrAccessListNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &acl, nil
|
|
}
|
|
|
|
// GetByUUID retrieves an access list by UUID
|
|
func (s *AccessListService) GetByUUID(uuidStr string) (*models.AccessList, error) {
|
|
var acl models.AccessList
|
|
if err := s.db.Where("uuid = ?", uuidStr).First(&acl).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrAccessListNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &acl, nil
|
|
}
|
|
|
|
// List retrieves all access lists sorted by updated_at desc
|
|
func (s *AccessListService) List() ([]models.AccessList, error) {
|
|
var acls []models.AccessList
|
|
if err := s.db.Order("updated_at desc").Find(&acls).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return acls, nil
|
|
}
|
|
|
|
// Update updates an existing access list with validation
|
|
func (s *AccessListService) Update(id uint, updates *models.AccessList) error {
|
|
acl, err := s.GetByID(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Apply updates
|
|
acl.Name = updates.Name
|
|
acl.Description = updates.Description
|
|
acl.Type = updates.Type
|
|
acl.IPRules = updates.IPRules
|
|
acl.CountryCodes = updates.CountryCodes
|
|
acl.LocalNetworkOnly = updates.LocalNetworkOnly
|
|
acl.Enabled = updates.Enabled
|
|
|
|
if err := s.validateAccessList(acl); err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.db.Save(acl).Error
|
|
}
|
|
|
|
// Delete deletes an access list if not in use
|
|
func (s *AccessListService) Delete(id uint) error {
|
|
// Check if ACL is in use by any proxy hosts
|
|
var count int64
|
|
if err := s.db.Model(&models.ProxyHost{}).Where("access_list_id = ?", id).Count(&count).Error; err != nil {
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
return ErrAccessListInUse
|
|
}
|
|
|
|
result := s.db.Delete(&models.AccessList{}, id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrAccessListNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// TestIP tests if an IP address would be allowed/blocked by the access list
|
|
func (s *AccessListService) TestIP(aclID uint, ipAddress string) (allowed bool, reason string, err error) {
|
|
acl, err := s.GetByID(aclID)
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
|
|
if !acl.Enabled {
|
|
return true, "Access list is disabled - all traffic allowed", nil
|
|
}
|
|
|
|
ip := net.ParseIP(ipAddress)
|
|
if ip == nil {
|
|
return false, "", ErrInvalidIPAddress
|
|
}
|
|
|
|
// Test local network only
|
|
if acl.LocalNetworkOnly {
|
|
if !s.isPrivateIP(ip) {
|
|
return false, "Not a private network IP (RFC1918)", nil
|
|
}
|
|
return true, "Allowed by local network only rule", nil
|
|
}
|
|
|
|
// Handle geo-based access lists
|
|
if strings.HasPrefix(acl.Type, "geo_") {
|
|
return s.testGeoIP(acl, ipAddress)
|
|
}
|
|
|
|
// Test IP rules
|
|
if acl.IPRules != "" {
|
|
var rules []models.AccessListRule
|
|
if err := json.Unmarshal([]byte(acl.IPRules), &rules); err == nil {
|
|
for _, rule := range rules {
|
|
if s.ipMatchesCIDR(ip, rule.CIDR) {
|
|
if acl.Type == "whitelist" {
|
|
return true, fmt.Sprintf("Allowed by whitelist rule: %s", rule.CIDR), nil
|
|
}
|
|
if acl.Type == "blacklist" {
|
|
return false, fmt.Sprintf("Blocked by blacklist rule: %s", rule.CIDR), nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Default behavior based on type
|
|
if acl.Type == "whitelist" {
|
|
return false, "Not in whitelist", nil
|
|
}
|
|
return true, "Not in blacklist", nil
|
|
}
|
|
|
|
// testGeoIP tests an IP against geo-based access list rules.
|
|
func (s *AccessListService) testGeoIP(acl *models.AccessList, ipAddress string) (allowed bool, reason string, err error) {
|
|
// Check if GeoIP service is available
|
|
if s.geoipSvc == nil || !s.geoipSvc.IsLoaded() {
|
|
// Graceful degradation: if GeoIP is not available, allow traffic with warning
|
|
return true, "GeoIP database not available - allowing by default", nil
|
|
}
|
|
|
|
// Look up the country for this IP
|
|
country, err := s.geoipSvc.LookupCountry(ipAddress)
|
|
if err != nil {
|
|
// Handle specific errors
|
|
if errors.Is(err, ErrCountryNotFound) {
|
|
// IP has no associated country (e.g., private IP, reserved range)
|
|
if acl.Type == "geo_whitelist" {
|
|
return false, "No country found for IP - not in geo whitelist", nil
|
|
}
|
|
// For blacklist, unknown country means not blocked
|
|
return true, "No country found for IP - not in geo blacklist", nil
|
|
}
|
|
// Other errors: graceful degradation
|
|
return true, fmt.Sprintf("GeoIP lookup error: %v - allowing by default", err), nil
|
|
}
|
|
|
|
// Parse the country codes from the ACL
|
|
allowedCountries := s.parseCountryCodes(acl.CountryCodes)
|
|
|
|
// Check if the country is in the list
|
|
countryInList := false
|
|
for _, code := range allowedCountries {
|
|
if strings.EqualFold(code, country) {
|
|
countryInList = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// Apply whitelist/blacklist logic
|
|
if acl.Type == "geo_whitelist" {
|
|
if countryInList {
|
|
return true, fmt.Sprintf("Allowed by geo whitelist: IP is from %s", country), nil
|
|
}
|
|
return false, fmt.Sprintf("Blocked by geo whitelist: IP is from %s (not in allowed countries)", country), nil
|
|
}
|
|
|
|
// geo_blacklist
|
|
if countryInList {
|
|
return false, fmt.Sprintf("Blocked by geo blacklist: IP is from %s", country), nil
|
|
}
|
|
return true, fmt.Sprintf("Allowed: IP is from %s (not in blocked countries)", country), nil
|
|
}
|
|
|
|
// parseCountryCodes parses a comma-separated list of country codes.
|
|
func (s *AccessListService) parseCountryCodes(codes string) []string {
|
|
if codes == "" {
|
|
return nil
|
|
}
|
|
parts := strings.Split(codes, ",")
|
|
result := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
code := strings.TrimSpace(strings.ToUpper(part))
|
|
if code != "" {
|
|
result = append(result, code)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// validateAccessList validates access list fields
|
|
func (s *AccessListService) validateAccessList(acl *models.AccessList) error {
|
|
// Validate name
|
|
if strings.TrimSpace(acl.Name) == "" {
|
|
return errors.New("name is required")
|
|
}
|
|
|
|
// Validate type
|
|
if !s.isValidType(acl.Type) {
|
|
return ErrInvalidAccessListType
|
|
}
|
|
|
|
// Validate IP rules
|
|
if acl.IPRules != "" {
|
|
var rules []models.AccessListRule
|
|
if err := json.Unmarshal([]byte(acl.IPRules), &rules); err != nil {
|
|
return fmt.Errorf("invalid IP rules JSON: %w", err)
|
|
}
|
|
|
|
for _, rule := range rules {
|
|
if !s.isValidCIDR(rule.CIDR) {
|
|
return fmt.Errorf("%w: %s", ErrInvalidIPAddress, rule.CIDR)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Validate country codes for geo types
|
|
if strings.HasPrefix(acl.Type, "geo_") {
|
|
if acl.CountryCodes == "" {
|
|
return errors.New("country codes are required for geo-blocking")
|
|
}
|
|
codes := strings.Split(acl.CountryCodes, ",")
|
|
for _, code := range codes {
|
|
code = strings.TrimSpace(strings.ToUpper(code))
|
|
if !s.isValidCountryCode(code) {
|
|
return fmt.Errorf("%w: %s", ErrInvalidCountryCode, code)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isValidType checks if access list type is valid
|
|
func (s *AccessListService) isValidType(aclType string) bool {
|
|
for _, valid := range ValidAccessListTypes {
|
|
if aclType == valid {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isValidCIDR validates IP address or CIDR notation
|
|
func (s *AccessListService) isValidCIDR(cidr string) bool {
|
|
// Try parsing as single IP
|
|
if ip := net.ParseIP(cidr); ip != nil {
|
|
return true
|
|
}
|
|
|
|
// Try parsing as CIDR
|
|
_, _, err := net.ParseCIDR(cidr)
|
|
return err == nil
|
|
}
|
|
|
|
// isValidCountryCode validates ISO 3166-1 alpha-2 country code
|
|
func (s *AccessListService) isValidCountryCode(code string) bool {
|
|
code = strings.ToUpper(strings.TrimSpace(code))
|
|
if len(code) != 2 {
|
|
return false
|
|
}
|
|
matched, _ := regexp.MatchString("^[A-Z]{2}$", code)
|
|
return matched && validCountryCodes[code]
|
|
}
|
|
|
|
// ipMatchesCIDR checks if an IP matches a CIDR block
|
|
func (s *AccessListService) ipMatchesCIDR(ip net.IP, cidr string) bool {
|
|
// Check if it's a single IP
|
|
if singleIP := net.ParseIP(cidr); singleIP != nil {
|
|
return ip.Equal(singleIP)
|
|
}
|
|
|
|
// Check CIDR range
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return ipNet.Contains(ip)
|
|
}
|
|
|
|
// isPrivateIP checks if an IP is in RFC1918 private ranges
|
|
func (s *AccessListService) isPrivateIP(ip net.IP) bool {
|
|
for _, cidr := range RFC1918PrivateNetworks {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if ipNet.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetTemplates returns predefined ACL templates
|
|
func (s *AccessListService) GetTemplates() []map[string]any {
|
|
return []map[string]any{
|
|
{
|
|
"id": "local-network",
|
|
"name": "Local Network Only",
|
|
"description": "Allow only RFC1918 private network IPs (home/office networks)",
|
|
"type": "whitelist",
|
|
"local_network_only": true,
|
|
"category": "security",
|
|
},
|
|
{
|
|
"id": "us-only",
|
|
"name": "US Only",
|
|
"description": "Allow only United States IPs",
|
|
"type": "geo_whitelist",
|
|
"country_codes": "US",
|
|
"category": "security",
|
|
},
|
|
{
|
|
"id": "eu-only",
|
|
"name": "EU Only",
|
|
"description": "Allow only European Union IPs",
|
|
"type": "geo_whitelist",
|
|
"country_codes": "AT,BE,BG,HR,CY,CZ,DK,EE,FI,FR,DE,GR,HU,IE,IT,LV,LT,LU,MT,NL,PL,PT,RO,SK,SI,ES,SE",
|
|
"category": "security",
|
|
},
|
|
{
|
|
"id": "high-risk-countries",
|
|
"name": "Block High-Risk Countries",
|
|
"description": "Block OFAC sanctioned countries and known attack sources",
|
|
"type": "geo_blacklist",
|
|
"country_codes": "RU,CN,KP,IR,BY,SY,VE,CU,SD",
|
|
"category": "security",
|
|
},
|
|
{
|
|
"id": "expanded-threat-countries",
|
|
"name": "Block Expanded Threat List",
|
|
"description": "Block high-risk countries plus additional bot/spam sources",
|
|
"type": "geo_blacklist",
|
|
"country_codes": "RU,CN,KP,IR,BY,SY,VE,CU,SD,PK,BD,NG,UA,VN,ID",
|
|
"category": "security",
|
|
},
|
|
// IP-based presets removed: IP blocklists and scanner ranges
|
|
// These are better handled by CrowdSec, WAF, or rate limiting.
|
|
}
|
|
}
|