package utils import ( "context" "crypto/tls" "fmt" "net" "net/http" "net/url" "strconv" "strings" "time" "github.com/Wikid82/charon/backend/internal/logger" "github.com/Wikid82/charon/backend/internal/metrics" "github.com/Wikid82/charon/backend/internal/network" "github.com/Wikid82/charon/backend/internal/security" ) func resolveAllowedIP(ctx context.Context, host string, allowLocalhost bool) (net.IP, error) { if host == "" { return nil, fmt.Errorf("missing hostname") } // Fast-path: IP literal. if ip := net.ParseIP(host); ip != nil { if allowLocalhost && ip.IsLoopback() { return ip, nil } if network.IsPrivateIP(ip) { return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip) } return ip, nil } ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) if err != nil { return nil, fmt.Errorf("DNS resolution failed: %w", err) } if len(ips) == 0 { return nil, fmt.Errorf("no IP addresses found for host") } var selected net.IP for _, ip := range ips { if allowLocalhost && ip.IP.IsLoopback() { if selected == nil { selected = ip.IP } continue } if network.IsPrivateIP(ip.IP) { return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP) } if selected == nil { selected = ip.IP } } if selected == nil { return nil, fmt.Errorf("no allowed IP addresses found for host") } return selected, nil } // ssrfSafeDialer creates a custom dialer that validates IP addresses at connection time. // This prevents DNS rebinding attacks by validating the IP just before connecting. // Returns a DialContext function suitable for use in http.Transport. func ssrfSafeDialer() func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, netw, addr string) (net.Conn, error) { // Parse host and port from address host, port, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("invalid address format: %w", err) } // Resolve DNS with context timeout ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) if err != nil { return nil, fmt.Errorf("DNS resolution failed: %w", err) } if len(ips) == 0 { return nil, fmt.Errorf("no IP addresses found for host") } // Validate ALL resolved IPs - if any are private, reject immediately // Using centralized network.IsPrivateIP for consistent SSRF protection for _, ip := range ips { if network.IsPrivateIP(ip.IP) { return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP) } } // Connect to the first valid IP (prevents DNS rebinding) dialer := &net.Dialer{ Timeout: 5 * time.Second, } return dialer.DialContext(ctx, netw, net.JoinHostPort(ips[0].IP.String(), port)) } } // NOTE: Redirect validation is implemented by validateRedirectTargetStrict. // TestURLConnectivity performs a server-side connectivity test with SSRF protection. // For testing purposes, an optional http.RoundTripper can be provided to bypass // DNS resolution and network calls. // Returns: // - reachable: true if URL returned 2xx-3xx status // - latency: round-trip time in milliseconds // - error: validation or connectivity error func TestURLConnectivity(rawURL string) (reachable bool, latency float64, err error) { // NOTE: This wrapper preserves the exported API while enforcing // deny-by-default SSRF-safe behavior. // // Do not add optional transports/options to the exported surface. // Tests can exercise alternative paths via unexported helpers. return testURLConnectivity(rawURL) } type urlConnectivityOptions struct { transport http.RoundTripper allowLocalhost bool } type urlConnectivityOption func(*urlConnectivityOptions) //nolint:unused // Used in test files func withTransportForTesting(rt http.RoundTripper) urlConnectivityOption { return func(o *urlConnectivityOptions) { o.transport = rt } } //nolint:unused // Used in test files func withAllowLocalhostForTesting() urlConnectivityOption { return func(o *urlConnectivityOptions) { o.allowLocalhost = true } } // testURLConnectivity is the implementation behind TestURLConnectivity. // It supports internal (package-only) hooks to keep unit tests deterministic // without weakening production defaults. func testURLConnectivity(rawURL string, opts ...urlConnectivityOption) (reachable bool, latency float64, err error) { // Track start time for metrics startTime := time.Now() // Generate unique request ID for tracing requestID := fmt.Sprintf("test-%d", time.Now().UnixNano()) options := urlConnectivityOptions{} for _, opt := range opts { opt(&options) } isTestMode := options.transport != nil // Parse URL first to validate structure parsed, err := url.Parse(rawURL) if err != nil { // ENHANCEMENT: Record validation failure metric metrics.RecordURLValidation("error", "invalid_format") // ENHANCEMENT: Audit log the failed validation if !isTestMode { security.LogURLTest(rawURL, requestID, "system", "", "error") } return false, 0, fmt.Errorf("invalid URL: %w", err) } // Validate scheme if parsed.Scheme != "http" && parsed.Scheme != "https" { // ENHANCEMENT: Record validation failure metric metrics.RecordURLValidation("error", "unsupported_scheme") // ENHANCEMENT: Audit log the failed validation if !isTestMode { security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error") } return false, 0, fmt.Errorf("only http and https schemes are allowed") } // Reject URLs containing userinfo (username:password@host) // This is checked again by security.ValidateExternalURL, but we keep it here // to ensure consistent behavior across all call paths. if parsed.User != nil { metrics.RecordURLValidation("error", "userinfo_not_allowed") if !isTestMode { security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error") } return false, 0, fmt.Errorf("urls with embedded credentials are not allowed") } // CRITICAL: Always validate the URL through centralized SSRF protection. // // Production defaults are deny-by-default: // - Only http/https allowed // - No localhost // - No private/reserved IPs // // Tests may opt-in to localhost via withAllowLocalhostForTesting(), without // weakening production behavior. validationOpts := []security.ValidationOption{ security.WithAllowHTTP(), } if options.allowLocalhost { validationOpts = append(validationOpts, security.WithAllowLocalhost()) } validatedURL, err := security.ValidateExternalURL(rawURL, validationOpts...) if err != nil { // ENHANCEMENT: Record SSRF block metrics // Determine the block reason from error message errMsg := err.Error() var blockReason string switch { case strings.Contains(errMsg, "private ip"): blockReason = "private_ip" metrics.RecordSSRFBlock("private", "system") // userID should come from context in production // ENHANCEMENT: Audit log the SSRF block security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "") case strings.Contains(errMsg, "cloud metadata"): blockReason = "metadata_endpoint" metrics.RecordSSRFBlock("metadata", "system") // ENHANCEMENT: Audit log the SSRF block security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "") case strings.Contains(errMsg, "dns resolution"): blockReason = "dns_failed" // ENHANCEMENT: Audit log the DNS failure security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error") default: blockReason = "validation_failed" // ENHANCEMENT: Audit log the validation failure security.LogURLTest(parsed.Hostname(), requestID, "system", "", "blocked") } metrics.RecordURLValidation("blocked", blockReason) // Transform error message for backward compatibility with existing tests // The security package uses lowercase in error messages, but tests expect mixed case errMsg = strings.Replace(errMsg, "dns resolution failed", "DNS resolution failed", 1) errMsg = strings.ReplaceAll(errMsg, "private ip", "private IP") // Cloud metadata endpoints are considered private IPs for test compatibility if strings.Contains(errMsg, "cloud metadata endpoints") { errMsg = strings.Replace(errMsg, "access to cloud metadata endpoints is blocked for security", "connection to private IP addresses is blocked for security", 1) } return false, 0, fmt.Errorf("security validation failed: %s", errMsg) } // ENHANCEMENT: Record successful validation metrics.RecordURLValidation("allowed", "validated") // ENHANCEMENT: Audit log successful validation (only in production to avoid test noise) if !isTestMode { security.LogURLTest(parsed.Hostname(), requestID, "system", "", "allowed") } // Use validated URL for requests (breaks taint chain) validatedRequestURL := validatedURL const ( requestTimeout = 5 * time.Second maxRedirects = 2 allowHTTPSUpgrade = true ) transport := &http.Transport{ // Explicitly ignore proxy environment variables for SSRF-sensitive requests. Proxy: nil, DialContext: ssrfSafeDialer(), MaxIdleConns: 1, IdleConnTimeout: requestTimeout, TLSHandshakeTimeout: requestTimeout, ResponseHeaderTimeout: requestTimeout, DisableKeepAlives: true, } if isTestMode { // Test-only override: allows deterministic unit tests without real network. transport = &http.Transport{ Proxy: nil, DisableKeepAlives: true, } // If the provided transport is an http.RoundTripper that is not an *http.Transport, // use it directly. } var rt http.RoundTripper = transport if isTestMode { rt = options.transport } client := &http.Client{ Timeout: requestTimeout, Transport: rt, CheckRedirect: func(req *http.Request, via []*http.Request) error { return validateRedirectTargetStrict(req, via, maxRedirects, allowHTTPSUpgrade, options.allowLocalhost) }, } // Perform HTTP HEAD request with strict timeout ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() start := time.Now() // Parse the validated URL to construct request from validated components // This breaks the taint chain for static analysis by using parsed URL components validatedParsed, err := url.Parse(validatedRequestURL) if err != nil { return false, 0, fmt.Errorf("failed to parse validated URL: %w", err) } // Normalize scheme to a constant value derived from an allowlisted set. // This avoids propagating the original input string into request construction. var safeScheme string switch validatedParsed.Scheme { case "http": safeScheme = "http" case "https": safeScheme = "https" default: return false, 0, fmt.Errorf("security validation failed: unsupported scheme") } // If we connect to an IP-literal for HTTPS, ensure TLS SNI still uses the hostname. if !isTestMode && safeScheme == "https" { transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, ServerName: validatedParsed.Hostname(), } } // Resolve to a concrete, allowed IP for the outbound request URL. // We still preserve the original hostname via Host header and TLS SNI. selectedIP, err := resolveAllowedIP(ctx, validatedParsed.Hostname(), options.allowLocalhost) if err != nil { return false, 0, fmt.Errorf("security validation failed: %s", err.Error()) } port := validatedParsed.Port() if port == "" { if safeScheme == "https" { port = "443" } else { port = "80" } } else { p, convErr := strconv.Atoi(port) if convErr != nil || p < 1 || p > 65535 { return false, 0, fmt.Errorf("security validation failed: invalid port") } port = strconv.Itoa(p) } // Construct a request URL from SSRF-safe components. // - Host is a resolved IP (selectedIP) to avoid hostname-based SSRF bypass // - Path is fixed to "/" because this is a connectivity test // nosemgrep: go.lang.security.audit.net.use-tls.use-tls safeURL := &url.URL{ Scheme: safeScheme, Host: net.JoinHostPort(selectedIP.String(), port), Path: "/", } req, err := http.NewRequestWithContext(ctx, http.MethodHead, safeURL.String(), http.NoBody) if err != nil { return false, 0, fmt.Errorf("failed to create request: %w", err) } // Preserve the original hostname in Host header for virtual-hosting. // This does not affect the destination IP (selectedIP), which is used in the URL. req.Host = validatedParsed.Host // Add custom User-Agent header req.Header.Set("User-Agent", "Charon-Health-Check/1.0") // ENHANCEMENT: Request Tracing Headers // These headers help track and identify URL test requests in logs req.Header.Set("X-Charon-Request-Type", "url-connectivity-test") req.Header.Set("X-Request-ID", requestID) // Use consistent request ID for tracing // SSRF Protection Summary: // This HTTP request is protected against SSRF by multiple defense layers: // 1. security.ValidateExternalURL() validates URL format, scheme, and performs // DNS resolution with private IP blocking (RFC 1918, loopback, link-local, metadata) // 2. ssrfSafeDialer() re-validates IPs at connection time (prevents DNS rebinding/TOCTOU) // 3. validateRedirectTarget() validates all redirect URLs in production // 4. safeURL is constructed from parsed/validated components (breaks taint chain) // See: internal/security/url_validator.go, internal/network/safeclient.go resp, err := client.Do(req) //nolint:bodyclose // Body closed via defer below latency = time.Since(start).Seconds() * 1000 // Convert to milliseconds // ENHANCEMENT: Record test duration metric (only in production to avoid test noise) if !isTestMode { durationSeconds := time.Since(startTime).Seconds() metrics.RecordURLTestDuration(durationSeconds) } if err != nil { return false, latency, fmt.Errorf("connection failed: %w", err) } defer func() { if err := resp.Body.Close(); err != nil { logger.Log().WithError(err).Warn("Failed to close response body") } }() // Accept 2xx and 3xx status codes as "reachable" if resp.StatusCode >= 200 && resp.StatusCode < 400 { return true, latency, nil } return false, latency, fmt.Errorf("server returned status %d", resp.StatusCode) } // validateRedirectTargetStrict validates HTTP redirects with SSRF protection. // It enforces: // - a hard redirect limit // - per-hop URL validation (scheme, userinfo, host, DNS, private/reserved IPs) // - scheme-change policy (deny by default; optionally allow http->https upgrade) func validateRedirectTargetStrict(req *http.Request, via []*http.Request, maxRedirects int, allowHTTPSUpgrade bool, allowLocalhost bool) error { if len(via) >= maxRedirects { return fmt.Errorf("too many redirects (max %d)", maxRedirects) } if len(via) > 0 { prevScheme := via[len(via)-1].URL.Scheme newScheme := req.URL.Scheme if newScheme != prevScheme { if !allowHTTPSUpgrade || prevScheme != "http" || newScheme != "https" { return fmt.Errorf("redirect scheme change blocked: %s -> %s", prevScheme, newScheme) } } } validationOpts := []security.ValidationOption{security.WithAllowHTTP(), security.WithTimeout(3 * time.Second)} if allowLocalhost { validationOpts = append(validationOpts, security.WithAllowLocalhost()) } _, err := security.ValidateExternalURL(req.URL.String(), validationOpts...) if err != nil { return fmt.Errorf("redirect target validation failed: %w", err) } return nil }