chore: clean .gitignore cache

This commit is contained in:
GitHub Actions
2026-01-26 19:21:33 +00:00
parent 1b1b3a70b1
commit e5f0fec5db
1483 changed files with 0 additions and 472793 deletions

View File

@@ -1,34 +0,0 @@
package network
import (
"net/http"
"time"
)
// NewInternalServiceHTTPClient returns an HTTP client intended for internal service calls
// that are already constrained by an explicit hostname allowlist + expected port policy.
//
// Security posture:
// - Ignores proxy environment variables.
// - Disables redirects.
// - Uses strict, caller-provided timeouts.
func NewInternalServiceHTTPClient(timeout time.Duration) *http.Client {
transport := &http.Transport{
// Explicitly ignore proxy environment variables for SSRF-sensitive requests.
Proxy: nil,
DisableKeepAlives: true,
MaxIdleConns: 1,
IdleConnTimeout: timeout,
TLSHandshakeTimeout: timeout,
ResponseHeaderTimeout: timeout,
}
return &http.Client{
Timeout: timeout,
Transport: transport,
// Explicit redirect policy per call site: disable.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}

View File

@@ -1,267 +0,0 @@
package network
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewInternalServiceHTTPClient(t *testing.T) {
t.Parallel()
tests := []struct {
name string
timeout time.Duration
}{
{"with 1 second timeout", 1 * time.Second},
{"with 5 second timeout", 5 * time.Second},
{"with 30 second timeout", 30 * time.Second},
{"with 100ms timeout", 100 * time.Millisecond},
{"with zero timeout", 0},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := NewInternalServiceHTTPClient(tt.timeout)
if client == nil {
t.Fatal("NewInternalServiceHTTPClient() returned nil")
}
if client.Timeout != tt.timeout {
t.Errorf("expected timeout %v, got %v", tt.timeout, client.Timeout)
}
})
}
}
func TestNewInternalServiceHTTPClient_TransportConfiguration(t *testing.T) {
t.Parallel()
timeout := 5 * time.Second
client := NewInternalServiceHTTPClient(timeout)
if client.Transport == nil {
t.Fatal("expected Transport to be set")
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatal("expected Transport to be *http.Transport")
}
// Verify proxy is nil (ignores proxy environment variables)
if transport.Proxy != nil {
t.Error("expected Proxy to be nil for SSRF protection")
}
// Verify keep-alives are disabled
if !transport.DisableKeepAlives {
t.Error("expected DisableKeepAlives to be true")
}
// Verify MaxIdleConns
if transport.MaxIdleConns != 1 {
t.Errorf("expected MaxIdleConns to be 1, got %d", transport.MaxIdleConns)
}
// Verify timeout settings
if transport.IdleConnTimeout != timeout {
t.Errorf("expected IdleConnTimeout %v, got %v", timeout, transport.IdleConnTimeout)
}
if transport.TLSHandshakeTimeout != timeout {
t.Errorf("expected TLSHandshakeTimeout %v, got %v", timeout, transport.TLSHandshakeTimeout)
}
if transport.ResponseHeaderTimeout != timeout {
t.Errorf("expected ResponseHeaderTimeout %v, got %v", timeout, transport.ResponseHeaderTimeout)
}
}
func TestNewInternalServiceHTTPClient_RedirectsDisabled(t *testing.T) {
t.Parallel()
// Create a test server that redirects
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
if r.URL.Path == "/" {
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("redirected"))
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer func() { _ = resp.Body.Close() }()
// Should receive the redirect response, not follow it
if resp.StatusCode != http.StatusFound {
t.Errorf("expected status %d (redirect not followed), got %d", http.StatusFound, resp.StatusCode)
}
// Verify only one request was made (redirect was not followed)
if redirectCount != 1 {
t.Errorf("expected exactly 1 request, got %d (redirect was followed)", redirectCount)
}
}
func TestNewInternalServiceHTTPClient_CheckRedirectReturnsErrUseLastResponse(t *testing.T) {
t.Parallel()
client := NewInternalServiceHTTPClient(5 * time.Second)
if client.CheckRedirect == nil {
t.Fatal("expected CheckRedirect to be set")
}
// Create a dummy request to test CheckRedirect
req, _ := http.NewRequest("GET", "http://example.com", http.NoBody)
err := client.CheckRedirect(req, nil)
if err != http.ErrUseLastResponse {
t.Errorf("expected CheckRedirect to return http.ErrUseLastResponse, got %v", err)
}
}
func TestNewInternalServiceHTTPClient_ActualRequest(t *testing.T) {
t.Parallel()
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestNewInternalServiceHTTPClient_TimeoutEnforced(t *testing.T) {
t.Parallel()
// Create a slow server that delays longer than the timeout
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(500 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Use a very short timeout
client := NewInternalServiceHTTPClient(100 * time.Millisecond)
resp, err := client.Get(server.URL)
if resp != nil {
_ = resp.Body.Close()
}
if err == nil {
t.Error("expected timeout error, got nil")
}
}
func TestNewInternalServiceHTTPClient_MultipleClients(t *testing.T) {
t.Parallel()
// Verify that multiple clients can be created with different timeouts
client1 := NewInternalServiceHTTPClient(1 * time.Second)
client2 := NewInternalServiceHTTPClient(10 * time.Second)
if client1 == client2 {
t.Error("expected different client instances")
}
if client1.Timeout != 1*time.Second {
t.Errorf("client1 expected timeout 1s, got %v", client1.Timeout)
}
if client2.Timeout != 10*time.Second {
t.Errorf("client2 expected timeout 10s, got %v", client2.Timeout)
}
}
func TestNewInternalServiceHTTPClient_ProxyIgnored(t *testing.T) {
t.Parallel()
// Set up a server to verify no proxy is used
directServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("direct"))
}))
defer directServer.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
// Even if environment has proxy settings, this client should ignore them
// because transport.Proxy is set to nil
transport := client.Transport.(*http.Transport)
if transport.Proxy != nil {
t.Error("expected Proxy to be nil (proxy env vars should be ignored)")
}
resp, err := client.Get(directServer.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestNewInternalServiceHTTPClient_PostRequest(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
}
w.WriteHeader(http.StatusCreated)
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
resp, err := client.Post(server.URL, "application/json", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusCreated {
t.Errorf("expected status 201, got %d", resp.StatusCode)
}
}
// Benchmark tests
func BenchmarkNewInternalServiceHTTPClient(b *testing.B) {
for i := 0; i < b.N; i++ {
NewInternalServiceHTTPClient(5 * time.Second)
}
}
func BenchmarkNewInternalServiceHTTPClient_Request(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := client.Get(server.URL)
if err == nil {
_ = resp.Body.Close()
}
}
}

View File

@@ -1,253 +0,0 @@
package network
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewInternalServiceHTTPClient(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
}{
{"with 1 second timeout", 1 * time.Second},
{"with 5 second timeout", 5 * time.Second},
{"with 30 second timeout", 30 * time.Second},
{"with 100ms timeout", 100 * time.Millisecond},
{"with zero timeout", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewInternalServiceHTTPClient(tt.timeout)
if client == nil {
t.Fatal("NewInternalServiceHTTPClient() returned nil")
}
if client.Timeout != tt.timeout {
t.Errorf("expected timeout %v, got %v", tt.timeout, client.Timeout)
}
})
}
}
func TestNewInternalServiceHTTPClient_TransportConfiguration(t *testing.T) {
timeout := 5 * time.Second
client := NewInternalServiceHTTPClient(timeout)
if client.Transport == nil {
t.Fatal("expected Transport to be set")
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatal("expected Transport to be *http.Transport")
}
// Verify proxy is nil (ignores proxy environment variables)
if transport.Proxy != nil {
t.Error("expected Proxy to be nil for SSRF protection")
}
// Verify keep-alives are disabled
if !transport.DisableKeepAlives {
t.Error("expected DisableKeepAlives to be true")
}
// Verify MaxIdleConns
if transport.MaxIdleConns != 1 {
t.Errorf("expected MaxIdleConns to be 1, got %d", transport.MaxIdleConns)
}
// Verify timeout settings
if transport.IdleConnTimeout != timeout {
t.Errorf("expected IdleConnTimeout %v, got %v", timeout, transport.IdleConnTimeout)
}
if transport.TLSHandshakeTimeout != timeout {
t.Errorf("expected TLSHandshakeTimeout %v, got %v", timeout, transport.TLSHandshakeTimeout)
}
if transport.ResponseHeaderTimeout != timeout {
t.Errorf("expected ResponseHeaderTimeout %v, got %v", timeout, transport.ResponseHeaderTimeout)
}
}
func TestNewInternalServiceHTTPClient_RedirectsDisabled(t *testing.T) {
// Create a test server that redirects
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
if r.URL.Path == "/" {
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("redirected"))
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
// Should receive the redirect response, not follow it
if resp.StatusCode != http.StatusFound {
t.Errorf("expected status %d (redirect not followed), got %d", http.StatusFound, resp.StatusCode)
}
// Verify only one request was made (redirect was not followed)
if redirectCount != 1 {
t.Errorf("expected exactly 1 request, got %d (redirect was followed)", redirectCount)
}
}
func TestNewInternalServiceHTTPClient_CheckRedirectReturnsErrUseLastResponse(t *testing.T) {
client := NewInternalServiceHTTPClient(5 * time.Second)
if client.CheckRedirect == nil {
t.Fatal("expected CheckRedirect to be set")
}
// Create a dummy request to test CheckRedirect
req, _ := http.NewRequest("GET", "http://example.com", http.NoBody)
err := client.CheckRedirect(req, nil)
if err != http.ErrUseLastResponse {
t.Errorf("expected CheckRedirect to return http.ErrUseLastResponse, got %v", err)
}
}
func TestNewInternalServiceHTTPClient_ActualRequest(t *testing.T) {
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestNewInternalServiceHTTPClient_TimeoutEnforced(t *testing.T) {
// Create a slow server that delays longer than the timeout
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(500 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Use a very short timeout
client := NewInternalServiceHTTPClient(100 * time.Millisecond)
_, err := client.Get(server.URL)
if err == nil {
t.Error("expected timeout error, got nil")
}
}
func TestNewInternalServiceHTTPClient_MultipleClients(t *testing.T) {
// Verify that multiple clients can be created with different timeouts
client1 := NewInternalServiceHTTPClient(1 * time.Second)
client2 := NewInternalServiceHTTPClient(10 * time.Second)
if client1 == client2 {
t.Error("expected different client instances")
}
if client1.Timeout != 1*time.Second {
t.Errorf("client1 expected timeout 1s, got %v", client1.Timeout)
}
if client2.Timeout != 10*time.Second {
t.Errorf("client2 expected timeout 10s, got %v", client2.Timeout)
}
}
func TestNewInternalServiceHTTPClient_ProxyIgnored(t *testing.T) {
// Set up a server to verify no proxy is used
directServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("direct"))
}))
defer directServer.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
// Even if environment has proxy settings, this client should ignore them
// because transport.Proxy is set to nil
transport := client.Transport.(*http.Transport)
if transport.Proxy != nil {
t.Error("expected Proxy to be nil (proxy env vars should be ignored)")
}
resp, err := client.Get(directServer.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestNewInternalServiceHTTPClient_PostRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
}
w.WriteHeader(http.StatusCreated)
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
resp, err := client.Post(server.URL, "application/json", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Errorf("expected status 201, got %d", resp.StatusCode)
}
}
// Benchmark tests
func BenchmarkNewInternalServiceHTTPClient(b *testing.B) {
for i := 0; i < b.N; i++ {
NewInternalServiceHTTPClient(5 * time.Second)
}
}
func BenchmarkNewInternalServiceHTTPClient_Request(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewInternalServiceHTTPClient(5 * time.Second)
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := client.Get(server.URL)
if err == nil {
resp.Body.Close()
}
}
}

View File

@@ -1,353 +0,0 @@
// Package network provides SSRF-safe HTTP client and networking utilities.
// This package implements comprehensive Server-Side Request Forgery (SSRF) protection
// by validating IP addresses at multiple layers: URL validation, DNS resolution, and connection time.
package network
import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"
)
// privateBlocks holds pre-parsed CIDR blocks for private/reserved IP ranges.
// These are parsed once at package initialization for performance.
var (
privateBlocks []*net.IPNet
initOnce sync.Once
)
// privateCIDRs defines all private and reserved IP ranges to block for SSRF protection.
// This list covers:
// - RFC 1918 private networks (10.x, 172.16-31.x, 192.168.x)
// - Loopback addresses (127.x.x.x, ::1)
// - Link-local addresses (169.254.x.x, fe80::) including cloud metadata endpoints
// - Reserved ranges (0.x.x.x, 240.x.x.x, 255.255.255.255)
// - IPv6 unique local addresses (fc00::)
var privateCIDRs = []string{
// IPv4 Private Networks (RFC 1918)
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
// IPv4 Link-Local (RFC 3927) - includes AWS/GCP/Azure metadata service (169.254.169.254)
"169.254.0.0/16",
// IPv4 Loopback
"127.0.0.0/8",
// IPv4 Reserved ranges
"0.0.0.0/8", // "This network"
"240.0.0.0/4", // Reserved for future use
"255.255.255.255/32", // Broadcast
// IPv6 Loopback
"::1/128",
// IPv6 Unique Local Addresses (RFC 4193)
"fc00::/7",
// IPv6 Link-Local
"fe80::/10",
}
// initPrivateBlocks parses all CIDR blocks once at startup.
func initPrivateBlocks() {
initOnce.Do(func() {
privateBlocks = make([]*net.IPNet, 0, len(privateCIDRs))
for _, cidr := range privateCIDRs {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
// This should never happen with valid CIDR strings
continue
}
privateBlocks = append(privateBlocks, block)
}
})
}
// IsPrivateIP checks if an IP address is private, loopback, link-local, or otherwise restricted.
// This function implements comprehensive SSRF protection by blocking:
// - Private IPv4 ranges (RFC 1918): 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
// - Loopback addresses: 127.0.0.0/8, ::1/128
// - Link-local addresses: 169.254.0.0/16, fe80::/10 (includes cloud metadata endpoints)
// - Reserved ranges: 0.0.0.0/8, 240.0.0.0/4, 255.255.255.255/32
// - IPv6 unique local addresses: fc00::/7
//
// IPv4-mapped IPv6 addresses (::ffff:x.x.x.x) are correctly handled by extracting
// the IPv4 portion and validating it.
//
// Returns true if the IP should be blocked, false if it's safe for external requests.
func IsPrivateIP(ip net.IP) bool {
if ip == nil {
return true // nil IPs should be blocked
}
// Ensure private blocks are initialized
initPrivateBlocks()
// Handle IPv4-mapped IPv6 addresses (::ffff:x.x.x.x)
// Convert to IPv4 for consistent checking
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
// Check built-in Go functions for common cases (fast path)
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
ip.IsMulticast() || ip.IsUnspecified() {
return true
}
// Check against all private/reserved CIDR blocks
for _, block := range privateBlocks {
if block.Contains(ip) {
return true
}
}
return false
}
// ClientOptions configures the behavior of the safe HTTP client.
type ClientOptions struct {
// Timeout is the total request timeout (default: 10s)
Timeout time.Duration
// AllowLocalhost permits connections to localhost/127.0.0.1 (default: false)
// Use only for testing or when connecting to known-safe local services.
AllowLocalhost bool
// AllowedDomains restricts requests to specific domains (optional).
// If set, only these domains will be allowed (in addition to localhost if AllowLocalhost is true).
AllowedDomains []string
// MaxRedirects sets the maximum number of redirects to follow (default: 0)
// Set to 0 to disable redirects entirely.
MaxRedirects int
// DialTimeout is the connection timeout for individual dial attempts (default: 5s)
DialTimeout time.Duration
}
// Option is a functional option for configuring ClientOptions.
type Option func(*ClientOptions)
// defaultOptions returns the default safe client configuration.
func defaultOptions() ClientOptions {
return ClientOptions{
Timeout: 10 * time.Second,
AllowLocalhost: false,
AllowedDomains: nil,
MaxRedirects: 0,
DialTimeout: 5 * time.Second,
}
}
// WithTimeout sets the total request timeout.
func WithTimeout(timeout time.Duration) Option {
return func(opts *ClientOptions) {
opts.Timeout = timeout
}
}
// WithAllowLocalhost permits connections to localhost addresses.
// Use this option only when connecting to known-safe local services (e.g., CrowdSec LAPI).
func WithAllowLocalhost() Option {
return func(opts *ClientOptions) {
opts.AllowLocalhost = true
}
}
// WithAllowedDomains restricts requests to specific domains.
// When set, only requests to these domains will be permitted.
func WithAllowedDomains(domains ...string) Option {
return func(opts *ClientOptions) {
opts.AllowedDomains = append(opts.AllowedDomains, domains...)
}
}
// WithMaxRedirects sets the maximum number of redirects to follow.
// Default is 0 (no redirects). Each redirect target is validated against SSRF rules.
func WithMaxRedirects(maxRedirects int) Option {
return func(opts *ClientOptions) {
opts.MaxRedirects = maxRedirects
}
}
// WithDialTimeout sets the connection timeout for individual dial attempts.
func WithDialTimeout(timeout time.Duration) Option {
return func(opts *ClientOptions) {
opts.DialTimeout = timeout
}
}
// safeDialer creates a custom dial function that validates IP addresses at connection time.
// This prevents DNS rebinding attacks by:
// 1. Resolving the hostname to IP addresses
// 2. Validating ALL resolved IPs against IsPrivateIP
// 3. Connecting directly to the validated IP (not the hostname)
//
// This approach defeats Time-of-Check to Time-of-Use (TOCTOU) attacks where
// DNS could return different IPs between validation and connection.
func safeDialer(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
// Parse host:port from address
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("invalid address format: %w", err)
}
// Check if this is an allowed localhost address
isLocalhost := host == "localhost" || host == "127.0.0.1" || host == "::1"
if isLocalhost && opts.AllowLocalhost {
// Allow localhost connections when explicitly permitted
dialer := &net.Dialer{Timeout: opts.DialTimeout}
return dialer.DialContext(ctx, network, addr)
}
// Resolve DNS with context timeout
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("DNS resolution failed for %s: %w", host, err)
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IP addresses found for host: %s", host)
}
// Validate ALL resolved IPs - if ANY are private, reject the entire request
// This prevents attackers from using DNS load balancing to mix private/public IPs
for _, ip := range ips {
// Allow localhost IPs if AllowLocalhost is set
if opts.AllowLocalhost && ip.IP.IsLoopback() {
continue
}
if IsPrivateIP(ip.IP) {
return nil, fmt.Errorf("connection to private IP blocked: %s resolved to %s", host, ip.IP)
}
}
// Find first valid IP to connect to
var selectedIP net.IP
for _, ip := range ips {
if opts.AllowLocalhost && ip.IP.IsLoopback() {
selectedIP = ip.IP
break
}
if !IsPrivateIP(ip.IP) {
selectedIP = ip.IP
break
}
}
if selectedIP == nil {
return nil, fmt.Errorf("no valid IP addresses found for host: %s", host)
}
// Connect to the validated IP (prevents DNS rebinding TOCTOU attacks)
dialer := &net.Dialer{Timeout: opts.DialTimeout}
return dialer.DialContext(ctx, network, net.JoinHostPort(selectedIP.String(), port))
}
}
// validateRedirectTarget checks if a redirect URL is safe to follow.
// Returns an error if the redirect target resolves to private IPs.
func validateRedirectTarget(req *http.Request, opts *ClientOptions) error {
host := req.URL.Hostname()
if host == "" {
return fmt.Errorf("missing hostname in redirect URL")
}
// Check localhost
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
if opts.AllowLocalhost {
return nil
}
return fmt.Errorf("redirect to localhost blocked")
}
// Resolve and validate IPs
ctx, cancel := context.WithTimeout(context.Background(), opts.DialTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("DNS resolution failed for redirect target %s: %w", host, err)
}
for _, ip := range ips {
if opts.AllowLocalhost && ip.IP.IsLoopback() {
continue
}
if IsPrivateIP(ip.IP) {
return fmt.Errorf("redirect to private IP blocked: %s resolved to %s", host, ip.IP)
}
}
return nil
}
// NewSafeHTTPClient creates an HTTP client with comprehensive SSRF protection.
// The client validates IP addresses at connection time to prevent:
// - Direct connections to private/reserved IP ranges
// - DNS rebinding attacks (TOCTOU)
// - Redirects to private IP addresses
// - Cloud metadata endpoint access (169.254.169.254)
//
// Default configuration:
// - 10 second timeout
// - No redirects (returns http.ErrUseLastResponse)
// - Keep-alives disabled
// - Private IPs blocked
//
// Use functional options to customize behavior:
//
// // Allow localhost for local service communication
// client := network.NewSafeHTTPClient(network.WithAllowLocalhost())
//
// // Set custom timeout
// client := network.NewSafeHTTPClient(network.WithTimeout(5 * time.Second))
//
// // Allow specific redirects
// client := network.NewSafeHTTPClient(network.WithMaxRedirects(2))
func NewSafeHTTPClient(opts ...Option) *http.Client {
cfg := defaultOptions()
for _, opt := range opts {
opt(&cfg)
}
return &http.Client{
Timeout: cfg.Timeout,
Transport: &http.Transport{
// Explicitly ignore proxy environment variables for SSRF-sensitive requests.
Proxy: nil,
DialContext: safeDialer(&cfg),
DisableKeepAlives: true,
MaxIdleConns: 1,
IdleConnTimeout: cfg.Timeout,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: cfg.Timeout,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// No redirects allowed by default
if cfg.MaxRedirects == 0 {
return http.ErrUseLastResponse
}
// Check redirect count
if len(via) >= cfg.MaxRedirects {
return fmt.Errorf("too many redirects (max %d)", cfg.MaxRedirects)
}
// Validate redirect target for SSRF
if err := validateRedirectTarget(req, &cfg); err != nil {
return err
}
return nil
},
}
}

View File

@@ -1,922 +0,0 @@
package network
import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestIsPrivateIP(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ip string
expected bool
}{
// Private IPv4 ranges
{"10.0.0.0/8 start", "10.0.0.1", true},
{"10.0.0.0/8 middle", "10.255.255.255", true},
{"172.16.0.0/12 start", "172.16.0.1", true},
{"172.16.0.0/12 end", "172.31.255.255", true},
{"192.168.0.0/16 start", "192.168.0.1", true},
{"192.168.0.0/16 end", "192.168.255.255", true},
// Link-local
{"169.254.0.0/16 start", "169.254.0.1", true},
{"169.254.0.0/16 end", "169.254.255.255", true},
// Loopback
{"127.0.0.0/8 localhost", "127.0.0.1", true},
{"127.0.0.0/8 other", "127.0.0.2", true},
{"127.0.0.0/8 end", "127.255.255.255", true},
// Special addresses
{"0.0.0.0/8", "0.0.0.1", true},
{"240.0.0.0/4 reserved", "240.0.0.1", true},
{"255.255.255.255 broadcast", "255.255.255.255", true},
// IPv6 private ranges
{"IPv6 loopback", "::1", true},
{"fc00::/7 unique local", "fc00::1", true},
{"fd00::/8 unique local", "fd00::1", true},
{"fe80::/10 link-local", "fe80::1", true},
// Public IPs (should return false)
{"Public IPv4 1", "8.8.8.8", false},
{"Public IPv4 2", "1.1.1.1", false},
{"Public IPv4 3", "93.184.216.34", false},
{"Public IPv6", "2001:4860:4860::8888", false},
// Edge cases
{"Just outside 172.16", "172.15.255.255", false},
{"Just outside 172.31", "172.32.0.0", false},
{"Just outside 192.168", "192.167.255.255", false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
func TestIsPrivateIP_NilIP(t *testing.T) {
t.Parallel()
// nil IP should return true (block by default for safety)
result := IsPrivateIP(nil)
if result != true {
t.Errorf("IsPrivateIP(nil) = %v, want true", result)
}
}
func TestSafeDialer_BlocksPrivateIPs(t *testing.T) {
t.Parallel()
tests := []struct {
name string
address string
shouldBlock bool
}{
{"blocks 10.x.x.x", "10.0.0.1:80", true},
{"blocks 172.16.x.x", "172.16.0.1:80", true},
{"blocks 192.168.x.x", "192.168.1.1:80", true},
{"blocks 127.0.0.1", "127.0.0.1:80", true},
{"blocks localhost", "localhost:80", true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := dialer(ctx, "tcp", tt.address)
if tt.shouldBlock {
if err == nil {
_ = conn.Close()
t.Errorf("expected connection to %s to be blocked", tt.address)
}
}
})
}
}
func TestSafeDialer_AllowsLocalhost(t *testing.T) {
t.Parallel()
// Create a local test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Extract host:port from test server URL
addr := server.Listener.Addr().String()
opts := &ClientOptions{
AllowLocalhost: true,
DialTimeout: 5 * time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := dialer(ctx, "tcp", addr)
if err != nil {
t.Errorf("expected connection to localhost to be allowed when allowLocalhost=true, got error: %v", err)
return
}
_ = conn.Close()
}
func TestSafeDialer_AllowedDomains(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
AllowedDomains: []string{"app.crowdsec.net", "hub.crowdsec.net"},
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
// Test that allowed domain passes validation (we can't actually connect)
// This is a structural test - we're verifying the domain check passes
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
// This will fail to connect (no server) but should NOT fail validation
_, err := dialer(ctx, "tcp", "app.crowdsec.net:443")
if err != nil {
// Check it's a connection error, not a validation error
if _, ok := err.(*net.OpError); !ok {
// Context deadline exceeded is also acceptable (DNS/connection timeout)
if err != context.DeadlineExceeded {
t.Logf("Got expected error type for allowed domain: %T: %v", err, err)
}
}
}
}
func TestNewSafeHTTPClient_DefaultOptions(t *testing.T) {
t.Parallel()
client := NewSafeHTTPClient()
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
if client.Timeout != 10*time.Second {
t.Errorf("expected default timeout of 10s, got %v", client.Timeout)
}
}
func TestNewSafeHTTPClient_WithTimeout(t *testing.T) {
t.Parallel()
client := NewSafeHTTPClient(WithTimeout(10 * time.Second))
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
if client.Timeout != 10*time.Second {
t.Errorf("expected timeout of 10s, got %v", client.Timeout)
}
}
func TestNewSafeHTTPClient_WithAllowLocalhost(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
// Create a local test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("expected request to localhost to succeed with allowLocalhost, got: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestNewSafeHTTPClient_BlocksSSRF(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
client := NewSafeHTTPClient(
WithTimeout(2 * time.Second),
)
// Test that internal IPs are blocked
urls := []string{
"http://127.0.0.1/",
"http://10.0.0.1/",
"http://172.16.0.1/",
"http://192.168.1.1/",
"http://localhost/",
}
for _, url := range urls {
t.Run(url, func(t *testing.T) {
t.Parallel()
resp, err := client.Get(url)
if err == nil {
defer func() { _ = resp.Body.Close() }()
t.Errorf("expected request to %s to be blocked", url)
}
})
}
}
func TestNewSafeHTTPClient_WithMaxRedirects(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
if redirectCount < 5 {
http.Redirect(w, r, "/redirect", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(2),
)
resp, err := client.Get(server.URL)
if err == nil {
defer func() { _ = resp.Body.Close() }()
t.Error("expected redirect limit to be enforced")
}
}
func TestNewSafeHTTPClient_WithAllowedDomains(t *testing.T) {
t.Parallel()
client := NewSafeHTTPClient(
WithTimeout(2*time.Second),
WithAllowedDomains("example.com"),
)
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
// We can't actually connect, but we verify the client is created
// with the correct configuration
}
func TestClientOptions_Defaults(t *testing.T) {
t.Parallel()
opts := defaultOptions()
if opts.Timeout != 10*time.Second {
t.Errorf("expected default timeout 10s, got %v", opts.Timeout)
}
if opts.MaxRedirects != 0 {
t.Errorf("expected default maxRedirects 0, got %d", opts.MaxRedirects)
}
if opts.DialTimeout != 5*time.Second {
t.Errorf("expected default dialTimeout 5s, got %v", opts.DialTimeout)
}
}
func TestWithDialTimeout(t *testing.T) {
t.Parallel()
client := NewSafeHTTPClient(WithDialTimeout(5 * time.Second))
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
}
// Benchmark tests
func BenchmarkIsPrivateIP_IPv4Private(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsPrivateIP(ip)
}
}
func BenchmarkIsPrivateIP_IPv4Public(b *testing.B) {
ip := net.ParseIP("8.8.8.8")
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsPrivateIP(ip)
}
}
func BenchmarkIsPrivateIP_IPv6(b *testing.B) {
ip := net.ParseIP("2001:4860:4860::8888")
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsPrivateIP(ip)
}
}
func BenchmarkNewSafeHTTPClient(b *testing.B) {
for i := 0; i < b.N; i++ {
NewSafeHTTPClient(
WithTimeout(10*time.Second),
WithAllowLocalhost(),
)
}
}
// Additional tests to increase coverage
func TestSafeDialer_InvalidAddress(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test invalid address format (no port)
_, err := dialer(ctx, "tcp", "invalid-address-no-port")
if err == nil {
t.Error("expected error for invalid address format")
}
}
func TestSafeDialer_LoopbackIPv6(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: true,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test IPv6 loopback with AllowLocalhost
_, err := dialer(ctx, "tcp", "[::1]:80")
// Should fail to connect but not due to validation
if err != nil {
t.Logf("Expected connection error (not validation): %v", err)
}
}
func TestValidateRedirectTarget_EmptyHostname(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
// Create request with empty hostname
req, _ := http.NewRequest("GET", "http:///path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for empty hostname")
}
}
func TestValidateRedirectTarget_Localhost(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
// Test localhost blocked
req, _ := http.NewRequest("GET", "http://localhost/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for localhost when AllowLocalhost=false")
}
// Test localhost allowed
opts.AllowLocalhost = true
err = validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for localhost when AllowLocalhost=true, got: %v", err)
}
}
func TestValidateRedirectTarget_127(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
req, _ := http.NewRequest("GET", "http://127.0.0.1/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for 127.0.0.1 when AllowLocalhost=false")
}
opts.AllowLocalhost = true
err = validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for 127.0.0.1 when AllowLocalhost=true, got: %v", err)
}
}
func TestValidateRedirectTarget_IPv6Loopback(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
req, _ := http.NewRequest("GET", "http://[::1]/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for ::1 when AllowLocalhost=false")
}
opts.AllowLocalhost = true
err = validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for ::1 when AllowLocalhost=true, got: %v", err)
}
}
func TestNewSafeHTTPClient_NoRedirectsByDefault(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer func() { _ = resp.Body.Close() }()
// Should not follow redirect - should return 302
if resp.StatusCode != http.StatusFound {
t.Errorf("expected status 302 (redirect not followed), got %d", resp.StatusCode)
}
}
func TestIsPrivateIP_IPv4MappedIPv6(t *testing.T) {
t.Parallel()
// Test IPv4-mapped IPv6 addresses
tests := []struct {
name string
ip string
expected bool
}{
{"IPv4-mapped private", "::ffff:192.168.1.1", true},
{"IPv4-mapped public", "::ffff:8.8.8.8", false},
{"IPv4-mapped loopback", "::ffff:127.0.0.1", true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
func TestIsPrivateIP_Multicast(t *testing.T) {
t.Parallel()
// Test multicast addresses
tests := []struct {
name string
ip string
expected bool
}{
{"IPv4 multicast", "224.0.0.1", true},
{"IPv6 multicast", "ff02::1", true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
func TestIsPrivateIP_Unspecified(t *testing.T) {
t.Parallel()
// Test unspecified addresses
tests := []struct {
name string
ip string
expected bool
}{
{"IPv4 unspecified", "0.0.0.0", true},
{"IPv6 unspecified", "::", true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
// Phase 1 Coverage Improvement Tests
func TestValidateRedirectTarget_DNSFailure(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: 100 * time.Millisecond, // Short timeout to force DNS failure quickly
}
// Use a domain that will fail DNS resolution
req, _ := http.NewRequest("GET", "http://this-domain-does-not-exist-12345.invalid/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for DNS resolution failure")
}
// Verify the error is DNS-related
if err != nil && !contains(err.Error(), "DNS resolution failed") {
t.Errorf("expected DNS resolution failure error, got: %v", err)
}
}
func TestValidateRedirectTarget_PrivateIPInRedirect(t *testing.T) {
t.Parallel()
// Test that redirects to private IPs are properly blocked
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
// Test various private IP redirect scenarios
privateHosts := []string{
"http://10.0.0.1/path",
"http://172.16.0.1/path",
"http://192.168.1.1/path",
"http://169.254.169.254/latest/meta-data/", // AWS metadata endpoint
}
for _, url := range privateHosts {
t.Run(url, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest("GET", url, http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Errorf("expected error for redirect to private IP: %s", url)
}
})
}
}
func TestSafeDialer_AllIPsPrivate(t *testing.T) {
t.Parallel()
// Test that when all resolved IPs are private, the connection is blocked
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test dialing addresses that resolve to private IPs
privateAddresses := []string{
"10.0.0.1:80",
"172.16.0.1:443",
"192.168.0.1:8080",
"169.254.169.254:80", // Cloud metadata endpoint
}
for _, addr := range privateAddresses {
t.Run(addr, func(t *testing.T) {
t.Parallel()
conn, err := dialer(ctx, "tcp", addr)
if err == nil {
_ = conn.Close()
t.Errorf("expected connection to %s to be blocked (all IPs private)", addr)
}
})
}
}
func TestNewSafeHTTPClient_RedirectToPrivateIP(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
// Create a server that redirects to a private IP
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
// Redirect to a private IP (will be blocked)
http.Redirect(w, r, "http://192.168.1.1/internal", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Client with redirects enabled and localhost allowed for the test server
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(3),
)
// Make request - should fail when trying to follow redirect to private IP
resp, err := client.Get(server.URL)
if err == nil {
defer func() { _ = resp.Body.Close() }()
t.Error("expected error when redirect targets private IP")
}
}
func TestSafeDialer_DNSResolutionFailure(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: 100 * time.Millisecond,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
// Use a domain that will fail DNS resolution
_, err := dialer(ctx, "tcp", "nonexistent-domain-xyz123.invalid:80")
if err == nil {
t.Error("expected error for DNS resolution failure")
}
if err != nil && !contains(err.Error(), "DNS resolution failed") {
t.Errorf("expected DNS resolution failure error, got: %v", err)
}
}
func TestSafeDialer_NoIPsReturned(t *testing.T) {
t.Parallel()
// This tests the edge case where DNS returns no IP addresses
// In practice this is rare, but we need to handle it
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// This domain should fail DNS resolution
_, err := dialer(ctx, "tcp", "empty-dns-result-test.invalid:80")
if err == nil {
t.Error("expected error when DNS returns no IPs")
}
}
func TestNewSafeHTTPClient_TooManyRedirects(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
// Keep redirecting to itself
http.Redirect(w, r, "/redirect"+string(rune('0'+redirectCount)), http.StatusFound)
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(3),
)
resp, err := client.Get(server.URL)
if resp != nil {
defer func() { _ = resp.Body.Close() }()
}
if err == nil {
t.Error("expected error for too many redirects")
}
if err != nil && !contains(err.Error(), "too many redirects") {
t.Logf("Got redirect error: %v", err)
}
}
func TestValidateRedirectTarget_AllowedLocalhost(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: true,
DialTimeout: time.Second,
}
// Test that localhost is allowed when AllowLocalhost is true
localhostURLs := []string{
"http://localhost/path",
"http://127.0.0.1/path",
"http://[::1]/path",
}
for _, url := range localhostURLs {
t.Run(url, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest("GET", url, http.NoBody)
err := validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for %s when AllowLocalhost=true, got: %v", url, err)
}
})
}
}
func TestNewSafeHTTPClient_MetadataEndpoint(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
// Test that cloud metadata endpoints are blocked
client := NewSafeHTTPClient(
WithTimeout(2 * time.Second),
)
// AWS metadata endpoint
resp, err := client.Get("http://169.254.169.254/latest/meta-data/")
if resp != nil {
defer func() { _ = resp.Body.Close() }()
}
if err == nil {
t.Error("expected cloud metadata endpoint to be blocked")
}
}
func TestSafeDialer_IPv4MappedIPv6(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test IPv6-formatted localhost
_, err := dialer(ctx, "tcp", "[::ffff:127.0.0.1]:80")
if err == nil {
t.Error("expected IPv4-mapped IPv6 loopback to be blocked")
}
}
func TestClientOptions_AllFunctionalOptions(t *testing.T) {
t.Parallel()
// Test all functional options together
client := NewSafeHTTPClient(
WithTimeout(15*time.Second),
WithAllowLocalhost(),
WithAllowedDomains("example.com", "api.example.com"),
WithMaxRedirects(5),
WithDialTimeout(3*time.Second),
)
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil with all options")
}
if client.Timeout != 15*time.Second {
t.Errorf("expected timeout of 15s, got %v", client.Timeout)
}
}
func TestSafeDialer_ContextCancelled(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: 5 * time.Second,
}
dialer := safeDialer(opts)
// Create an already-cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := dialer(ctx, "tcp", "example.com:80")
if err == nil {
t.Error("expected error for cancelled context")
}
}
func TestNewSafeHTTPClient_RedirectValidation(t *testing.T) {
if testing.Short() {
t.Skip("Skipping network I/O test in short mode")
}
t.Parallel()
// Server that redirects to itself (valid redirect)
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if callCount == 1 {
http.Redirect(w, r, "/final", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("success"))
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(2),
)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// Helper function for error message checking
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || s != "" && containsSubstr(s, substr))
}
func containsSubstr(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -1,854 +0,0 @@
package network
import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestIsPrivateIP(t *testing.T) { t.Parallel() tests := []struct {
name string
ip string
expected bool
}{
// Private IPv4 ranges
{"10.0.0.0/8 start", "10.0.0.1", true},
{"10.0.0.0/8 middle", "10.255.255.255", true},
{"172.16.0.0/12 start", "172.16.0.1", true},
{"172.16.0.0/12 end", "172.31.255.255", true},
{"192.168.0.0/16 start", "192.168.0.1", true},
{"192.168.0.0/16 end", "192.168.255.255", true},
// Link-local
{"169.254.0.0/16 start", "169.254.0.1", true},
{"169.254.0.0/16 end", "169.254.255.255", true},
// Loopback
{"127.0.0.0/8 localhost", "127.0.0.1", true},
{"127.0.0.0/8 other", "127.0.0.2", true},
{"127.0.0.0/8 end", "127.255.255.255", true},
// Special addresses
{"0.0.0.0/8", "0.0.0.1", true},
{"240.0.0.0/4 reserved", "240.0.0.1", true},
{"255.255.255.255 broadcast", "255.255.255.255", true},
// IPv6 private ranges
{"IPv6 loopback", "::1", true},
{"fc00::/7 unique local", "fc00::1", true},
{"fd00::/8 unique local", "fd00::1", true},
{"fe80::/10 link-local", "fe80::1", true},
// Public IPs (should return false)
{"Public IPv4 1", "8.8.8.8", false},
{"Public IPv4 2", "1.1.1.1", false},
{"Public IPv4 3", "93.184.216.34", false},
{"Public IPv6", "2001:4860:4860::8888", false},
// Edge cases
{"Just outside 172.16", "172.15.255.255", false},
{"Just outside 172.31", "172.32.0.0", false},
{"Just outside 192.168", "192.167.255.255", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
func TestIsPrivateIP_NilIP(t *testing.T) {
t.Parallel()
// nil IP should return true (block by default for safety)
result := IsPrivateIP(nil)
if result != true {
t.Errorf("IsPrivateIP(nil) = %v, want true", result)
}
}
func TestSafeDialer_BlocksPrivateIPs(t *testing.T) { t.Parallel() tests := []struct {
name string
address string
shouldBlock bool
}{
{"blocks 10.x.x.x", "10.0.0.1:80", true},
{"blocks 172.16.x.x", "172.16.0.1:80", true},
{"blocks 192.168.x.x", "192.168.1.1:80", true},
{"blocks 127.0.0.1", "127.0.0.1:80", true},
{"blocks localhost", "localhost:80", true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := dialer(ctx, "tcp", tt.address)
if tt.shouldBlock {
if err == nil {
conn.Close()
t.Errorf("expected connection to %s to be blocked", tt.address)
}
}
})
}
}
func TestSafeDialer_AllowsLocalhost(t *testing.T) {
t.Parallel()
// Create a local test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Extract host:port from test server URL
addr := server.Listener.Addr().String()
opts := &ClientOptions{
AllowLocalhost: true,
DialTimeout: 5 * time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := dialer(ctx, "tcp", addr)
if err != nil {
t.Errorf("expected connection to localhost to be allowed when allowLocalhost=true, got error: %v", err)
return
}
conn.Close()
}
func TestSafeDialer_AllowedDomains(t *testing.T) {
t.Parallel()
opts := &ClientOptions{
AllowLocalhost: false,
AllowedDomains: []string{"app.crowdsec.net", "hub.crowdsec.net"},
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
// Test that allowed domain passes validation (we can't actually connect)
// This is a structural test - we're verifying the domain check passes
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
// This will fail to connect (no server) but should NOT fail validation
_, err := dialer(ctx, "tcp", "app.crowdsec.net:443")
if err != nil {
// Check it's a connection error, not a validation error
if _, ok := err.(*net.OpError); !ok {
// Context deadline exceeded is also acceptable (DNS/connection timeout)
if err != context.DeadlineExceeded {
t.Logf("Got expected error type for allowed domain: %T: %v", err, err)
}
}
}
}
func TestNewSafeHTTPClient_DefaultOptions(t *testing.T) {
t.Parallel()
client := NewSafeHTTPClient()
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
if client.Timeout != 10*time.Second {
t.Errorf("expected default timeout of 10s, got %v", client.Timeout)
}
}
func TestNewSafeHTTPClient_WithTimeout(t *testing.T) {
t.Parallel()
client := NewSafeHTTPClient(WithTimeout(10 * time.Second))
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
if client.Timeout != 10*time.Second {
t.Errorf("expected timeout of 10s, got %v", client.Timeout)
}
}
func TestNewSafeHTTPClient_WithAllowLocalhost(t *testing.T) {
t.Parallel()
// Create a local test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("expected request to localhost to succeed with allowLocalhost, got: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestNewSafeHTTPClient_BlocksSSRF(t *testing.T) {
client := NewSafeHTTPClient(
WithTimeout(2 * time.Second),
)
// Test that internal IPs are blocked
urls := []string{
"http://127.0.0.1/",
"http://10.0.0.1/",
"http://172.16.0.1/",
"http://192.168.1.1/",
"http://localhost/",
}
for _, url := range urls {
t.Run(url, func(t *testing.T) {
resp, err := client.Get(url)
if err == nil {
defer resp.Body.Close()
t.Errorf("expected request to %s to be blocked", url)
}
})
}
}
func TestNewSafeHTTPClient_WithMaxRedirects(t *testing.T) {
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
if redirectCount < 5 {
http.Redirect(w, r, "/redirect", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(2),
)
resp, err := client.Get(server.URL)
if err == nil {
defer resp.Body.Close()
t.Error("expected redirect limit to be enforced")
}
}
func TestNewSafeHTTPClient_WithAllowedDomains(t *testing.T) {
client := NewSafeHTTPClient(
WithTimeout(2*time.Second),
WithAllowedDomains("example.com"),
)
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
// We can't actually connect, but we verify the client is created
// with the correct configuration
}
func TestClientOptions_Defaults(t *testing.T) {
opts := defaultOptions()
if opts.Timeout != 10*time.Second {
t.Errorf("expected default timeout 10s, got %v", opts.Timeout)
}
if opts.MaxRedirects != 0 {
t.Errorf("expected default maxRedirects 0, got %d", opts.MaxRedirects)
}
if opts.DialTimeout != 5*time.Second {
t.Errorf("expected default dialTimeout 5s, got %v", opts.DialTimeout)
}
}
func TestWithDialTimeout(t *testing.T) {
client := NewSafeHTTPClient(WithDialTimeout(5 * time.Second))
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil")
}
}
// Benchmark tests
func BenchmarkIsPrivateIP_IPv4Private(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsPrivateIP(ip)
}
}
func BenchmarkIsPrivateIP_IPv4Public(b *testing.B) {
ip := net.ParseIP("8.8.8.8")
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsPrivateIP(ip)
}
}
func BenchmarkIsPrivateIP_IPv6(b *testing.B) {
ip := net.ParseIP("2001:4860:4860::8888")
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsPrivateIP(ip)
}
}
func BenchmarkNewSafeHTTPClient(b *testing.B) {
for i := 0; i < b.N; i++ {
NewSafeHTTPClient(
WithTimeout(10*time.Second),
WithAllowLocalhost(),
)
}
}
// Additional tests to increase coverage
func TestSafeDialer_InvalidAddress(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test invalid address format (no port)
_, err := dialer(ctx, "tcp", "invalid-address-no-port")
if err == nil {
t.Error("expected error for invalid address format")
}
}
func TestSafeDialer_LoopbackIPv6(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: true,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test IPv6 loopback with AllowLocalhost
_, err := dialer(ctx, "tcp", "[::1]:80")
// Should fail to connect but not due to validation
if err != nil {
t.Logf("Expected connection error (not validation): %v", err)
}
}
func TestValidateRedirectTarget_EmptyHostname(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
// Create request with empty hostname
req, _ := http.NewRequest("GET", "http:///path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for empty hostname")
}
}
func TestValidateRedirectTarget_Localhost(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
// Test localhost blocked
req, _ := http.NewRequest("GET", "http://localhost/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for localhost when AllowLocalhost=false")
}
// Test localhost allowed
opts.AllowLocalhost = true
err = validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for localhost when AllowLocalhost=true, got: %v", err)
}
}
func TestValidateRedirectTarget_127(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
req, _ := http.NewRequest("GET", "http://127.0.0.1/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for 127.0.0.1 when AllowLocalhost=false")
}
opts.AllowLocalhost = true
err = validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for 127.0.0.1 when AllowLocalhost=true, got: %v", err)
}
}
func TestValidateRedirectTarget_IPv6Loopback(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
req, _ := http.NewRequest("GET", "http://[::1]/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for ::1 when AllowLocalhost=false")
}
opts.AllowLocalhost = true
err = validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for ::1 when AllowLocalhost=true, got: %v", err)
}
}
func TestNewSafeHTTPClient_NoRedirectsByDefault(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
// Should not follow redirect - should return 302
if resp.StatusCode != http.StatusFound {
t.Errorf("expected status 302 (redirect not followed), got %d", resp.StatusCode)
}
}
func TestIsPrivateIP_IPv4MappedIPv6(t *testing.T) {
// Test IPv4-mapped IPv6 addresses
tests := []struct {
name string
ip string
expected bool
}{
{"IPv4-mapped private", "::ffff:192.168.1.1", true},
{"IPv4-mapped public", "::ffff:8.8.8.8", false},
{"IPv4-mapped loopback", "::ffff:127.0.0.1", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
func TestIsPrivateIP_Multicast(t *testing.T) {
// Test multicast addresses
tests := []struct {
name string
ip string
expected bool
}{
{"IPv4 multicast", "224.0.0.1", true},
{"IPv6 multicast", "ff02::1", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
func TestIsPrivateIP_Unspecified(t *testing.T) {
// Test unspecified addresses
tests := []struct {
name string
ip string
expected bool
}{
{"IPv4 unspecified", "0.0.0.0", true},
{"IPv6 unspecified", "::", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
t.Fatalf("failed to parse IP: %s", tt.ip)
}
result := IsPrivateIP(ip)
if result != tt.expected {
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}
// Phase 1 Coverage Improvement Tests
func TestValidateRedirectTarget_DNSFailure(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: 100 * time.Millisecond, // Short timeout to force DNS failure quickly
}
// Use a domain that will fail DNS resolution
req, _ := http.NewRequest("GET", "http://this-domain-does-not-exist-12345.invalid/path", http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Error("expected error for DNS resolution failure")
}
// Verify the error is DNS-related
if err != nil && !contains(err.Error(), "DNS resolution failed") {
t.Errorf("expected DNS resolution failure error, got: %v", err)
}
}
func TestValidateRedirectTarget_PrivateIPInRedirect(t *testing.T) {
// Test that redirects to private IPs are properly blocked
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
// Test various private IP redirect scenarios
privateHosts := []string{
"http://10.0.0.1/path",
"http://172.16.0.1/path",
"http://192.168.1.1/path",
"http://169.254.169.254/latest/meta-data/", // AWS metadata endpoint
}
for _, url := range privateHosts {
t.Run(url, func(t *testing.T) {
req, _ := http.NewRequest("GET", url, http.NoBody)
err := validateRedirectTarget(req, opts)
if err == nil {
t.Errorf("expected error for redirect to private IP: %s", url)
}
})
}
}
func TestSafeDialer_AllIPsPrivate(t *testing.T) {
// Test that when all resolved IPs are private, the connection is blocked
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test dialing addresses that resolve to private IPs
privateAddresses := []string{
"10.0.0.1:80",
"172.16.0.1:443",
"192.168.0.1:8080",
"169.254.169.254:80", // Cloud metadata endpoint
}
for _, addr := range privateAddresses {
t.Run(addr, func(t *testing.T) {
conn, err := dialer(ctx, "tcp", addr)
if err == nil {
conn.Close()
t.Errorf("expected connection to %s to be blocked (all IPs private)", addr)
}
})
}
}
func TestNewSafeHTTPClient_RedirectToPrivateIP(t *testing.T) {
// Create a server that redirects to a private IP
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
// Redirect to a private IP (will be blocked)
http.Redirect(w, r, "http://192.168.1.1/internal", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Client with redirects enabled and localhost allowed for the test server
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(3),
)
// Make request - should fail when trying to follow redirect to private IP
resp, err := client.Get(server.URL)
if err == nil {
defer resp.Body.Close()
t.Error("expected error when redirect targets private IP")
}
}
func TestSafeDialer_DNSResolutionFailure(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: 100 * time.Millisecond,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
// Use a domain that will fail DNS resolution
_, err := dialer(ctx, "tcp", "nonexistent-domain-xyz123.invalid:80")
if err == nil {
t.Error("expected error for DNS resolution failure")
}
if err != nil && !contains(err.Error(), "DNS resolution failed") {
t.Errorf("expected DNS resolution failure error, got: %v", err)
}
}
func TestSafeDialer_NoIPsReturned(t *testing.T) {
// This tests the edge case where DNS returns no IP addresses
// In practice this is rare, but we need to handle it
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// This domain should fail DNS resolution
_, err := dialer(ctx, "tcp", "empty-dns-result-test.invalid:80")
if err == nil {
t.Error("expected error when DNS returns no IPs")
}
}
func TestNewSafeHTTPClient_TooManyRedirects(t *testing.T) {
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
// Keep redirecting to itself
http.Redirect(w, r, "/redirect"+string(rune('0'+redirectCount)), http.StatusFound)
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(3),
)
resp, err := client.Get(server.URL)
if resp != nil {
resp.Body.Close()
}
if err == nil {
t.Error("expected error for too many redirects")
}
if err != nil && !contains(err.Error(), "too many redirects") {
t.Logf("Got redirect error: %v", err)
}
}
func TestValidateRedirectTarget_AllowedLocalhost(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: true,
DialTimeout: time.Second,
}
// Test that localhost is allowed when AllowLocalhost is true
localhostURLs := []string{
"http://localhost/path",
"http://127.0.0.1/path",
"http://[::1]/path",
}
for _, url := range localhostURLs {
t.Run(url, func(t *testing.T) {
req, _ := http.NewRequest("GET", url, http.NoBody)
err := validateRedirectTarget(req, opts)
if err != nil {
t.Errorf("expected no error for %s when AllowLocalhost=true, got: %v", url, err)
}
})
}
}
func TestNewSafeHTTPClient_MetadataEndpoint(t *testing.T) {
// Test that cloud metadata endpoints are blocked
client := NewSafeHTTPClient(
WithTimeout(2 * time.Second),
)
// AWS metadata endpoint
resp, err := client.Get("http://169.254.169.254/latest/meta-data/")
if resp != nil {
defer resp.Body.Close()
}
if err == nil {
t.Error("expected cloud metadata endpoint to be blocked")
}
}
func TestSafeDialer_IPv4MappedIPv6(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: time.Second,
}
dialer := safeDialer(opts)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Test IPv6-formatted localhost
_, err := dialer(ctx, "tcp", "[::ffff:127.0.0.1]:80")
if err == nil {
t.Error("expected IPv4-mapped IPv6 loopback to be blocked")
}
}
func TestClientOptions_AllFunctionalOptions(t *testing.T) {
// Test all functional options together
client := NewSafeHTTPClient(
WithTimeout(15*time.Second),
WithAllowLocalhost(),
WithAllowedDomains("example.com", "api.example.com"),
WithMaxRedirects(5),
WithDialTimeout(3*time.Second),
)
if client == nil {
t.Fatal("NewSafeHTTPClient() returned nil with all options")
}
if client.Timeout != 15*time.Second {
t.Errorf("expected timeout of 15s, got %v", client.Timeout)
}
}
func TestSafeDialer_ContextCancelled(t *testing.T) {
opts := &ClientOptions{
AllowLocalhost: false,
DialTimeout: 5 * time.Second,
}
dialer := safeDialer(opts)
// Create an already-cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := dialer(ctx, "tcp", "example.com:80")
if err == nil {
t.Error("expected error for cancelled context")
}
}
func TestNewSafeHTTPClient_RedirectValidation(t *testing.T) {
// Server that redirects to itself (valid redirect)
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if callCount == 1 {
http.Redirect(w, r, "/final", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer server.Close()
client := NewSafeHTTPClient(
WithTimeout(5*time.Second),
WithAllowLocalhost(),
WithMaxRedirects(2),
)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// Helper function for error message checking
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || s != "" && containsSubstr(s, substr))
}
func containsSubstr(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}