chore: clean .gitignore cache
This commit is contained in:
@@ -1,95 +0,0 @@
|
||||
// Package security provides audit logging for security-sensitive operations.
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditEvent represents a security audit log entry.
|
||||
// All fields are included in JSON output for structured logging.
|
||||
type AuditEvent struct {
|
||||
Timestamp string `json:"timestamp"` // RFC3339 timestamp of the event
|
||||
Action string `json:"action"` // Action being performed (e.g., "url_validation", "url_test")
|
||||
Host string `json:"host"` // Target hostname from URL
|
||||
RequestID string `json:"request_id"` // Unique request identifier for tracing
|
||||
Result string `json:"result"` // Result of action: "allowed", "blocked", "error"
|
||||
ResolvedIPs []string `json:"resolved_ips"` // DNS resolution results (for debugging)
|
||||
BlockedReason string `json:"blocked_reason"` // Why the request was blocked
|
||||
UserID string `json:"user_id"` // User who made the request (CRITICAL for attribution)
|
||||
SourceIP string `json:"source_ip"` // IP address of the request originator
|
||||
}
|
||||
|
||||
// AuditLogger provides structured security audit logging.
|
||||
type AuditLogger struct {
|
||||
// prefix is prepended to all log messages
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewAuditLogger creates a new security audit logger.
|
||||
func NewAuditLogger() *AuditLogger {
|
||||
return &AuditLogger{
|
||||
prefix: "[SECURITY AUDIT]",
|
||||
}
|
||||
}
|
||||
|
||||
// LogURLValidation logs a URL validation event.
|
||||
func (al *AuditLogger) LogURLValidation(event AuditEvent) {
|
||||
// Ensure timestamp is set
|
||||
if event.Timestamp == "" {
|
||||
event.Timestamp = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// Serialize to JSON for structured logging
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
log.Printf("%s ERROR: Failed to serialize audit event: %v", al.prefix, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Log to standard logger (will be captured by application logger)
|
||||
log.Printf("%s %s", al.prefix, string(eventJSON))
|
||||
}
|
||||
|
||||
// LogURLTest is a convenience method for logging URL connectivity tests.
|
||||
func (al *AuditLogger) LogURLTest(host, requestID, userID, sourceIP, result string) {
|
||||
event := AuditEvent{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Action: "url_connectivity_test",
|
||||
Host: host,
|
||||
RequestID: requestID,
|
||||
Result: result,
|
||||
UserID: userID,
|
||||
SourceIP: sourceIP,
|
||||
}
|
||||
al.LogURLValidation(event)
|
||||
}
|
||||
|
||||
// LogSSRFBlock is a convenience method for logging blocked SSRF attempts.
|
||||
func (al *AuditLogger) LogSSRFBlock(host string, resolvedIPs []string, reason, userID, sourceIP string) {
|
||||
event := AuditEvent{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Action: "ssrf_block",
|
||||
Host: host,
|
||||
ResolvedIPs: resolvedIPs,
|
||||
BlockedReason: reason,
|
||||
Result: "blocked",
|
||||
UserID: userID,
|
||||
SourceIP: sourceIP,
|
||||
}
|
||||
al.LogURLValidation(event)
|
||||
}
|
||||
|
||||
// Global audit logger instance
|
||||
var globalAuditLogger = NewAuditLogger()
|
||||
|
||||
// LogURLTest logs a URL test event using the global logger.
|
||||
func LogURLTest(host, requestID, userID, sourceIP, result string) {
|
||||
globalAuditLogger.LogURLTest(host, requestID, userID, sourceIP, result)
|
||||
}
|
||||
|
||||
// LogSSRFBlock logs a blocked SSRF attempt using the global logger.
|
||||
func LogSSRFBlock(host string, resolvedIPs []string, reason, userID, sourceIP string) {
|
||||
globalAuditLogger.LogSSRFBlock(host, resolvedIPs, reason, userID, sourceIP)
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestAuditEvent_JSONSerialization tests that audit events serialize correctly to JSON.
|
||||
func TestAuditEvent_JSONSerialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
event := AuditEvent{
|
||||
Timestamp: "2025-12-31T12:00:00Z",
|
||||
Action: "url_validation",
|
||||
Host: "example.com",
|
||||
RequestID: "test-123",
|
||||
Result: "blocked",
|
||||
ResolvedIPs: []string{"192.168.1.1", "10.0.0.1"},
|
||||
BlockedReason: "private_ip",
|
||||
UserID: "user123",
|
||||
SourceIP: "203.0.113.1",
|
||||
}
|
||||
|
||||
// Serialize to JSON
|
||||
jsonBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal AuditEvent: %v", err)
|
||||
}
|
||||
|
||||
// Verify all fields are present
|
||||
jsonStr := string(jsonBytes)
|
||||
expectedFields := []string{
|
||||
"timestamp", "action", "host", "request_id", "result",
|
||||
"resolved_ips", "blocked_reason", "user_id", "source_ip",
|
||||
}
|
||||
|
||||
for _, field := range expectedFields {
|
||||
if !strings.Contains(jsonStr, field) {
|
||||
t.Errorf("JSON output missing field: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
// Deserialize and verify
|
||||
var decoded AuditEvent
|
||||
err = json.Unmarshal(jsonBytes, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal AuditEvent: %v", err)
|
||||
}
|
||||
|
||||
if decoded.Timestamp != event.Timestamp {
|
||||
t.Errorf("Timestamp mismatch: got %s, want %s", decoded.Timestamp, event.Timestamp)
|
||||
}
|
||||
if decoded.UserID != event.UserID {
|
||||
t.Errorf("UserID mismatch: got %s, want %s", decoded.UserID, event.UserID)
|
||||
}
|
||||
if len(decoded.ResolvedIPs) != len(event.ResolvedIPs) {
|
||||
t.Errorf("ResolvedIPs length mismatch: got %d, want %d", len(decoded.ResolvedIPs), len(event.ResolvedIPs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditLogger_LogURLValidation tests audit logging of URL validation events.
|
||||
func TestAuditLogger_LogURLValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := NewAuditLogger()
|
||||
|
||||
event := AuditEvent{
|
||||
Action: "url_test",
|
||||
Host: "malicious.com",
|
||||
RequestID: "req-456",
|
||||
Result: "blocked",
|
||||
ResolvedIPs: []string{"169.254.169.254"},
|
||||
BlockedReason: "metadata_endpoint",
|
||||
UserID: "attacker",
|
||||
SourceIP: "198.51.100.1",
|
||||
}
|
||||
|
||||
// This will log to standard logger, which we can't easily capture in tests
|
||||
// But we can verify it doesn't panic
|
||||
logger.LogURLValidation(event)
|
||||
|
||||
// Verify timestamp was auto-added if missing
|
||||
event2 := AuditEvent{
|
||||
Action: "test",
|
||||
Host: "test.com",
|
||||
}
|
||||
logger.LogURLValidation(event2)
|
||||
}
|
||||
|
||||
// TestAuditLogger_LogURLTest tests the convenience method for URL tests.
|
||||
func TestAuditLogger_LogURLTest(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := NewAuditLogger()
|
||||
|
||||
// Should not panic
|
||||
logger.LogURLTest("example.com", "req-789", "user456", "192.0.2.1", "allowed")
|
||||
}
|
||||
|
||||
// TestAuditLogger_LogSSRFBlock tests the convenience method for SSRF blocks.
|
||||
func TestAuditLogger_LogSSRFBlock(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := NewAuditLogger()
|
||||
|
||||
resolvedIPs := []string{"10.0.0.1", "192.168.1.1"}
|
||||
|
||||
// Should not panic
|
||||
logger.LogSSRFBlock("internal.local", resolvedIPs, "private_ip", "user123", "203.0.113.5")
|
||||
}
|
||||
|
||||
// TestGlobalAuditLogger tests the global audit logger functions.
|
||||
func TestGlobalAuditLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test global functions don't panic
|
||||
LogURLTest("test.com", "req-global", "user-global", "192.0.2.10", "allowed")
|
||||
LogSSRFBlock("blocked.local", []string{"127.0.0.1"}, "loopback", "user-global", "198.51.100.10")
|
||||
}
|
||||
|
||||
// TestAuditEvent_RequiredFields tests that required fields are enforced.
|
||||
func TestAuditEvent_RequiredFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
// CRITICAL: UserID field must be present for attribution
|
||||
event := AuditEvent{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Action: "ssrf_block",
|
||||
Host: "malicious.com",
|
||||
RequestID: "req-security",
|
||||
Result: "blocked",
|
||||
ResolvedIPs: []string{"192.168.1.1"},
|
||||
BlockedReason: "private_ip",
|
||||
UserID: "attacker123", // REQUIRED per Supervisor review
|
||||
SourceIP: "203.0.113.100",
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify UserID is in JSON output
|
||||
if !strings.Contains(string(jsonBytes), "attacker123") {
|
||||
t.Errorf("UserID not found in audit log JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditLogger_TimestampFormat tests that timestamps use RFC3339 format.
|
||||
func TestAuditLogger_TimestampFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := NewAuditLogger()
|
||||
|
||||
event := AuditEvent{
|
||||
Action: "test",
|
||||
Host: "test.com",
|
||||
// Timestamp intentionally omitted to test auto-generation
|
||||
}
|
||||
|
||||
// Capture the event by marshaling after logging
|
||||
// In real scenario, LogURLValidation sets the timestamp
|
||||
if event.Timestamp == "" {
|
||||
event.Timestamp = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// Parse the timestamp to verify it's valid RFC3339
|
||||
_, err := time.Parse(time.RFC3339, event.Timestamp)
|
||||
if err != nil {
|
||||
t.Errorf("Invalid timestamp format: %s, error: %v", event.Timestamp, err)
|
||||
}
|
||||
|
||||
logger.LogURLValidation(event)
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestAuditEvent_JSONSerialization tests that audit events serialize correctly to JSON.
|
||||
func TestAuditEvent_JSONSerialization(t *testing.T) {
|
||||
event := AuditEvent{
|
||||
Timestamp: "2025-12-31T12:00:00Z",
|
||||
Action: "url_validation",
|
||||
Host: "example.com",
|
||||
RequestID: "test-123",
|
||||
Result: "blocked",
|
||||
ResolvedIPs: []string{"192.168.1.1", "10.0.0.1"},
|
||||
BlockedReason: "private_ip",
|
||||
UserID: "user123",
|
||||
SourceIP: "203.0.113.1",
|
||||
}
|
||||
|
||||
// Serialize to JSON
|
||||
jsonBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal AuditEvent: %v", err)
|
||||
}
|
||||
|
||||
// Verify all fields are present
|
||||
jsonStr := string(jsonBytes)
|
||||
expectedFields := []string{
|
||||
"timestamp", "action", "host", "request_id", "result",
|
||||
"resolved_ips", "blocked_reason", "user_id", "source_ip",
|
||||
}
|
||||
|
||||
for _, field := range expectedFields {
|
||||
if !strings.Contains(jsonStr, field) {
|
||||
t.Errorf("JSON output missing field: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
// Deserialize and verify
|
||||
var decoded AuditEvent
|
||||
err = json.Unmarshal(jsonBytes, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal AuditEvent: %v", err)
|
||||
}
|
||||
|
||||
if decoded.Timestamp != event.Timestamp {
|
||||
t.Errorf("Timestamp mismatch: got %s, want %s", decoded.Timestamp, event.Timestamp)
|
||||
}
|
||||
if decoded.UserID != event.UserID {
|
||||
t.Errorf("UserID mismatch: got %s, want %s", decoded.UserID, event.UserID)
|
||||
}
|
||||
if len(decoded.ResolvedIPs) != len(event.ResolvedIPs) {
|
||||
t.Errorf("ResolvedIPs length mismatch: got %d, want %d", len(decoded.ResolvedIPs), len(event.ResolvedIPs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditLogger_LogURLValidation tests audit logging of URL validation events.
|
||||
func TestAuditLogger_LogURLValidation(t *testing.T) {
|
||||
logger := NewAuditLogger()
|
||||
|
||||
event := AuditEvent{
|
||||
Action: "url_test",
|
||||
Host: "malicious.com",
|
||||
RequestID: "req-456",
|
||||
Result: "blocked",
|
||||
ResolvedIPs: []string{"169.254.169.254"},
|
||||
BlockedReason: "metadata_endpoint",
|
||||
UserID: "attacker",
|
||||
SourceIP: "198.51.100.1",
|
||||
}
|
||||
|
||||
// This will log to standard logger, which we can't easily capture in tests
|
||||
// But we can verify it doesn't panic
|
||||
logger.LogURLValidation(event)
|
||||
|
||||
// Verify timestamp was auto-added if missing
|
||||
event2 := AuditEvent{
|
||||
Action: "test",
|
||||
Host: "test.com",
|
||||
}
|
||||
logger.LogURLValidation(event2)
|
||||
}
|
||||
|
||||
// TestAuditLogger_LogURLTest tests the convenience method for URL tests.
|
||||
func TestAuditLogger_LogURLTest(t *testing.T) {
|
||||
logger := NewAuditLogger()
|
||||
|
||||
// Should not panic
|
||||
logger.LogURLTest("example.com", "req-789", "user456", "192.0.2.1", "allowed")
|
||||
}
|
||||
|
||||
// TestAuditLogger_LogSSRFBlock tests the convenience method for SSRF blocks.
|
||||
func TestAuditLogger_LogSSRFBlock(t *testing.T) {
|
||||
logger := NewAuditLogger()
|
||||
|
||||
resolvedIPs := []string{"10.0.0.1", "192.168.1.1"}
|
||||
|
||||
// Should not panic
|
||||
logger.LogSSRFBlock("internal.local", resolvedIPs, "private_ip", "user123", "203.0.113.5")
|
||||
}
|
||||
|
||||
// TestGlobalAuditLogger tests the global audit logger functions.
|
||||
func TestGlobalAuditLogger(t *testing.T) {
|
||||
// Test global functions don't panic
|
||||
LogURLTest("test.com", "req-global", "user-global", "192.0.2.10", "allowed")
|
||||
LogSSRFBlock("blocked.local", []string{"127.0.0.1"}, "loopback", "user-global", "198.51.100.10")
|
||||
}
|
||||
|
||||
// TestAuditEvent_RequiredFields tests that required fields are enforced.
|
||||
func TestAuditEvent_RequiredFields(t *testing.T) {
|
||||
// CRITICAL: UserID field must be present for attribution
|
||||
event := AuditEvent{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Action: "ssrf_block",
|
||||
Host: "malicious.com",
|
||||
RequestID: "req-security",
|
||||
Result: "blocked",
|
||||
ResolvedIPs: []string{"192.168.1.1"},
|
||||
BlockedReason: "private_ip",
|
||||
UserID: "attacker123", // REQUIRED per Supervisor review
|
||||
SourceIP: "203.0.113.100",
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify UserID is in JSON output
|
||||
if !strings.Contains(string(jsonBytes), "attacker123") {
|
||||
t.Errorf("UserID not found in audit log JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditLogger_TimestampFormat tests that timestamps use RFC3339 format.
|
||||
func TestAuditLogger_TimestampFormat(t *testing.T) {
|
||||
logger := NewAuditLogger()
|
||||
|
||||
event := AuditEvent{
|
||||
Action: "test",
|
||||
Host: "test.com",
|
||||
// Timestamp intentionally omitted to test auto-generation
|
||||
}
|
||||
|
||||
// Capture the event by marshaling after logging
|
||||
// In real scenario, LogURLValidation sets the timestamp
|
||||
if event.Timestamp == "" {
|
||||
event.Timestamp = time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// Parse the timestamp to verify it's valid RFC3339
|
||||
_, err := time.Parse(time.RFC3339, event.Timestamp)
|
||||
if err != nil {
|
||||
t.Errorf("Invalid timestamp format: %s, error: %v", event.Timestamp, err)
|
||||
}
|
||||
|
||||
logger.LogURLValidation(event)
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseExactHostnameAllowlist(t *testing.T) {
|
||||
allow := ParseExactHostnameAllowlist(" crowdsec , CADDY, ,http://example.com,example.com/path,user@host,::1 ")
|
||||
|
||||
if _, ok := allow["crowdsec"]; !ok {
|
||||
t.Fatalf("expected allowlist to contain crowdsec")
|
||||
}
|
||||
if _, ok := allow["caddy"]; !ok {
|
||||
t.Fatalf("expected allowlist to contain caddy")
|
||||
}
|
||||
if _, ok := allow["::1"]; !ok {
|
||||
t.Fatalf("expected allowlist to contain ::1")
|
||||
}
|
||||
|
||||
if _, ok := allow["http://example.com"]; ok {
|
||||
t.Fatalf("expected scheme-containing entry to be ignored")
|
||||
}
|
||||
if _, ok := allow["example.com/path"]; ok {
|
||||
t.Fatalf("expected path-containing entry to be ignored")
|
||||
}
|
||||
if _, ok := allow["user@host"]; ok {
|
||||
t.Fatalf("expected userinfo-containing entry to be ignored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInternalServiceBaseURL(t *testing.T) {
|
||||
allowed := map[string]struct{}{"localhost": {}, "127.0.0.1": {}, "::1": {}}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
raw string
|
||||
expectedPort int
|
||||
want string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "OK http localhost explicit port",
|
||||
raw: "http://localhost:2019",
|
||||
expectedPort: 2019,
|
||||
want: "http://localhost:2019",
|
||||
},
|
||||
{
|
||||
name: "OK http localhost path normalized",
|
||||
raw: "http://localhost:2019/config/",
|
||||
expectedPort: 2019,
|
||||
want: "http://localhost:2019",
|
||||
},
|
||||
{
|
||||
name: "OK https localhost default port",
|
||||
raw: "https://localhost",
|
||||
expectedPort: 443,
|
||||
want: "https://localhost:443",
|
||||
},
|
||||
{
|
||||
name: "OK ipv6 loopback explicit port",
|
||||
raw: "http://[::1]:2019",
|
||||
expectedPort: 2019,
|
||||
want: "http://[::1]:2019",
|
||||
},
|
||||
{
|
||||
name: "Reject userinfo",
|
||||
raw: "http://user:pass@localhost:2019",
|
||||
expectedPort: 2019,
|
||||
wantErr: true,
|
||||
errContains: "embedded credentials",
|
||||
},
|
||||
{
|
||||
name: "Reject unsupported scheme",
|
||||
raw: "file://localhost:2019",
|
||||
expectedPort: 2019,
|
||||
wantErr: true,
|
||||
errContains: "unsupported scheme",
|
||||
},
|
||||
{
|
||||
name: "Reject missing hostname",
|
||||
raw: "http://:2019",
|
||||
expectedPort: 2019,
|
||||
wantErr: true,
|
||||
errContains: "missing hostname",
|
||||
},
|
||||
{
|
||||
name: "Reject hostname not allowed",
|
||||
raw: "http://evil.example:2019",
|
||||
expectedPort: 2019,
|
||||
wantErr: true,
|
||||
errContains: "hostname not allowed",
|
||||
},
|
||||
{
|
||||
name: "Reject unexpected port when omitted",
|
||||
raw: "http://localhost",
|
||||
expectedPort: 2019,
|
||||
wantErr: true,
|
||||
errContains: "unexpected port",
|
||||
},
|
||||
{
|
||||
name: "Reject invalid port",
|
||||
raw: "http://localhost:0",
|
||||
expectedPort: 2019,
|
||||
wantErr: true,
|
||||
errContains: "invalid port",
|
||||
},
|
||||
{
|
||||
name: "Reject out-of-range port",
|
||||
raw: "http://localhost:99999",
|
||||
expectedPort: 2019,
|
||||
wantErr: true,
|
||||
errContains: "invalid port",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
u, err := ValidateInternalServiceBaseURL(tc.raw, tc.expectedPort, allowed)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Fatalf("expected error to contain %q, got %q", tc.errContains, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if u.String() != tc.want {
|
||||
t.Fatalf("expected %q, got %q", tc.want, u.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,359 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
neturl "net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/network"
|
||||
)
|
||||
|
||||
// InternalServiceHostAllowlistEnvVar controls which *additional* hostnames (exact matches)
|
||||
// are permitted for internal service HTTP calls (CrowdSec LAPI, Caddy Admin, etc.).
|
||||
//
|
||||
// Default policy remains localhost-only.
|
||||
// Example: CHARON_SSRF_INTERNAL_HOST_ALLOWLIST="crowdsec,caddy"
|
||||
const InternalServiceHostAllowlistEnvVar = "CHARON_SSRF_INTERNAL_HOST_ALLOWLIST"
|
||||
|
||||
// ParseExactHostnameAllowlist parses a comma-separated list of hostnames into an exact-match set.
|
||||
//
|
||||
// Notes:
|
||||
// - Hostnames are lowercased for comparison.
|
||||
// - Entries containing schemes/paths are ignored.
|
||||
func ParseExactHostnameAllowlist(csv string) map[string]struct{} {
|
||||
out := make(map[string]struct{})
|
||||
for _, part := range strings.Split(csv, ",") {
|
||||
h := strings.ToLower(strings.TrimSpace(part))
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
// Reject obvious non-hostname inputs.
|
||||
if strings.Contains(h, "://") || strings.ContainsAny(h, "/@") {
|
||||
continue
|
||||
}
|
||||
out[h] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// InternalServiceHostAllowlist returns the deny-by-default internal-service hostname allowlist.
|
||||
//
|
||||
// Defaults: localhost-only. Docker/service-name deployments must opt-in via
|
||||
// CHARON_SSRF_INTERNAL_HOST_ALLOWLIST.
|
||||
func InternalServiceHostAllowlist() map[string]struct{} {
|
||||
allow := map[string]struct{}{
|
||||
"localhost": {},
|
||||
"127.0.0.1": {},
|
||||
"::1": {},
|
||||
}
|
||||
extra := ParseExactHostnameAllowlist(os.Getenv(InternalServiceHostAllowlistEnvVar))
|
||||
for h := range extra {
|
||||
allow[h] = struct{}{}
|
||||
}
|
||||
return allow
|
||||
}
|
||||
|
||||
// ValidateInternalServiceBaseURL validates a configured base URL for an internal service.
|
||||
//
|
||||
// Security model:
|
||||
// - host must be an exact match in allowedHosts
|
||||
// - port must match expectedPort (including default ports if omitted)
|
||||
// - proxy env vars must be ignored by callers (client/transport responsibility)
|
||||
//
|
||||
// Returns a normalized base URL (scheme://host:expectedPort) suitable for safe request construction.
|
||||
func ValidateInternalServiceBaseURL(rawURL string, expectedPort int, allowedHosts map[string]struct{}) (*neturl.URL, error) {
|
||||
u, err := neturl.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid url format: %w", err)
|
||||
}
|
||||
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return nil, fmt.Errorf("unsupported scheme: %s (only http and https are allowed)", u.Scheme)
|
||||
}
|
||||
if u.User != nil {
|
||||
return nil, fmt.Errorf("urls with embedded credentials are not allowed")
|
||||
}
|
||||
|
||||
host := strings.ToLower(u.Hostname())
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("missing hostname in url")
|
||||
}
|
||||
if _, ok := allowedHosts[host]; !ok {
|
||||
return nil, fmt.Errorf("hostname not allowed: %s", host)
|
||||
}
|
||||
|
||||
actualPort := 0
|
||||
if p := u.Port(); p != "" {
|
||||
portNum, perr := strconv.Atoi(p)
|
||||
if perr != nil || portNum < 1 || portNum > 65535 {
|
||||
return nil, fmt.Errorf("invalid port")
|
||||
}
|
||||
actualPort = portNum
|
||||
} else {
|
||||
if u.Scheme == "https" {
|
||||
actualPort = 443
|
||||
} else {
|
||||
actualPort = 80
|
||||
}
|
||||
}
|
||||
if actualPort != expectedPort {
|
||||
return nil, fmt.Errorf("unexpected port: %d (expected %d)", actualPort, expectedPort)
|
||||
}
|
||||
|
||||
// Normalize to a base URL with an explicit expected port.
|
||||
base := &neturl.URL{
|
||||
Scheme: u.Scheme,
|
||||
Host: net.JoinHostPort(u.Hostname(), strconv.Itoa(expectedPort)),
|
||||
}
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// ValidationConfig holds options for URL validation.
|
||||
type ValidationConfig struct {
|
||||
AllowLocalhost bool
|
||||
AllowHTTP bool
|
||||
MaxRedirects int
|
||||
Timeout time.Duration
|
||||
BlockPrivateIPs bool
|
||||
}
|
||||
|
||||
// ValidationOption allows customizing validation behavior.
|
||||
type ValidationOption func(*ValidationConfig)
|
||||
|
||||
// WithAllowLocalhost permits localhost addresses for testing (default: false).
|
||||
func WithAllowLocalhost() ValidationOption {
|
||||
return func(c *ValidationConfig) { c.AllowLocalhost = true }
|
||||
}
|
||||
|
||||
// WithAllowHTTP permits HTTP scheme (default: false, HTTPS only).
|
||||
func WithAllowHTTP() ValidationOption {
|
||||
return func(c *ValidationConfig) { c.AllowHTTP = true }
|
||||
}
|
||||
|
||||
// WithTimeout sets the DNS resolution timeout (default: 3 seconds).
|
||||
func WithTimeout(timeout time.Duration) ValidationOption {
|
||||
return func(c *ValidationConfig) { c.Timeout = timeout }
|
||||
}
|
||||
|
||||
// WithMaxRedirects sets the maximum number of redirects to follow (default: 0).
|
||||
func WithMaxRedirects(maxRedirects int) ValidationOption {
|
||||
return func(c *ValidationConfig) { c.MaxRedirects = maxRedirects }
|
||||
}
|
||||
|
||||
// ValidateExternalURL validates a URL for external HTTP requests with comprehensive SSRF protection.
|
||||
// This function provides defense-in-depth against Server-Side Request Forgery attacks by:
|
||||
// 1. Validating URL format and scheme
|
||||
// 2. Resolving DNS and checking all resolved IPs against private/reserved ranges
|
||||
// 3. Blocking access to cloud metadata endpoints (AWS, GCP, Azure)
|
||||
// 4. Enforcing HTTPS by default (configurable)
|
||||
//
|
||||
// Returns: normalized URL string, error
|
||||
//
|
||||
// Security: This function blocks access to:
|
||||
// - Private IP ranges (RFC 1918: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
|
||||
// - Loopback addresses (127.0.0.0/8, ::1/128) unless AllowLocalhost option is set
|
||||
// - Link-local addresses (169.254.0.0/16, fe80::/10) including cloud metadata endpoints
|
||||
// - Reserved IP ranges (0.0.0.0/8, 240.0.0.0/4, 255.255.255.255/32)
|
||||
// - IPv6 unique local addresses (fc00::/7)
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// // Production use (HTTPS only, no private IPs)
|
||||
// url, err := ValidateExternalURL("https://api.example.com/webhook")
|
||||
//
|
||||
// // Testing use (allow localhost and HTTP)
|
||||
// url, err := ValidateExternalURL("http://localhost:8080/test",
|
||||
// WithAllowLocalhost(),
|
||||
// WithAllowHTTP())
|
||||
func ValidateExternalURL(rawURL string, options ...ValidationOption) (string, error) {
|
||||
// Apply default configuration
|
||||
config := &ValidationConfig{
|
||||
AllowLocalhost: false,
|
||||
AllowHTTP: false,
|
||||
MaxRedirects: 0,
|
||||
Timeout: 3 * time.Second,
|
||||
BlockPrivateIPs: true,
|
||||
}
|
||||
|
||||
// Apply custom options
|
||||
for _, opt := range options {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
// Phase 1: URL Format Validation
|
||||
u, err := neturl.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid url format: %w", err)
|
||||
}
|
||||
|
||||
// Validate scheme - only http/https allowed
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return "", fmt.Errorf("unsupported scheme: %s (only http and https are allowed)", u.Scheme)
|
||||
}
|
||||
|
||||
// Enforce HTTPS unless explicitly allowed
|
||||
if !config.AllowHTTP && u.Scheme != "https" {
|
||||
return "", fmt.Errorf("http scheme not allowed (use https for security)")
|
||||
}
|
||||
|
||||
// Validate hostname exists
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("missing hostname in url")
|
||||
}
|
||||
|
||||
// ENHANCEMENT: Hostname Length Validation (RFC 1035)
|
||||
const maxHostnameLength = 253
|
||||
if len(host) > maxHostnameLength {
|
||||
return "", fmt.Errorf("hostname exceeds maximum length of %d characters", maxHostnameLength)
|
||||
}
|
||||
|
||||
// ENHANCEMENT: Suspicious Pattern Detection
|
||||
if strings.Contains(host, "..") {
|
||||
return "", fmt.Errorf("hostname contains suspicious pattern (..)")
|
||||
}
|
||||
|
||||
// Reject URLs with credentials in authority section
|
||||
if u.User != nil {
|
||||
return "", fmt.Errorf("urls with embedded credentials are not allowed")
|
||||
}
|
||||
|
||||
// ENHANCEMENT: Port Range Validation
|
||||
if port := u.Port(); port != "" {
|
||||
portNum, err := parsePort(port)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
if portNum < 1 || portNum > 65535 {
|
||||
return "", fmt.Errorf("port out of range: %d", portNum)
|
||||
}
|
||||
// CRITICAL FIX: Allow standard ports 80/443, block other privileged ports
|
||||
standardPorts := map[int]bool{80: true, 443: true}
|
||||
if portNum < 1024 && !standardPorts[portNum] && !config.AllowLocalhost {
|
||||
return "", fmt.Errorf("non-standard privileged port blocked: %d", portNum)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Localhost Exception Handling
|
||||
if config.AllowLocalhost {
|
||||
// Check if this is an explicit localhost address
|
||||
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
||||
// Normalize and return - localhost is allowed
|
||||
return u.String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: DNS Resolution and IP Validation
|
||||
// Resolve hostname with timeout
|
||||
resolver := &net.Resolver{}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := resolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dns resolution failed for %s: %w", host, err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no ip addresses resolved for hostname: %s", host)
|
||||
}
|
||||
|
||||
// Phase 4: Private IP Blocking
|
||||
// Check ALL resolved IPs against private/reserved ranges
|
||||
if config.BlockPrivateIPs {
|
||||
for _, ip := range ips {
|
||||
// ENHANCEMENT: IPv4-mapped IPv6 Detection
|
||||
// Prevent bypass via ::ffff:192.168.1.1 format
|
||||
if ip.To4() != nil && ip.To16() != nil && isIPv4MappedIPv6(ip) {
|
||||
// Extract the IPv4 address from the mapped format
|
||||
ipv4 := ip.To4()
|
||||
if network.IsPrivateIP(ipv4) {
|
||||
return "", fmt.Errorf("connection to private ip addresses is blocked for security (detected IPv4-mapped IPv6: %s)", ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Check if IP is in private/reserved ranges using centralized network.IsPrivateIP
|
||||
// This includes:
|
||||
// - RFC 1918 private networks (10.x, 172.16.x, 192.168.x)
|
||||
// - Loopback (127.x.x.x, ::1)
|
||||
// - Link-local (169.254.x.x, fe80::) including cloud metadata
|
||||
// - Reserved ranges (0.x.x.x, 240.x.x.x, 255.255.255.255)
|
||||
// - IPv6 unique local (fc00::)
|
||||
if network.IsPrivateIP(ip) {
|
||||
// ENHANCEMENT: Sanitize Error Messages
|
||||
// Don't leak internal IPs in error messages to external users
|
||||
sanitizedIP := sanitizeIPForError(ip.String())
|
||||
if ip.String() == "169.254.169.254" {
|
||||
return "", fmt.Errorf("access to cloud metadata endpoints is blocked for security (detected: %s)", sanitizedIP)
|
||||
}
|
||||
return "", fmt.Errorf("connection to private ip addresses is blocked for security (detected: %s)", sanitizedIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize URL (trim trailing slashes, lowercase host)
|
||||
normalized := u.String()
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// isIPv4MappedIPv6 detects IPv4-mapped IPv6 addresses (::ffff:192.168.1.1).
|
||||
// This prevents SSRF bypass via IPv6 notation of private IPv4 addresses.
|
||||
func isIPv4MappedIPv6(ip net.IP) bool {
|
||||
// IPv4-mapped IPv6 addresses have the form ::ffff:a.b.c.d
|
||||
// In binary: 80 bits of zeros, 16 bits of ones, 32 bits of IPv4
|
||||
if len(ip) != net.IPv6len {
|
||||
return false
|
||||
}
|
||||
// Check for ::ffff: prefix (10 zero bytes, 2 0xff bytes)
|
||||
for i := 0; i < 10; i++ {
|
||||
if ip[i] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return ip[10] == 0xff && ip[11] == 0xff
|
||||
}
|
||||
|
||||
// parsePort safely parses a port string to an integer.
|
||||
func parsePort(port string) (int, error) {
|
||||
if port == "" {
|
||||
return 0, fmt.Errorf("empty port string")
|
||||
}
|
||||
var portNum int
|
||||
_, err := fmt.Sscanf(port, "%d", &portNum)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("port must be numeric: %s", port)
|
||||
}
|
||||
return portNum, nil
|
||||
}
|
||||
|
||||
// sanitizeIPForError removes sensitive details from IP addresses in error messages.
|
||||
// This prevents leaking internal network topology to external users.
|
||||
func sanitizeIPForError(ip string) string {
|
||||
// For private IPs, show only the first octet to avoid leaking network structure
|
||||
// Example: 192.168.1.100 -> 192.x.x.x
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return "invalid-ip"
|
||||
}
|
||||
|
||||
if parsedIP.To4() != nil {
|
||||
// IPv4: show only first octet
|
||||
parts := strings.Split(ip, ".")
|
||||
if len(parts) == 4 {
|
||||
return parts[0] + ".x.x.x"
|
||||
}
|
||||
} else {
|
||||
// IPv6: show only first segment
|
||||
parts := strings.Split(ip, ":")
|
||||
if len(parts) > 0 {
|
||||
return parts[0] + "::"
|
||||
}
|
||||
}
|
||||
return "private-ip"
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestInternalServiceHostAllowlist tests the internal service hostname allowlist.
|
||||
func TestInternalServiceHostAllowlist(t *testing.T) {
|
||||
// Save original env var
|
||||
originalEnv := os.Getenv(InternalServiceHostAllowlistEnvVar)
|
||||
defer func() { _ = os.Setenv(InternalServiceHostAllowlistEnvVar, originalEnv) }()
|
||||
|
||||
t.Run("DefaultLocalhostOnly", func(t *testing.T) {
|
||||
_ = os.Setenv(InternalServiceHostAllowlistEnvVar, "")
|
||||
allowlist := InternalServiceHostAllowlist()
|
||||
|
||||
// Should contain localhost entries
|
||||
expected := []string{"localhost", "127.0.0.1", "::1"}
|
||||
for _, host := range expected {
|
||||
if _, ok := allowlist[host]; !ok {
|
||||
t.Errorf("Expected %s to be in default allowlist", host)
|
||||
}
|
||||
}
|
||||
|
||||
// Should only have 3 localhost entries
|
||||
if len(allowlist) != 3 {
|
||||
t.Errorf("Expected 3 entries in default allowlist, got %d", len(allowlist))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WithAdditionalHosts", func(t *testing.T) {
|
||||
_ = os.Setenv(InternalServiceHostAllowlistEnvVar, "crowdsec,caddy,traefik")
|
||||
allowlist := InternalServiceHostAllowlist()
|
||||
|
||||
// Should contain localhost + additional hosts
|
||||
expected := []string{"localhost", "127.0.0.1", "::1", "crowdsec", "caddy", "traefik"}
|
||||
for _, host := range expected {
|
||||
if _, ok := allowlist[host]; !ok {
|
||||
t.Errorf("Expected %s to be in allowlist", host)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowlist) != 6 {
|
||||
t.Errorf("Expected 6 entries in allowlist, got %d", len(allowlist))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WithEmptyAndWhitespaceEntries", func(t *testing.T) {
|
||||
_ = os.Setenv(InternalServiceHostAllowlistEnvVar, " , crowdsec , , caddy , ")
|
||||
allowlist := InternalServiceHostAllowlist()
|
||||
|
||||
// Should contain localhost + valid hosts (empty and whitespace ignored)
|
||||
expected := []string{"localhost", "127.0.0.1", "::1", "crowdsec", "caddy"}
|
||||
for _, host := range expected {
|
||||
if _, ok := allowlist[host]; !ok {
|
||||
t.Errorf("Expected %s to be in allowlist", host)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowlist) != 5 {
|
||||
t.Errorf("Expected 5 entries in allowlist, got %d", len(allowlist))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WithInvalidEntries", func(t *testing.T) {
|
||||
_ = os.Setenv(InternalServiceHostAllowlistEnvVar, "crowdsec,http://invalid,user@host,/path")
|
||||
allowlist := InternalServiceHostAllowlist()
|
||||
|
||||
// Should only have localhost + crowdsec (others rejected)
|
||||
if _, ok := allowlist["crowdsec"]; !ok {
|
||||
t.Error("Expected crowdsec to be in allowlist")
|
||||
}
|
||||
if _, ok := allowlist["http://invalid"]; ok {
|
||||
t.Error("Did not expect http://invalid to be in allowlist")
|
||||
}
|
||||
if _, ok := allowlist["user@host"]; ok {
|
||||
t.Error("Did not expect user@host to be in allowlist")
|
||||
}
|
||||
if _, ok := allowlist["/path"]; ok {
|
||||
t.Error("Did not expect /path to be in allowlist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestWithMaxRedirects tests the WithMaxRedirects validation option.
|
||||
func TestWithMaxRedirects(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value int
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Zero redirects",
|
||||
value: 0,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Five redirects",
|
||||
value: 5,
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "Ten redirects",
|
||||
value: 10,
|
||||
expected: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ValidationConfig{}
|
||||
opt := WithMaxRedirects(tt.value)
|
||||
opt(config)
|
||||
|
||||
if config.MaxRedirects != tt.expected {
|
||||
t.Errorf("Expected MaxRedirects=%d, got %d", tt.expected, config.MaxRedirects)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateInternalServiceBaseURL_AdditionalCases tests edge cases for ValidateInternalServiceBaseURL.
|
||||
func TestValidateInternalServiceBaseURL_AdditionalCases(t *testing.T) {
|
||||
allowlist := map[string]struct{}{
|
||||
"localhost": {},
|
||||
"caddy": {},
|
||||
}
|
||||
|
||||
t.Run("HTTPSWithDefaultPort", func(t *testing.T) {
|
||||
// HTTPS without explicit port should default to 443
|
||||
url, err := ValidateInternalServiceBaseURL("https://localhost", 443, allowlist)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if url.String() != "https://localhost:443" {
|
||||
t.Errorf("Expected https://localhost:443, got %s", url.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTPWithDefaultPort", func(t *testing.T) {
|
||||
// HTTP without explicit port should default to 80
|
||||
url, err := ValidateInternalServiceBaseURL("http://localhost", 80, allowlist)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if url.String() != "http://localhost:80" {
|
||||
t.Errorf("Expected http://localhost:80, got %s", url.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PortMismatchWithDefaultHTTPS", func(t *testing.T) {
|
||||
// HTTPS defaults to 443, but we expect 2019
|
||||
_, err := ValidateInternalServiceBaseURL("https://localhost", 2019, allowlist)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for port mismatch, got nil")
|
||||
}
|
||||
if !contains(err.Error(), "unexpected port") {
|
||||
t.Errorf("Expected 'unexpected port' error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PortMismatchWithDefaultHTTP", func(t *testing.T) {
|
||||
// HTTP defaults to 80, but we expect 8080
|
||||
_, err := ValidateInternalServiceBaseURL("http://localhost", 8080, allowlist)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for port mismatch, got nil")
|
||||
}
|
||||
if !contains(err.Error(), "unexpected port") {
|
||||
t.Errorf("Expected 'unexpected port' error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InvalidPortNumber", func(t *testing.T) {
|
||||
_, err := ValidateInternalServiceBaseURL("http://localhost:99999", 99999, allowlist)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid port, got nil")
|
||||
}
|
||||
if !contains(err.Error(), "invalid port") {
|
||||
t.Errorf("Expected 'invalid port' error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NegativePort", func(t *testing.T) {
|
||||
_, err := ValidateInternalServiceBaseURL("http://localhost:-1", -1, allowlist)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for negative port, got nil")
|
||||
}
|
||||
if !contains(err.Error(), "invalid port") {
|
||||
t.Errorf("Expected 'invalid port' error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HostNotInAllowlist", func(t *testing.T) {
|
||||
_, err := ValidateInternalServiceBaseURL("http://evil.com:80", 80, allowlist)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for disallowed host, got nil")
|
||||
}
|
||||
if !contains(err.Error(), "hostname not allowed") {
|
||||
t.Errorf("Expected 'hostname not allowed' error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyAllowlist", func(t *testing.T) {
|
||||
emptyList := map[string]struct{}{}
|
||||
_, err := ValidateInternalServiceBaseURL("http://localhost:80", 80, emptyList)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty allowlist, got nil")
|
||||
}
|
||||
if !contains(err.Error(), "hostname not allowed") {
|
||||
t.Errorf("Expected 'hostname not allowed' error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitiveHostMatching", func(t *testing.T) {
|
||||
// Hostname should be case-insensitive
|
||||
url, err := ValidateInternalServiceBaseURL("http://LOCALHOST:2019", 2019, allowlist)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error for uppercase hostname, got %v", err)
|
||||
}
|
||||
if url.Hostname() != "LOCALHOST" {
|
||||
t.Errorf("Expected hostname preservation, got %s", url.Hostname())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AllowedHostDifferentCase", func(t *testing.T) {
|
||||
// Caddy in allowlist, CADDY in URL
|
||||
url, err := ValidateInternalServiceBaseURL("http://CADDY:2019", 2019, allowlist)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error for case variation, got %v", err)
|
||||
}
|
||||
if url.Hostname() != "CADDY" {
|
||||
t.Errorf("Expected hostname CADDY, got %s", url.Hostname())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSanitizeIPForError_AdditionalCases tests additional edge cases for IP sanitization.
|
||||
func TestSanitizeIPForError_AdditionalCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "InvalidIPString",
|
||||
input: "not-an-ip",
|
||||
expected: "invalid-ip",
|
||||
},
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: "",
|
||||
expected: "invalid-ip",
|
||||
},
|
||||
{
|
||||
name: "IPv4Malformed",
|
||||
input: "192.168",
|
||||
expected: "invalid-ip",
|
||||
},
|
||||
{
|
||||
name: "IPv6SingleSegment",
|
||||
input: "fe80::1",
|
||||
expected: "fe80::",
|
||||
},
|
||||
{
|
||||
name: "IPv6MultipleSegments",
|
||||
input: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
|
||||
expected: "2001::",
|
||||
},
|
||||
{
|
||||
name: "IPv6Compressed",
|
||||
input: "::1",
|
||||
expected: "::",
|
||||
},
|
||||
{
|
||||
name: "IPv4ThreeOctets",
|
||||
input: "192.168.1",
|
||||
expected: "invalid-ip",
|
||||
},
|
||||
{
|
||||
name: "IPv4FiveOctets",
|
||||
input: "192.168.1.1.1",
|
||||
expected: "invalid-ip",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizeIPForError(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsSubstring(s, substr))
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user