package network import ( "context" "net" "net/http" "net/http/httptest" "testing" "time" ) func TestIsPrivateIP(t *testing.T) { t.Parallel() tests := []struct { name string ip string expected bool }{ // Private IPv4 ranges {"10.0.0.0/8 start", "10.0.0.1", true}, {"10.0.0.0/8 middle", "10.255.255.255", true}, {"172.16.0.0/12 start", "172.16.0.1", true}, {"172.16.0.0/12 end", "172.31.255.255", true}, {"192.168.0.0/16 start", "192.168.0.1", true}, {"192.168.0.0/16 end", "192.168.255.255", true}, // Link-local {"169.254.0.0/16 start", "169.254.0.1", true}, {"169.254.0.0/16 end", "169.254.255.255", true}, // Loopback {"127.0.0.0/8 localhost", "127.0.0.1", true}, {"127.0.0.0/8 other", "127.0.0.2", true}, {"127.0.0.0/8 end", "127.255.255.255", true}, // Special addresses {"0.0.0.0/8", "0.0.0.1", true}, {"240.0.0.0/4 reserved", "240.0.0.1", true}, {"255.255.255.255 broadcast", "255.255.255.255", true}, // IPv6 private ranges {"IPv6 loopback", "::1", true}, {"fc00::/7 unique local", "fc00::1", true}, {"fd00::/8 unique local", "fd00::1", true}, {"fe80::/10 link-local", "fe80::1", true}, // Public IPs (should return false) {"Public IPv4 1", "8.8.8.8", false}, {"Public IPv4 2", "1.1.1.1", false}, {"Public IPv4 3", "93.184.216.34", false}, {"Public IPv6", "2001:4860:4860::8888", false}, // Edge cases {"Just outside 172.16", "172.15.255.255", false}, {"Just outside 172.31", "172.32.0.0", false}, {"Just outside 192.168", "192.167.255.255", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { t.Fatalf("failed to parse IP: %s", tt.ip) } result := IsPrivateIP(ip) if result != tt.expected { t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected) } }) } } func TestIsPrivateIP_NilIP(t *testing.T) { t.Parallel() // nil IP should return true (block by default for safety) result := IsPrivateIP(nil) if result != true { t.Errorf("IsPrivateIP(nil) = %v, want true", result) } } func TestSafeDialer_BlocksPrivateIPs(t *testing.T) { t.Parallel() tests := []struct { name string address string shouldBlock bool }{ {"blocks 10.x.x.x", "10.0.0.1:80", true}, {"blocks 172.16.x.x", "172.16.0.1:80", true}, {"blocks 192.168.x.x", "192.168.1.1:80", true}, {"blocks 127.0.0.1", "127.0.0.1:80", true}, {"blocks localhost", "localhost:80", true}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn, err := dialer(ctx, "tcp", tt.address) if tt.shouldBlock { if err == nil { conn.Close() t.Errorf("expected connection to %s to be blocked", tt.address) } } }) } } func TestSafeDialer_AllowsLocalhost(t *testing.T) { t.Parallel() // Create a local test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() // Extract host:port from test server URL addr := server.Listener.Addr().String() opts := &ClientOptions{ AllowLocalhost: true, DialTimeout: 5 * time.Second, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := dialer(ctx, "tcp", addr) if err != nil { t.Errorf("expected connection to localhost to be allowed when allowLocalhost=true, got error: %v", err) return } conn.Close() } func TestSafeDialer_AllowedDomains(t *testing.T) { t.Parallel() opts := &ClientOptions{ AllowLocalhost: false, AllowedDomains: []string{"app.crowdsec.net", "hub.crowdsec.net"}, DialTimeout: time.Second, } dialer := safeDialer(opts) // Test that allowed domain passes validation (we can't actually connect) // This is a structural test - we're verifying the domain check passes ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() // This will fail to connect (no server) but should NOT fail validation _, err := dialer(ctx, "tcp", "app.crowdsec.net:443") if err != nil { // Check it's a connection error, not a validation error if _, ok := err.(*net.OpError); !ok { // Context deadline exceeded is also acceptable (DNS/connection timeout) if err != context.DeadlineExceeded { t.Logf("Got expected error type for allowed domain: %T: %v", err, err) } } } } func TestNewSafeHTTPClient_DefaultOptions(t *testing.T) { t.Parallel() client := NewSafeHTTPClient() if client == nil { t.Fatal("NewSafeHTTPClient() returned nil") } if client.Timeout != 10*time.Second { t.Errorf("expected default timeout of 10s, got %v", client.Timeout) } } func TestNewSafeHTTPClient_WithTimeout(t *testing.T) { t.Parallel() client := NewSafeHTTPClient(WithTimeout(10 * time.Second)) if client == nil { t.Fatal("NewSafeHTTPClient() returned nil") } if client.Timeout != 10*time.Second { t.Errorf("expected timeout of 10s, got %v", client.Timeout) } } func TestNewSafeHTTPClient_WithAllowLocalhost(t *testing.T) { t.Parallel() // Create a local test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() client := NewSafeHTTPClient( WithTimeout(5*time.Second), WithAllowLocalhost(), ) resp, err := client.Get(server.URL) if err != nil { t.Fatalf("expected request to localhost to succeed with allowLocalhost, got: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } } func TestNewSafeHTTPClient_BlocksSSRF(t *testing.T) { client := NewSafeHTTPClient( WithTimeout(2 * time.Second), ) // Test that internal IPs are blocked urls := []string{ "http://127.0.0.1/", "http://10.0.0.1/", "http://172.16.0.1/", "http://192.168.1.1/", "http://localhost/", } for _, url := range urls { t.Run(url, func(t *testing.T) { resp, err := client.Get(url) if err == nil { defer resp.Body.Close() t.Errorf("expected request to %s to be blocked", url) } }) } } func TestNewSafeHTTPClient_WithMaxRedirects(t *testing.T) { redirectCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { redirectCount++ if redirectCount < 5 { http.Redirect(w, r, "/redirect", http.StatusFound) return } w.WriteHeader(http.StatusOK) })) defer server.Close() client := NewSafeHTTPClient( WithTimeout(5*time.Second), WithAllowLocalhost(), WithMaxRedirects(2), ) resp, err := client.Get(server.URL) if err == nil { defer resp.Body.Close() t.Error("expected redirect limit to be enforced") } } func TestNewSafeHTTPClient_WithAllowedDomains(t *testing.T) { client := NewSafeHTTPClient( WithTimeout(2*time.Second), WithAllowedDomains("example.com"), ) if client == nil { t.Fatal("NewSafeHTTPClient() returned nil") } // We can't actually connect, but we verify the client is created // with the correct configuration } func TestClientOptions_Defaults(t *testing.T) { opts := defaultOptions() if opts.Timeout != 10*time.Second { t.Errorf("expected default timeout 10s, got %v", opts.Timeout) } if opts.MaxRedirects != 0 { t.Errorf("expected default maxRedirects 0, got %d", opts.MaxRedirects) } if opts.DialTimeout != 5*time.Second { t.Errorf("expected default dialTimeout 5s, got %v", opts.DialTimeout) } } func TestWithDialTimeout(t *testing.T) { client := NewSafeHTTPClient(WithDialTimeout(5 * time.Second)) if client == nil { t.Fatal("NewSafeHTTPClient() returned nil") } } // Benchmark tests func BenchmarkIsPrivateIP_IPv4Private(b *testing.B) { ip := net.ParseIP("192.168.1.1") b.ResetTimer() for i := 0; i < b.N; i++ { IsPrivateIP(ip) } } func BenchmarkIsPrivateIP_IPv4Public(b *testing.B) { ip := net.ParseIP("8.8.8.8") b.ResetTimer() for i := 0; i < b.N; i++ { IsPrivateIP(ip) } } func BenchmarkIsPrivateIP_IPv6(b *testing.B) { ip := net.ParseIP("2001:4860:4860::8888") b.ResetTimer() for i := 0; i < b.N; i++ { IsPrivateIP(ip) } } func BenchmarkNewSafeHTTPClient(b *testing.B) { for i := 0; i < b.N; i++ { NewSafeHTTPClient( WithTimeout(10*time.Second), WithAllowLocalhost(), ) } } // Additional tests to increase coverage func TestSafeDialer_InvalidAddress(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // Test invalid address format (no port) _, err := dialer(ctx, "tcp", "invalid-address-no-port") if err == nil { t.Error("expected error for invalid address format") } } func TestSafeDialer_LoopbackIPv6(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: true, DialTimeout: time.Second, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // Test IPv6 loopback with AllowLocalhost _, err := dialer(ctx, "tcp", "[::1]:80") // Should fail to connect but not due to validation if err != nil { t.Logf("Expected connection error (not validation): %v", err) } } func TestValidateRedirectTarget_EmptyHostname(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } // Create request with empty hostname req, _ := http.NewRequest("GET", "http:///path", http.NoBody) err := validateRedirectTarget(req, opts) if err == nil { t.Error("expected error for empty hostname") } } func TestValidateRedirectTarget_Localhost(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } // Test localhost blocked req, _ := http.NewRequest("GET", "http://localhost/path", http.NoBody) err := validateRedirectTarget(req, opts) if err == nil { t.Error("expected error for localhost when AllowLocalhost=false") } // Test localhost allowed opts.AllowLocalhost = true err = validateRedirectTarget(req, opts) if err != nil { t.Errorf("expected no error for localhost when AllowLocalhost=true, got: %v", err) } } func TestValidateRedirectTarget_127(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } req, _ := http.NewRequest("GET", "http://127.0.0.1/path", http.NoBody) err := validateRedirectTarget(req, opts) if err == nil { t.Error("expected error for 127.0.0.1 when AllowLocalhost=false") } opts.AllowLocalhost = true err = validateRedirectTarget(req, opts) if err != nil { t.Errorf("expected no error for 127.0.0.1 when AllowLocalhost=true, got: %v", err) } } func TestValidateRedirectTarget_IPv6Loopback(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } req, _ := http.NewRequest("GET", "http://[::1]/path", http.NoBody) err := validateRedirectTarget(req, opts) if err == nil { t.Error("expected error for ::1 when AllowLocalhost=false") } opts.AllowLocalhost = true err = validateRedirectTarget(req, opts) if err != nil { t.Errorf("expected no error for ::1 when AllowLocalhost=true, got: %v", err) } } func TestNewSafeHTTPClient_NoRedirectsByDefault(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" { http.Redirect(w, r, "/redirected", http.StatusFound) return } w.WriteHeader(http.StatusOK) })) defer server.Close() client := NewSafeHTTPClient( WithTimeout(5*time.Second), WithAllowLocalhost(), ) resp, err := client.Get(server.URL) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() // Should not follow redirect - should return 302 if resp.StatusCode != http.StatusFound { t.Errorf("expected status 302 (redirect not followed), got %d", resp.StatusCode) } } func TestIsPrivateIP_IPv4MappedIPv6(t *testing.T) { // Test IPv4-mapped IPv6 addresses tests := []struct { name string ip string expected bool }{ {"IPv4-mapped private", "::ffff:192.168.1.1", true}, {"IPv4-mapped public", "::ffff:8.8.8.8", false}, {"IPv4-mapped loopback", "::ffff:127.0.0.1", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { t.Fatalf("failed to parse IP: %s", tt.ip) } result := IsPrivateIP(ip) if result != tt.expected { t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected) } }) } } func TestIsPrivateIP_Multicast(t *testing.T) { // Test multicast addresses tests := []struct { name string ip string expected bool }{ {"IPv4 multicast", "224.0.0.1", true}, {"IPv6 multicast", "ff02::1", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { t.Fatalf("failed to parse IP: %s", tt.ip) } result := IsPrivateIP(ip) if result != tt.expected { t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected) } }) } } func TestIsPrivateIP_Unspecified(t *testing.T) { // Test unspecified addresses tests := []struct { name string ip string expected bool }{ {"IPv4 unspecified", "0.0.0.0", true}, {"IPv6 unspecified", "::", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { t.Fatalf("failed to parse IP: %s", tt.ip) } result := IsPrivateIP(ip) if result != tt.expected { t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected) } }) } } // Phase 1 Coverage Improvement Tests func TestValidateRedirectTarget_DNSFailure(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: 100 * time.Millisecond, // Short timeout to force DNS failure quickly } // Use a domain that will fail DNS resolution req, _ := http.NewRequest("GET", "http://this-domain-does-not-exist-12345.invalid/path", http.NoBody) err := validateRedirectTarget(req, opts) if err == nil { t.Error("expected error for DNS resolution failure") } // Verify the error is DNS-related if err != nil && !contains(err.Error(), "DNS resolution failed") { t.Errorf("expected DNS resolution failure error, got: %v", err) } } func TestValidateRedirectTarget_PrivateIPInRedirect(t *testing.T) { // Test that redirects to private IPs are properly blocked opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } // Test various private IP redirect scenarios privateHosts := []string{ "http://10.0.0.1/path", "http://172.16.0.1/path", "http://192.168.1.1/path", "http://169.254.169.254/latest/meta-data/", // AWS metadata endpoint } for _, url := range privateHosts { t.Run(url, func(t *testing.T) { req, _ := http.NewRequest("GET", url, http.NoBody) err := validateRedirectTarget(req, opts) if err == nil { t.Errorf("expected error for redirect to private IP: %s", url) } }) } } func TestSafeDialer_AllIPsPrivate(t *testing.T) { // Test that when all resolved IPs are private, the connection is blocked opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // Test dialing addresses that resolve to private IPs privateAddresses := []string{ "10.0.0.1:80", "172.16.0.1:443", "192.168.0.1:8080", "169.254.169.254:80", // Cloud metadata endpoint } for _, addr := range privateAddresses { t.Run(addr, func(t *testing.T) { conn, err := dialer(ctx, "tcp", addr) if err == nil { conn.Close() t.Errorf("expected connection to %s to be blocked (all IPs private)", addr) } }) } } func TestNewSafeHTTPClient_RedirectToPrivateIP(t *testing.T) { // Create a server that redirects to a private IP server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" { // Redirect to a private IP (will be blocked) http.Redirect(w, r, "http://192.168.1.1/internal", http.StatusFound) return } w.WriteHeader(http.StatusOK) })) defer server.Close() // Client with redirects enabled and localhost allowed for the test server client := NewSafeHTTPClient( WithTimeout(5*time.Second), WithAllowLocalhost(), WithMaxRedirects(3), ) // Make request - should fail when trying to follow redirect to private IP resp, err := client.Get(server.URL) if err == nil { defer resp.Body.Close() t.Error("expected error when redirect targets private IP") } } func TestSafeDialer_DNSResolutionFailure(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: 100 * time.Millisecond, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() // Use a domain that will fail DNS resolution _, err := dialer(ctx, "tcp", "nonexistent-domain-xyz123.invalid:80") if err == nil { t.Error("expected error for DNS resolution failure") } if err != nil && !contains(err.Error(), "DNS resolution failed") { t.Errorf("expected DNS resolution failure error, got: %v", err) } } func TestSafeDialer_NoIPsReturned(t *testing.T) { // This tests the edge case where DNS returns no IP addresses // In practice this is rare, but we need to handle it opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // This domain should fail DNS resolution _, err := dialer(ctx, "tcp", "empty-dns-result-test.invalid:80") if err == nil { t.Error("expected error when DNS returns no IPs") } } func TestNewSafeHTTPClient_TooManyRedirects(t *testing.T) { redirectCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { redirectCount++ // Keep redirecting to itself http.Redirect(w, r, "/redirect"+string(rune('0'+redirectCount)), http.StatusFound) })) defer server.Close() client := NewSafeHTTPClient( WithTimeout(5*time.Second), WithAllowLocalhost(), WithMaxRedirects(3), ) resp, err := client.Get(server.URL) if resp != nil { resp.Body.Close() } if err == nil { t.Error("expected error for too many redirects") } if err != nil && !contains(err.Error(), "too many redirects") { t.Logf("Got redirect error: %v", err) } } func TestValidateRedirectTarget_AllowedLocalhost(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: true, DialTimeout: time.Second, } // Test that localhost is allowed when AllowLocalhost is true localhostURLs := []string{ "http://localhost/path", "http://127.0.0.1/path", "http://[::1]/path", } for _, url := range localhostURLs { t.Run(url, func(t *testing.T) { req, _ := http.NewRequest("GET", url, http.NoBody) err := validateRedirectTarget(req, opts) if err != nil { t.Errorf("expected no error for %s when AllowLocalhost=true, got: %v", url, err) } }) } } func TestNewSafeHTTPClient_MetadataEndpoint(t *testing.T) { // Test that cloud metadata endpoints are blocked client := NewSafeHTTPClient( WithTimeout(2 * time.Second), ) // AWS metadata endpoint resp, err := client.Get("http://169.254.169.254/latest/meta-data/") if resp != nil { defer resp.Body.Close() } if err == nil { t.Error("expected cloud metadata endpoint to be blocked") } } func TestSafeDialer_IPv4MappedIPv6(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: time.Second, } dialer := safeDialer(opts) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // Test IPv6-formatted localhost _, err := dialer(ctx, "tcp", "[::ffff:127.0.0.1]:80") if err == nil { t.Error("expected IPv4-mapped IPv6 loopback to be blocked") } } func TestClientOptions_AllFunctionalOptions(t *testing.T) { // Test all functional options together client := NewSafeHTTPClient( WithTimeout(15*time.Second), WithAllowLocalhost(), WithAllowedDomains("example.com", "api.example.com"), WithMaxRedirects(5), WithDialTimeout(3*time.Second), ) if client == nil { t.Fatal("NewSafeHTTPClient() returned nil with all options") } if client.Timeout != 15*time.Second { t.Errorf("expected timeout of 15s, got %v", client.Timeout) } } func TestSafeDialer_ContextCancelled(t *testing.T) { opts := &ClientOptions{ AllowLocalhost: false, DialTimeout: 5 * time.Second, } dialer := safeDialer(opts) // Create an already-cancelled context ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately _, err := dialer(ctx, "tcp", "example.com:80") if err == nil { t.Error("expected error for cancelled context") } } func TestNewSafeHTTPClient_RedirectValidation(t *testing.T) { // Server that redirects to itself (valid redirect) callCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ if callCount == 1 { http.Redirect(w, r, "/final", http.StatusFound) return } w.WriteHeader(http.StatusOK) w.Write([]byte("success")) })) defer server.Close() client := NewSafeHTTPClient( WithTimeout(5*time.Second), WithAllowLocalhost(), WithMaxRedirects(2), ) resp, err := client.Get(server.URL) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } } // Helper function for error message checking func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || s != "" && containsSubstr(s, substr)) } func containsSubstr(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false }