chore: clean .gitignore cache
This commit is contained in:
@@ -1,443 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/metrics"
|
||||
"github.com/Wikid82/charon/backend/internal/network"
|
||||
"github.com/Wikid82/charon/backend/internal/security"
|
||||
)
|
||||
|
||||
func resolveAllowedIP(ctx context.Context, host string, allowLocalhost bool) (net.IP, error) {
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("missing hostname")
|
||||
}
|
||||
|
||||
// Fast-path: IP literal.
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if allowLocalhost && ip.IsLoopback() {
|
||||
return ip, nil
|
||||
}
|
||||
if network.IsPrivateIP(ip) {
|
||||
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip)
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS resolution failed: %w", err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IP addresses found for host")
|
||||
}
|
||||
|
||||
var selected net.IP
|
||||
for _, ip := range ips {
|
||||
if allowLocalhost && ip.IP.IsLoopback() {
|
||||
if selected == nil {
|
||||
selected = ip.IP
|
||||
}
|
||||
continue
|
||||
}
|
||||
if network.IsPrivateIP(ip.IP) {
|
||||
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP)
|
||||
}
|
||||
if selected == nil {
|
||||
selected = ip.IP
|
||||
}
|
||||
}
|
||||
|
||||
if selected == nil {
|
||||
return nil, fmt.Errorf("no allowed IP addresses found for host")
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// ssrfSafeDialer creates a custom dialer that validates IP addresses at connection time.
|
||||
// This prevents DNS rebinding attacks by validating the IP just before connecting.
|
||||
// Returns a DialContext function suitable for use in http.Transport.
|
||||
func ssrfSafeDialer() func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
// Parse host and port from address
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address format: %w", err)
|
||||
}
|
||||
|
||||
// Resolve DNS with context timeout
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS resolution failed: %w", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IP addresses found for host")
|
||||
}
|
||||
|
||||
// Validate ALL resolved IPs - if any are private, reject immediately
|
||||
// Using centralized network.IsPrivateIP for consistent SSRF protection
|
||||
for _, ip := range ips {
|
||||
if network.IsPrivateIP(ip.IP) {
|
||||
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to the first valid IP (prevents DNS rebinding)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
return dialer.DialContext(ctx, netw, net.JoinHostPort(ips[0].IP.String(), port))
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Redirect validation is implemented by validateRedirectTargetStrict.
|
||||
|
||||
// TestURLConnectivity performs a server-side connectivity test with SSRF protection.
|
||||
// For testing purposes, an optional http.RoundTripper can be provided to bypass
|
||||
// DNS resolution and network calls.
|
||||
// Returns:
|
||||
// - reachable: true if URL returned 2xx-3xx status
|
||||
// - latency: round-trip time in milliseconds
|
||||
// - error: validation or connectivity error
|
||||
func TestURLConnectivity(rawURL string) (reachable bool, latency float64, err error) {
|
||||
// NOTE: This wrapper preserves the exported API while enforcing
|
||||
// deny-by-default SSRF-safe behavior.
|
||||
//
|
||||
// Do not add optional transports/options to the exported surface.
|
||||
// Tests can exercise alternative paths via unexported helpers.
|
||||
return testURLConnectivity(rawURL)
|
||||
}
|
||||
|
||||
type urlConnectivityOptions struct {
|
||||
transport http.RoundTripper
|
||||
allowLocalhost bool
|
||||
}
|
||||
|
||||
type urlConnectivityOption func(*urlConnectivityOptions)
|
||||
|
||||
//nolint:unused // Used in test files
|
||||
func withTransportForTesting(rt http.RoundTripper) urlConnectivityOption {
|
||||
return func(o *urlConnectivityOptions) {
|
||||
o.transport = rt
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:unused // Used in test files
|
||||
func withAllowLocalhostForTesting() urlConnectivityOption {
|
||||
return func(o *urlConnectivityOptions) {
|
||||
o.allowLocalhost = true
|
||||
}
|
||||
}
|
||||
|
||||
// testURLConnectivity is the implementation behind TestURLConnectivity.
|
||||
// It supports internal (package-only) hooks to keep unit tests deterministic
|
||||
// without weakening production defaults.
|
||||
func testURLConnectivity(rawURL string, opts ...urlConnectivityOption) (reachable bool, latency float64, err error) {
|
||||
// Track start time for metrics
|
||||
startTime := time.Now()
|
||||
|
||||
// Generate unique request ID for tracing
|
||||
requestID := fmt.Sprintf("test-%d", time.Now().UnixNano())
|
||||
|
||||
options := urlConnectivityOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
isTestMode := options.transport != nil
|
||||
|
||||
// Parse URL first to validate structure
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
// ENHANCEMENT: Record validation failure metric
|
||||
metrics.RecordURLValidation("error", "invalid_format")
|
||||
// ENHANCEMENT: Audit log the failed validation
|
||||
if !isTestMode {
|
||||
security.LogURLTest(rawURL, requestID, "system", "", "error")
|
||||
}
|
||||
return false, 0, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Validate scheme
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
// ENHANCEMENT: Record validation failure metric
|
||||
metrics.RecordURLValidation("error", "unsupported_scheme")
|
||||
// ENHANCEMENT: Audit log the failed validation
|
||||
if !isTestMode {
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
|
||||
}
|
||||
return false, 0, fmt.Errorf("only http and https schemes are allowed")
|
||||
}
|
||||
|
||||
// Reject URLs containing userinfo (username:password@host)
|
||||
// This is checked again by security.ValidateExternalURL, but we keep it here
|
||||
// to ensure consistent behavior across all call paths.
|
||||
if parsed.User != nil {
|
||||
metrics.RecordURLValidation("error", "userinfo_not_allowed")
|
||||
if !isTestMode {
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
|
||||
}
|
||||
return false, 0, fmt.Errorf("urls with embedded credentials are not allowed")
|
||||
}
|
||||
|
||||
// CRITICAL: Always validate the URL through centralized SSRF protection.
|
||||
//
|
||||
// Production defaults are deny-by-default:
|
||||
// - Only http/https allowed
|
||||
// - No localhost
|
||||
// - No private/reserved IPs
|
||||
//
|
||||
// Tests may opt-in to localhost via withAllowLocalhostForTesting(), without
|
||||
// weakening production behavior.
|
||||
validationOpts := []security.ValidationOption{
|
||||
security.WithAllowHTTP(),
|
||||
}
|
||||
if options.allowLocalhost {
|
||||
validationOpts = append(validationOpts, security.WithAllowLocalhost())
|
||||
}
|
||||
|
||||
validatedURL, err := security.ValidateExternalURL(rawURL, validationOpts...)
|
||||
if err != nil {
|
||||
// ENHANCEMENT: Record SSRF block metrics
|
||||
// Determine the block reason from error message
|
||||
errMsg := err.Error()
|
||||
var blockReason string
|
||||
switch {
|
||||
case strings.Contains(errMsg, "private ip"):
|
||||
blockReason = "private_ip"
|
||||
metrics.RecordSSRFBlock("private", "system") // userID should come from context in production
|
||||
// ENHANCEMENT: Audit log the SSRF block
|
||||
security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "")
|
||||
case strings.Contains(errMsg, "cloud metadata"):
|
||||
blockReason = "metadata_endpoint"
|
||||
metrics.RecordSSRFBlock("metadata", "system")
|
||||
// ENHANCEMENT: Audit log the SSRF block
|
||||
security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "")
|
||||
case strings.Contains(errMsg, "dns resolution"):
|
||||
blockReason = "dns_failed"
|
||||
// ENHANCEMENT: Audit log the DNS failure
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
|
||||
default:
|
||||
blockReason = "validation_failed"
|
||||
// ENHANCEMENT: Audit log the validation failure
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "blocked")
|
||||
}
|
||||
metrics.RecordURLValidation("blocked", blockReason)
|
||||
|
||||
// Transform error message for backward compatibility with existing tests
|
||||
// The security package uses lowercase in error messages, but tests expect mixed case
|
||||
errMsg = strings.Replace(errMsg, "dns resolution failed", "DNS resolution failed", 1)
|
||||
errMsg = strings.ReplaceAll(errMsg, "private ip", "private IP")
|
||||
// Cloud metadata endpoints are considered private IPs for test compatibility
|
||||
if strings.Contains(errMsg, "cloud metadata endpoints") {
|
||||
errMsg = strings.Replace(errMsg, "access to cloud metadata endpoints is blocked for security", "connection to private IP addresses is blocked for security", 1)
|
||||
}
|
||||
return false, 0, fmt.Errorf("security validation failed: %s", errMsg)
|
||||
}
|
||||
// ENHANCEMENT: Record successful validation
|
||||
metrics.RecordURLValidation("allowed", "validated")
|
||||
// ENHANCEMENT: Audit log successful validation (only in production to avoid test noise)
|
||||
if !isTestMode {
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "allowed")
|
||||
}
|
||||
|
||||
// Use validated URL for requests (breaks taint chain)
|
||||
validatedRequestURL := validatedURL
|
||||
|
||||
const (
|
||||
requestTimeout = 5 * time.Second
|
||||
maxRedirects = 2
|
||||
allowHTTPSUpgrade = true
|
||||
)
|
||||
|
||||
transport := &http.Transport{
|
||||
// Explicitly ignore proxy environment variables for SSRF-sensitive requests.
|
||||
Proxy: nil,
|
||||
DialContext: ssrfSafeDialer(),
|
||||
MaxIdleConns: 1,
|
||||
IdleConnTimeout: requestTimeout,
|
||||
TLSHandshakeTimeout: requestTimeout,
|
||||
ResponseHeaderTimeout: requestTimeout,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
|
||||
if isTestMode {
|
||||
// Test-only override: allows deterministic unit tests without real network.
|
||||
transport = &http.Transport{
|
||||
Proxy: nil,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
// If the provided transport is an http.RoundTripper that is not an *http.Transport,
|
||||
// use it directly.
|
||||
}
|
||||
|
||||
var rt http.RoundTripper = transport
|
||||
if isTestMode {
|
||||
rt = options.transport
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: requestTimeout,
|
||||
Transport: rt,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return validateRedirectTargetStrict(req, via, maxRedirects, allowHTTPSUpgrade, options.allowLocalhost)
|
||||
},
|
||||
}
|
||||
|
||||
// Perform HTTP HEAD request with strict timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
|
||||
// Parse the validated URL to construct request from validated components
|
||||
// This breaks the taint chain for static analysis by using parsed URL components
|
||||
validatedParsed, err := url.Parse(validatedRequestURL)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to parse validated URL: %w", err)
|
||||
}
|
||||
|
||||
// Normalize scheme to a constant value derived from an allowlisted set.
|
||||
// This avoids propagating the original input string into request construction.
|
||||
var safeScheme string
|
||||
switch validatedParsed.Scheme {
|
||||
case "http":
|
||||
safeScheme = "http"
|
||||
case "https":
|
||||
safeScheme = "https"
|
||||
default:
|
||||
return false, 0, fmt.Errorf("security validation failed: unsupported scheme")
|
||||
}
|
||||
|
||||
// If we connect to an IP-literal for HTTPS, ensure TLS SNI still uses the hostname.
|
||||
if !isTestMode && safeScheme == "https" {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: validatedParsed.Hostname(),
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve to a concrete, allowed IP for the outbound request URL.
|
||||
// We still preserve the original hostname via Host header and TLS SNI.
|
||||
selectedIP, err := resolveAllowedIP(ctx, validatedParsed.Hostname(), options.allowLocalhost)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("security validation failed: %s", err.Error())
|
||||
}
|
||||
|
||||
port := validatedParsed.Port()
|
||||
if port == "" {
|
||||
if safeScheme == "https" {
|
||||
port = "443"
|
||||
} else {
|
||||
port = "80"
|
||||
}
|
||||
} else {
|
||||
p, convErr := strconv.Atoi(port)
|
||||
if convErr != nil || p < 1 || p > 65535 {
|
||||
return false, 0, fmt.Errorf("security validation failed: invalid port")
|
||||
}
|
||||
port = strconv.Itoa(p)
|
||||
}
|
||||
|
||||
// Construct a request URL from SSRF-safe components.
|
||||
// - Host is a resolved IP (selectedIP) to avoid hostname-based SSRF bypass
|
||||
// - Path is fixed to "/" because this is a connectivity test
|
||||
// nosemgrep: go.lang.security.audit.net.use-tls.use-tls
|
||||
safeURL := &url.URL{
|
||||
Scheme: safeScheme,
|
||||
Host: net.JoinHostPort(selectedIP.String(), port),
|
||||
Path: "/",
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, safeURL.String(), http.NoBody)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Preserve the original hostname in Host header for virtual-hosting.
|
||||
// This does not affect the destination IP (selectedIP), which is used in the URL.
|
||||
req.Host = validatedParsed.Host
|
||||
|
||||
// Add custom User-Agent header
|
||||
req.Header.Set("User-Agent", "Charon-Health-Check/1.0")
|
||||
|
||||
// ENHANCEMENT: Request Tracing Headers
|
||||
// These headers help track and identify URL test requests in logs
|
||||
req.Header.Set("X-Charon-Request-Type", "url-connectivity-test")
|
||||
req.Header.Set("X-Request-ID", requestID) // Use consistent request ID for tracing
|
||||
|
||||
// SSRF Protection Summary:
|
||||
// This HTTP request is protected against SSRF by multiple defense layers:
|
||||
// 1. security.ValidateExternalURL() validates URL format, scheme, and performs
|
||||
// DNS resolution with private IP blocking (RFC 1918, loopback, link-local, metadata)
|
||||
// 2. ssrfSafeDialer() re-validates IPs at connection time (prevents DNS rebinding/TOCTOU)
|
||||
// 3. validateRedirectTarget() validates all redirect URLs in production
|
||||
// 4. safeURL is constructed from parsed/validated components (breaks taint chain)
|
||||
// See: internal/security/url_validator.go, internal/network/safeclient.go
|
||||
resp, err := client.Do(req) //nolint:bodyclose // Body closed via defer below
|
||||
latency = time.Since(start).Seconds() * 1000 // Convert to milliseconds
|
||||
|
||||
// ENHANCEMENT: Record test duration metric (only in production to avoid test noise)
|
||||
if !isTestMode {
|
||||
durationSeconds := time.Since(startTime).Seconds()
|
||||
metrics.RecordURLTestDuration(durationSeconds)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, latency, fmt.Errorf("connection failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
logger.Log().WithError(err).Warn("Failed to close response body")
|
||||
}
|
||||
}()
|
||||
|
||||
// Accept 2xx and 3xx status codes as "reachable"
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
|
||||
return true, latency, nil
|
||||
}
|
||||
|
||||
return false, latency, fmt.Errorf("server returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// validateRedirectTargetStrict validates HTTP redirects with SSRF protection.
|
||||
// It enforces:
|
||||
// - a hard redirect limit
|
||||
// - per-hop URL validation (scheme, userinfo, host, DNS, private/reserved IPs)
|
||||
// - scheme-change policy (deny by default; optionally allow http->https upgrade)
|
||||
func validateRedirectTargetStrict(req *http.Request, via []*http.Request, maxRedirects int, allowHTTPSUpgrade bool, allowLocalhost bool) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("too many redirects (max %d)", maxRedirects)
|
||||
}
|
||||
|
||||
if len(via) > 0 {
|
||||
prevScheme := via[len(via)-1].URL.Scheme
|
||||
newScheme := req.URL.Scheme
|
||||
if newScheme != prevScheme {
|
||||
if !allowHTTPSUpgrade || prevScheme != "http" || newScheme != "https" {
|
||||
return fmt.Errorf("redirect scheme change blocked: %s -> %s", prevScheme, newScheme)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validationOpts := []security.ValidationOption{security.WithAllowHTTP(), security.WithTimeout(3 * time.Second)}
|
||||
if allowLocalhost {
|
||||
validationOpts = append(validationOpts, security.WithAllowLocalhost())
|
||||
}
|
||||
|
||||
_, err := security.ValidateExternalURL(req.URL.String(), validationOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirect target validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user