Files
Charon/backend/internal/security/url_validator_coverage_test.go

310 lines
8.6 KiB
Go

package security
import (
"os"
"testing"
)
// TestInternalServiceHostAllowlist tests the internal service hostname allowlist.
func TestInternalServiceHostAllowlist(t *testing.T) {
// Save original env var
originalEnv := os.Getenv(InternalServiceHostAllowlistEnvVar)
defer os.Setenv(InternalServiceHostAllowlistEnvVar, originalEnv)
t.Run("DefaultLocalhostOnly", func(t *testing.T) {
os.Setenv(InternalServiceHostAllowlistEnvVar, "")
allowlist := InternalServiceHostAllowlist()
// Should contain localhost entries
expected := []string{"localhost", "127.0.0.1", "::1"}
for _, host := range expected {
if _, ok := allowlist[host]; !ok {
t.Errorf("Expected %s to be in default allowlist", host)
}
}
// Should only have 3 localhost entries
if len(allowlist) != 3 {
t.Errorf("Expected 3 entries in default allowlist, got %d", len(allowlist))
}
})
t.Run("WithAdditionalHosts", func(t *testing.T) {
os.Setenv(InternalServiceHostAllowlistEnvVar, "crowdsec,caddy,traefik")
allowlist := InternalServiceHostAllowlist()
// Should contain localhost + additional hosts
expected := []string{"localhost", "127.0.0.1", "::1", "crowdsec", "caddy", "traefik"}
for _, host := range expected {
if _, ok := allowlist[host]; !ok {
t.Errorf("Expected %s to be in allowlist", host)
}
}
if len(allowlist) != 6 {
t.Errorf("Expected 6 entries in allowlist, got %d", len(allowlist))
}
})
t.Run("WithEmptyAndWhitespaceEntries", func(t *testing.T) {
os.Setenv(InternalServiceHostAllowlistEnvVar, " , crowdsec , , caddy , ")
allowlist := InternalServiceHostAllowlist()
// Should contain localhost + valid hosts (empty and whitespace ignored)
expected := []string{"localhost", "127.0.0.1", "::1", "crowdsec", "caddy"}
for _, host := range expected {
if _, ok := allowlist[host]; !ok {
t.Errorf("Expected %s to be in allowlist", host)
}
}
if len(allowlist) != 5 {
t.Errorf("Expected 5 entries in allowlist, got %d", len(allowlist))
}
})
t.Run("WithInvalidEntries", func(t *testing.T) {
os.Setenv(InternalServiceHostAllowlistEnvVar, "crowdsec,http://invalid,user@host,/path")
allowlist := InternalServiceHostAllowlist()
// Should only have localhost + crowdsec (others rejected)
if _, ok := allowlist["crowdsec"]; !ok {
t.Error("Expected crowdsec to be in allowlist")
}
if _, ok := allowlist["http://invalid"]; ok {
t.Error("Did not expect http://invalid to be in allowlist")
}
if _, ok := allowlist["user@host"]; ok {
t.Error("Did not expect user@host to be in allowlist")
}
if _, ok := allowlist["/path"]; ok {
t.Error("Did not expect /path to be in allowlist")
}
})
}
// TestWithMaxRedirects tests the WithMaxRedirects validation option.
func TestWithMaxRedirects(t *testing.T) {
tests := []struct {
name string
value int
expected int
}{
{
name: "Zero redirects",
value: 0,
expected: 0,
},
{
name: "Five redirects",
value: 5,
expected: 5,
},
{
name: "Ten redirects",
value: 10,
expected: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &ValidationConfig{}
opt := WithMaxRedirects(tt.value)
opt(config)
if config.MaxRedirects != tt.expected {
t.Errorf("Expected MaxRedirects=%d, got %d", tt.expected, config.MaxRedirects)
}
})
}
}
// TestValidateInternalServiceBaseURL_AdditionalCases tests edge cases for ValidateInternalServiceBaseURL.
func TestValidateInternalServiceBaseURL_AdditionalCases(t *testing.T) {
allowlist := map[string]struct{}{
"localhost": {},
"caddy": {},
}
t.Run("HTTPSWithDefaultPort", func(t *testing.T) {
// HTTPS without explicit port should default to 443
url, err := ValidateInternalServiceBaseURL("https://localhost", 443, allowlist)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if url.String() != "https://localhost:443" {
t.Errorf("Expected https://localhost:443, got %s", url.String())
}
})
t.Run("HTTPWithDefaultPort", func(t *testing.T) {
// HTTP without explicit port should default to 80
url, err := ValidateInternalServiceBaseURL("http://localhost", 80, allowlist)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if url.String() != "http://localhost:80" {
t.Errorf("Expected http://localhost:80, got %s", url.String())
}
})
t.Run("PortMismatchWithDefaultHTTPS", func(t *testing.T) {
// HTTPS defaults to 443, but we expect 2019
_, err := ValidateInternalServiceBaseURL("https://localhost", 2019, allowlist)
if err == nil {
t.Fatal("Expected error for port mismatch, got nil")
}
if !contains(err.Error(), "unexpected port") {
t.Errorf("Expected 'unexpected port' error, got %v", err)
}
})
t.Run("PortMismatchWithDefaultHTTP", func(t *testing.T) {
// HTTP defaults to 80, but we expect 8080
_, err := ValidateInternalServiceBaseURL("http://localhost", 8080, allowlist)
if err == nil {
t.Fatal("Expected error for port mismatch, got nil")
}
if !contains(err.Error(), "unexpected port") {
t.Errorf("Expected 'unexpected port' error, got %v", err)
}
})
t.Run("InvalidPortNumber", func(t *testing.T) {
_, err := ValidateInternalServiceBaseURL("http://localhost:99999", 99999, allowlist)
if err == nil {
t.Fatal("Expected error for invalid port, got nil")
}
if !contains(err.Error(), "invalid port") {
t.Errorf("Expected 'invalid port' error, got %v", err)
}
})
t.Run("NegativePort", func(t *testing.T) {
_, err := ValidateInternalServiceBaseURL("http://localhost:-1", -1, allowlist)
if err == nil {
t.Fatal("Expected error for negative port, got nil")
}
if !contains(err.Error(), "invalid port") {
t.Errorf("Expected 'invalid port' error, got %v", err)
}
})
t.Run("HostNotInAllowlist", func(t *testing.T) {
_, err := ValidateInternalServiceBaseURL("http://evil.com:80", 80, allowlist)
if err == nil {
t.Fatal("Expected error for disallowed host, got nil")
}
if !contains(err.Error(), "hostname not allowed") {
t.Errorf("Expected 'hostname not allowed' error, got %v", err)
}
})
t.Run("EmptyAllowlist", func(t *testing.T) {
emptyList := map[string]struct{}{}
_, err := ValidateInternalServiceBaseURL("http://localhost:80", 80, emptyList)
if err == nil {
t.Fatal("Expected error for empty allowlist, got nil")
}
if !contains(err.Error(), "hostname not allowed") {
t.Errorf("Expected 'hostname not allowed' error, got %v", err)
}
})
t.Run("CaseInsensitiveHostMatching", func(t *testing.T) {
// Hostname should be case-insensitive
url, err := ValidateInternalServiceBaseURL("http://LOCALHOST:2019", 2019, allowlist)
if err != nil {
t.Fatalf("Expected no error for uppercase hostname, got %v", err)
}
if url.Hostname() != "LOCALHOST" {
t.Errorf("Expected hostname preservation, got %s", url.Hostname())
}
})
t.Run("AllowedHostDifferentCase", func(t *testing.T) {
// Caddy in allowlist, CADDY in URL
url, err := ValidateInternalServiceBaseURL("http://CADDY:2019", 2019, allowlist)
if err != nil {
t.Fatalf("Expected no error for case variation, got %v", err)
}
if url.Hostname() != "CADDY" {
t.Errorf("Expected hostname CADDY, got %s", url.Hostname())
}
})
}
// TestSanitizeIPForError_AdditionalCases tests additional edge cases for IP sanitization.
func TestSanitizeIPForError_AdditionalCases(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "InvalidIPString",
input: "not-an-ip",
expected: "invalid-ip",
},
{
name: "EmptyString",
input: "",
expected: "invalid-ip",
},
{
name: "IPv4Malformed",
input: "192.168",
expected: "invalid-ip",
},
{
name: "IPv6SingleSegment",
input: "fe80::1",
expected: "fe80::",
},
{
name: "IPv6MultipleSegments",
input: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
expected: "2001::",
},
{
name: "IPv6Compressed",
input: "::1",
expected: "::",
},
{
name: "IPv4ThreeOctets",
input: "192.168.1",
expected: "invalid-ip",
},
{
name: "IPv4FiveOctets",
input: "192.168.1.1.1",
expected: "invalid-ip",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizeIPForError(tt.input)
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsSubstring(s, substr))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}