chore: clean .gitignore cache

This commit is contained in:
GitHub Actions
2026-01-26 19:21:33 +00:00
parent 1b1b3a70b1
commit e5f0fec5db
1483 changed files with 0 additions and 472793 deletions

View File

@@ -1,52 +0,0 @@
package utils
import (
"net"
"github.com/Wikid82/charon/backend/internal/network"
)
// IsPrivateIP checks if the given host string is a private IPv4 address.
// Returns false for hostnames, invalid IPs, or public IP addresses.
//
// Deprecated: This function only checks IPv4. For comprehensive SSRF protection,
// use network.IsPrivateIP() directly which handles IPv4, IPv6, and IPv4-mapped IPv6.
func IsPrivateIP(host string) bool {
ip := net.ParseIP(host)
if ip == nil {
return false
}
// Ensure it's IPv4 (for backward compatibility)
ip4 := ip.To4()
if ip4 == nil {
return false
}
// Use centralized network.IsPrivateIP for consistent checking
return network.IsPrivateIP(ip)
}
// IsDockerBridgeIP checks if the given host string is likely a Docker bridge network IP.
// Docker typically uses 172.17.x.x for the default bridge and 172.18-31.x.x for user-defined networks.
// Returns false for hostnames, invalid IPs, or non-Docker IP addresses.
func IsDockerBridgeIP(host string) bool {
ip := net.ParseIP(host)
if ip == nil {
return false
}
// Ensure it's IPv4
ip4 := ip.To4()
if ip4 == nil {
return false
}
// Docker bridge network CIDR range: 172.16.0.0/12
_, dockerNetwork, err := net.ParseCIDR("172.16.0.0/12")
if err != nil {
return false
}
return dockerNetwork.Contains(ip4)
}

View File

@@ -1,294 +0,0 @@
package utils
import "testing"
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
// Private IP ranges - Class A (10.0.0.0/8)
{"10.0.0.1 is private", "10.0.0.1", true},
{"10.255.255.255 is private", "10.255.255.255", true},
{"10.10.10.10 is private", "10.10.10.10", true},
// Private IP ranges - Class B (172.16.0.0/12)
{"172.16.0.1 is private", "172.16.0.1", true},
{"172.31.255.255 is private", "172.31.255.255", true},
{"172.20.0.1 is private", "172.20.0.1", true},
// Private IP ranges - Class C (192.168.0.0/16)
{"192.168.1.1 is private", "192.168.1.1", true},
{"192.168.0.1 is private", "192.168.0.1", true},
{"192.168.255.255 is private", "192.168.255.255", true},
// Docker bridge IPs (subset of 172.16.0.0/12)
{"172.17.0.2 is private", "172.17.0.2", true},
{"172.18.0.5 is private", "172.18.0.5", true},
// Public IPs - should return false
{"8.8.8.8 is public", "8.8.8.8", false},
{"1.1.1.1 is public", "1.1.1.1", false},
{"142.250.80.14 is public", "142.250.80.14", false},
{"203.0.113.50 is public", "203.0.113.50", false},
// Edge cases for 172.x range (outside 172.16-31)
{"172.15.0.1 is public", "172.15.0.1", false},
{"172.32.0.1 is public", "172.32.0.1", false},
// Hostnames - should return false
{"nginx hostname", "nginx", false},
{"my-app hostname", "my-app", false},
{"app.local hostname", "app.local", false},
{"example.com hostname", "example.com", false},
{"my-container.internal hostname", "my-container.internal", false},
// Invalid inputs - should return false
{"empty string", "", false},
{"malformed IP", "192.168.1", false},
{"too many octets", "192.168.1.1.1", false},
{"negative octet", "192.168.-1.1", false},
{"octet out of range", "192.168.256.1", false},
{"letters in IP", "192.168.a.1", false},
{"IPv6 address", "::1", false},
{"IPv6 full address", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", false},
// Localhost and special addresses - these are now blocked for SSRF protection
{"localhost 127.0.0.1", "127.0.0.1", true},
{"0.0.0.0", "0.0.0.0", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsPrivateIP(tt.host)
if result != tt.expected {
t.Errorf("IsPrivateIP(%q) = %v, want %v", tt.host, result, tt.expected)
}
})
}
}
func TestIsDockerBridgeIP(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
// Docker default bridge (172.17.x.x)
{"172.17.0.1 is Docker bridge", "172.17.0.1", true},
{"172.17.0.2 is Docker bridge", "172.17.0.2", true},
{"172.17.255.255 is Docker bridge", "172.17.255.255", true},
// Docker user-defined networks (172.18-31.x.x)
{"172.18.0.1 is Docker bridge", "172.18.0.1", true},
{"172.18.0.5 is Docker bridge", "172.18.0.5", true},
{"172.20.0.1 is Docker bridge", "172.20.0.1", true},
{"172.31.0.1 is Docker bridge", "172.31.0.1", true},
{"172.31.255.255 is Docker bridge", "172.31.255.255", true},
// Also matches 172.16.x.x (part of 172.16.0.0/12)
{"172.16.0.1 is in Docker range", "172.16.0.1", true},
// Private IPs NOT in Docker bridge range
{"10.0.0.1 is not Docker bridge", "10.0.0.1", false},
{"192.168.1.1 is not Docker bridge", "192.168.1.1", false},
// Public IPs - should return false
{"8.8.8.8 is public", "8.8.8.8", false},
{"1.1.1.1 is public", "1.1.1.1", false},
// Edge cases for 172.x range (outside 172.16-31)
{"172.15.0.1 is outside Docker range", "172.15.0.1", false},
{"172.32.0.1 is outside Docker range", "172.32.0.1", false},
// Hostnames - should return false
{"nginx hostname", "nginx", false},
{"my-app hostname", "my-app", false},
{"container-name hostname", "container-name", false},
// Invalid inputs - should return false
{"empty string", "", false},
{"malformed IP", "172.17.0", false},
{"too many octets", "172.17.0.2.1", false},
{"letters in IP", "172.17.a.1", false},
{"IPv6 address", "::1", false},
// Special addresses
{"localhost 127.0.0.1", "127.0.0.1", false},
{"0.0.0.0", "0.0.0.0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsDockerBridgeIP(tt.host)
if result != tt.expected {
t.Errorf("IsDockerBridgeIP(%q) = %v, want %v", tt.host, result, tt.expected)
}
})
}
}
// TestIsPrivateIP_IPv4Mapped tests IPv4-mapped IPv6 addresses
func TestIsPrivateIP_IPv4Mapped(t *testing.T) {
// IPv4-mapped IPv6 addresses should be handled correctly
tests := []struct {
name string
host string
expected bool
}{
// net.ParseIP converts IPv4-mapped IPv6 to IPv4
{"::ffff:10.0.0.1 mapped", "::ffff:10.0.0.1", true},
{"::ffff:192.168.1.1 mapped", "::ffff:192.168.1.1", true},
{"::ffff:8.8.8.8 mapped", "::ffff:8.8.8.8", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsPrivateIP(tt.host)
if result != tt.expected {
t.Errorf("IsPrivateIP(%q) = %v, want %v", tt.host, result, tt.expected)
}
})
}
}
// ============== Phase 3.3: Additional IP Helpers Tests ==============
func TestIsPrivateIP_CIDRParseError(t *testing.T) {
// Temporarily modify the private IP ranges to include an invalid CIDR
// This tests graceful handling of CIDR parse errors
// Since we can't modify the package-level variable, we test the function behavior
// with edge cases that might trigger parsing issues
// Test with various invalid IP formats (should return false gracefully)
invalidInputs := []string{
"10.0.0.1/8", // CIDR notation (not a raw IP)
"10.0.0.256", // Invalid octet
"999.999.999.999", // Out of range
"10.0.0", // Incomplete
"not-an-ip", // Hostname
"", // Empty
"10.0.0.1.1", // Too many octets
}
for _, input := range invalidInputs {
t.Run(input, func(t *testing.T) {
result := IsPrivateIP(input)
// All invalid inputs should return false (not panic)
if result {
t.Errorf("IsPrivateIP(%q) = true, want false for invalid input", input)
}
})
}
}
func TestIsDockerBridgeIP_CIDRParseError(t *testing.T) {
// Test graceful handling of invalid inputs
invalidInputs := []string{
"172.17.0.1/16", // CIDR notation
"172.17.0.256", // Invalid octet
"999.999.999.999", // Out of range
"172.17", // Incomplete
"not-an-ip", // Hostname
"", // Empty
}
for _, input := range invalidInputs {
t.Run(input, func(t *testing.T) {
result := IsDockerBridgeIP(input)
// All invalid inputs should return false (not panic)
if result {
t.Errorf("IsDockerBridgeIP(%q) = true, want false for invalid input", input)
}
})
}
}
func TestIsPrivateIP_IPv6Comprehensive(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
// IPv6 Loopback
{"IPv6 loopback", "::1", false}, // Current implementation treats loopback as non-private
{"IPv6 loopback expanded", "0000:0000:0000:0000:0000:0000:0000:0001", false},
// IPv6 Link-Local (fe80::/10)
{"IPv6 link-local", "fe80::1", false},
{"IPv6 link-local 2", "fe80::abcd:ef01:2345:6789", false},
// IPv6 Unique Local (fc00::/7)
{"IPv6 unique local fc00", "fc00::1", false},
{"IPv6 unique local fd00", "fd00::1", false},
{"IPv6 unique local fdff", "fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", false},
// IPv6 Public addresses
{"IPv6 public Google DNS", "2001:4860:4860::8888", false},
{"IPv6 public Cloudflare", "2606:4700:4700::1111", false},
// IPv6 mapped IPv4
{"IPv6 mapped private", "::ffff:10.0.0.1", true},
{"IPv6 mapped public", "::ffff:8.8.8.8", false},
// Invalid IPv6
{"Invalid IPv6", "gggg::1", false},
{"Incomplete IPv6", "2001::", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsPrivateIP(tt.host)
if result != tt.expected {
t.Errorf("IsPrivateIP(%q) = %v, want %v", tt.host, result, tt.expected)
}
})
}
}
func TestIsDockerBridgeIP_EdgeCases(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
// Boundaries of 172.16.0.0/12 range
{"Lower boundary - 1", "172.15.255.255", false}, // Just outside
{"Lower boundary", "172.16.0.0", true}, // Start of range
{"Lower boundary + 1", "172.16.0.1", true},
{"Upper boundary - 1", "172.31.255.254", true},
{"Upper boundary", "172.31.255.255", true}, // End of range
{"Upper boundary + 1", "172.32.0.0", false}, // Just outside
{"Upper boundary + 2", "172.32.0.1", false},
// Docker default bridge (172.17.0.0/16)
{"Docker default bridge start", "172.17.0.0", true},
{"Docker default bridge gateway", "172.17.0.1", true},
{"Docker default bridge host", "172.17.0.2", true},
{"Docker default bridge end", "172.17.255.255", true},
// Docker user-defined networks
{"User network 1", "172.18.0.1", true},
{"User network 2", "172.19.0.1", true},
{"User network 30", "172.30.0.1", true},
{"User network 31", "172.31.0.1", true},
// Edge of 172.x range
{"172.0.0.1", "172.0.0.1", false}, // Below range
{"172.15.0.1", "172.15.0.1", false}, // Below range
{"172.32.0.1", "172.32.0.1", false}, // Above range
{"172.255.255.255", "172.255.255.255", false}, // Above range
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsDockerBridgeIP(tt.host)
if result != tt.expected {
t.Errorf("IsDockerBridgeIP(%q) = %v, want %v", tt.host, result, tt.expected)
}
})
}
}

View File

@@ -1,133 +0,0 @@
package utils
import (
"errors"
"net/url"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"github.com/Wikid82/charon/backend/internal/models"
)
// GetConfiguredPublicURL returns the configured, normalized public URL.
//
// Security note:
// This function intentionally never derives URLs from request data (Host/X-Forwarded-*),
// so it is safe to use for embedding external links (e.g., invite emails).
func GetConfiguredPublicURL(db *gorm.DB) (string, bool) {
var setting models.Setting
if err := db.Where("key = ?", "app.public_url").First(&setting).Error; err != nil {
return "", false
}
normalized, err := normalizeConfiguredPublicURL(setting.Value)
if err != nil {
return "", false
}
return normalized, true
}
// GetPublicURL retrieves the configured public URL or falls back to request host.
// This should be used for all user-facing URLs (emails, invite links).
func GetPublicURL(db *gorm.DB, c *gin.Context) string {
var setting models.Setting
if err := db.Where("key = ?", "app.public_url").First(&setting).Error; err == nil {
if setting.Value != "" {
return strings.TrimSuffix(setting.Value, "/")
}
}
// Fallback to request-derived URL
return getBaseURL(c)
}
func normalizeConfiguredPublicURL(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", errors.New("public URL is empty")
}
if strings.ContainsAny(raw, "\r\n") {
return "", errors.New("public URL contains invalid characters")
}
parsed, err := url.Parse(raw)
if err != nil {
return "", err
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", errors.New("public URL must use http or https")
}
if parsed.Host == "" {
return "", errors.New("public URL must include a host")
}
if parsed.User != nil {
return "", errors.New("public URL must not include userinfo")
}
if parsed.RawQuery != "" || parsed.Fragment != "" {
return "", errors.New("public URL must not include query or fragment")
}
if parsed.Path != "" && parsed.Path != "/" {
return "", errors.New("public URL must not include a path")
}
if parsed.Opaque != "" {
return "", errors.New("public URL must not be opaque")
}
normalized := (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host}).String()
return normalized, nil
}
// getBaseURL extracts the base URL from the request.
func getBaseURL(c *gin.Context) string {
scheme := "https"
if c.Request.TLS == nil {
// Check for X-Forwarded-Proto header
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
}
return scheme + "://" + c.Request.Host
}
// ValidateURL validates that a URL is properly formatted for use as an application URL.
// Returns error message if invalid, empty string if valid.
func ValidateURL(rawURL string) (normalized, warning string, err error) {
// Parse URL
parsed, parseErr := url.Parse(rawURL)
if parseErr != nil {
return "", "", parseErr
}
// Validate scheme
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", "", &url.Error{
Op: "parse",
URL: rawURL,
Err: nil,
}
}
// Warn if HTTP
if parsed.Scheme == "http" {
warning = "Using HTTP is not recommended. Consider using HTTPS for security."
}
// Reject URLs with path components beyond "/"
if parsed.Path != "" && parsed.Path != "/" {
return "", "", &url.Error{
Op: "validate",
URL: rawURL,
Err: nil,
}
}
// Normalize URL (remove trailing slash, keep scheme and host)
normalized = strings.TrimSuffix(rawURL, "/")
return normalized, warning, nil
}

View File

@@ -1,426 +0,0 @@
package utils
import (
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wikid82/charon/backend/internal/network"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockTransport is a custom http.RoundTripper for testing that bypasses network calls
type mockTransport struct {
statusCode int
headers http.Header
body string
err error
handler http.HandlerFunc // For dynamic responses
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if m.err != nil {
return nil, m.err
}
// Use handler if provided (for dynamic responses like redirects)
if m.handler != nil {
w := httptest.NewRecorder()
m.handler(w, req)
return w.Result(), nil
}
// Static response
resp := &http.Response{
StatusCode: m.statusCode,
Header: m.headers,
Body: io.NopCloser(strings.NewReader(m.body)),
Request: req,
}
return resp, nil
}
// TestTestURLConnectivity_Success verifies that valid public URLs are reachable
func TestTestURLConnectivity_Success(t *testing.T) {
transport := &mockTransport{
statusCode: http.StatusOK,
headers: http.Header{"Content-Type": []string{"text/html"}},
body: "",
}
reachable, latency, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
assert.NoError(t, err)
assert.True(t, reachable)
assert.Greater(t, latency, 0.0, "latency should be positive")
assert.Less(t, latency, 5000.0, "latency should be reasonable (< 5s)")
}
// TestTestURLConnectivity_Redirect verifies redirect handling
func TestTestURLConnectivity_Redirect(t *testing.T) {
redirectCount := 0
transport := &mockTransport{
handler: func(w http.ResponseWriter, r *http.Request) {
redirectCount++
// Only redirect once, then return OK
if redirectCount == 1 {
w.Header().Set("Location", "http://example.com/final")
w.WriteHeader(http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
},
}
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
assert.NoError(t, err)
assert.True(t, reachable)
assert.LessOrEqual(t, redirectCount, 3, "should follow max 2 redirects")
}
// TestTestURLConnectivity_TooManyRedirects verifies redirect limit enforcement
func TestTestURLConnectivity_TooManyRedirects(t *testing.T) {
transport := &mockTransport{
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", "/redirect")
w.WriteHeader(http.StatusFound)
},
}
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
assert.Error(t, err)
assert.False(t, reachable)
assert.Contains(t, err.Error(), "redirect", "error should mention redirects")
}
// TestTestURLConnectivity_StatusCodes verifies handling of different HTTP status codes
func TestTestURLConnectivity_StatusCodes(t *testing.T) {
testCases := []struct {
name string
statusCode int
expected bool
}{
{"200 OK", http.StatusOK, true},
{"201 Created", http.StatusCreated, true},
{"204 No Content", http.StatusNoContent, true},
{"301 Moved Permanently", http.StatusMovedPermanently, true},
{"302 Found", http.StatusFound, true},
{"400 Bad Request", http.StatusBadRequest, false},
{"401 Unauthorized", http.StatusUnauthorized, false},
{"403 Forbidden", http.StatusForbidden, false},
{"404 Not Found", http.StatusNotFound, false},
{"500 Internal Server Error", http.StatusInternalServerError, false},
{"503 Service Unavailable", http.StatusServiceUnavailable, false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
transport := &mockTransport{
statusCode: tc.statusCode,
}
reachable, latency, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
if tc.expected {
assert.NoError(t, err)
assert.True(t, reachable)
assert.Greater(t, latency, 0.0)
} else {
assert.Error(t, err)
assert.False(t, reachable)
assert.Contains(t, err.Error(), fmt.Sprintf("status %d", tc.statusCode))
}
})
}
}
// TestTestURLConnectivity_InvalidURL verifies invalid URL handling
func TestTestURLConnectivity_InvalidURL(t *testing.T) {
testCases := []struct {
name string
url string
}{
{"Empty URL", ""},
{"Invalid scheme", "ftp://example.com"},
{"Malformed URL", "http://[invalid"},
{"No scheme", "example.com"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// No transport needed - these fail at URL parsing
reachable, _, err := TestURLConnectivity(tc.url)
assert.Error(t, err)
assert.False(t, reachable)
// Latency varies depending on error type
// Some errors may still measure time before failing
})
}
}
// TestTestURLConnectivity_DNSFailure verifies DNS resolution error handling
func TestTestURLConnectivity_DNSFailure(t *testing.T) {
// Without transport, this will try real DNS and should fail
reachable, _, err := TestURLConnectivity("http://nonexistent-domain-12345.invalid")
assert.Error(t, err)
assert.False(t, reachable)
assert.Contains(t, err.Error(), "DNS resolution failed", "error should mention DNS failure")
}
// TestTestURLConnectivity_Timeout verifies timeout enforcement
func TestTestURLConnectivity_Timeout(t *testing.T) {
transport := &mockTransport{
err: fmt.Errorf("context deadline exceeded"),
}
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
assert.Error(t, err)
assert.False(t, reachable)
assert.Contains(t, err.Error(), "connection failed", "error should mention connection failure")
}
// TestIsPrivateIP_PrivateIPv4Ranges verifies blocking of private IPv4 ranges
func TestIsPrivateIP_PrivateIPv4Ranges(t *testing.T) {
testCases := []struct {
name string
ip string
expected bool
}{
// RFC 1918 Private Networks
{"10.0.0.0/8 start", "10.0.0.1", true},
{"10.0.0.0/8 mid", "10.128.0.1", true},
{"10.0.0.0/8 end", "10.255.255.254", true},
{"172.16.0.0/12 start", "172.16.0.1", true},
{"172.16.0.0/12 mid", "172.20.0.1", true},
{"172.16.0.0/12 end", "172.31.255.254", true},
{"192.168.0.0/16 start", "192.168.0.1", true},
{"192.168.0.0/16 end", "192.168.255.254", true},
// Loopback
{"127.0.0.1 localhost", "127.0.0.1", true},
{"127.0.0.0/8 start", "127.0.0.0", true},
{"127.0.0.0/8 end", "127.255.255.255", true},
// Link-Local (includes AWS/GCP metadata)
{"169.254.0.0/16 start", "169.254.0.1", true},
{"169.254.169.254 AWS metadata", "169.254.169.254", true},
{"169.254.0.0/16 end", "169.254.255.254", true},
// Reserved ranges
{"0.0.0.0/8", "0.0.0.1", true},
{"240.0.0.0/4", "240.0.0.1", true},
{"255.255.255.255 broadcast", "255.255.255.255", true},
// Public IPs (should NOT be blocked)
{"8.8.8.8 Google DNS", "8.8.8.8", false},
{"1.1.1.1 Cloudflare DNS", "1.1.1.1", false},
{"93.184.216.34 example.com", "93.184.216.34", false},
{"151.101.1.140 GitHub", "151.101.1.140", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ip := net.ParseIP(tc.ip)
require.NotNil(t, ip, "IP should parse successfully")
result := network.IsPrivateIP(ip)
assert.Equal(t, tc.expected, result,
"IP %s should be private=%v", tc.ip, tc.expected)
})
}
}
// TestIsPrivateIP_PrivateIPv6Ranges verifies blocking of private IPv6 ranges
func TestIsPrivateIP_PrivateIPv6Ranges(t *testing.T) {
testCases := []struct {
name string
ip string
expected bool
}{
// IPv6 Loopback
{"::1 loopback", "::1", true},
// IPv6 Link-Local
{"fe80::/10 start", "fe80::1", true},
{"fe80::/10 mid", "fe80:1234::5678", true},
// IPv6 Unique Local (RFC 4193)
{"fc00::/7 start", "fc00::1", true},
{"fc00::/7 mid", "fd12:3456:789a::1", true},
// Public IPv6 (should NOT be blocked)
{"2001:4860:4860::8888 Google DNS", "2001:4860:4860::8888", false},
{"2606:4700:4700::1111 Cloudflare DNS", "2606:4700:4700::1111", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ip := net.ParseIP(tc.ip)
require.NotNil(t, ip, "IP should parse successfully")
result := network.IsPrivateIP(ip)
assert.Equal(t, tc.expected, result,
"IP %s should be private=%v", tc.ip, tc.expected)
})
}
}
// TestTestURLConnectivity_PrivateIP_Blocked verifies SSRF protection
func TestTestURLConnectivity_PrivateIP_Blocked(t *testing.T) {
// Note: This test will fail if run on a system that actually resolves
// these hostnames to private IPs. In a production test environment,
// you might want to mock DNS resolution.
testCases := []struct {
name string
url string
}{
{"localhost", "http://localhost"},
{"127.0.0.1", "http://127.0.0.1"},
{"Private IP 10.x", "http://10.0.0.1"},
{"Private IP 192.168.x", "http://192.168.1.1"},
{"AWS metadata", "http://169.254.169.254"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reachable, _, err := TestURLConnectivity(tc.url)
// Should fail with private IP error
assert.Error(t, err)
assert.False(t, reachable)
assert.Contains(t, err.Error(), "private IP", "error should mention private IP blocking")
})
}
}
// TestTestURLConnectivity_SSRF_Protection_Comprehensive performs comprehensive SSRF tests
func TestTestURLConnectivity_SSRF_Protection_Comprehensive(t *testing.T) {
if testing.Short() {
t.Skip("Skipping comprehensive SSRF test in short mode")
}
// Test various SSRF attack vectors
attackVectors := []string{
"http://localhost:8080",
"http://127.0.0.1:8080",
"http://0.0.0.0:8080",
"http://[::1]:8080",
"http://169.254.169.254/latest/meta-data/",
"http://metadata.google.internal/computeMetadata/v1/",
}
for _, url := range attackVectors {
t.Run(url, func(t *testing.T) {
reachable, _, err := TestURLConnectivity(url)
// All should be blocked
assert.Error(t, err, "SSRF attack vector should be blocked")
assert.False(t, reachable)
})
}
}
// TestTestURLConnectivity_HTTPSSupport verifies HTTPS support
func TestTestURLConnectivity_HTTPSSupport(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Note: This will likely fail due to self-signed cert in test server
// but it demonstrates HTTPS support
reachable, _, err := TestURLConnectivity(server.URL)
// May fail due to cert validation, but should not panic
if err != nil {
t.Logf("HTTPS test failed (expected with self-signed cert): %v", err)
} else {
assert.True(t, reachable)
}
}
// BenchmarkTestURLConnectivity benchmarks the connectivity test
func BenchmarkTestURLConnectivity(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = TestURLConnectivity(server.URL)
}
}
// BenchmarkIsPrivateIP benchmarks private IP checking
func BenchmarkIsPrivateIP(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = network.IsPrivateIP(ip)
}
}
// TestTestURLConnectivity_RedirectLimit_ProductionPath verifies the production
// CheckRedirect callback enforces a maximum of 2 redirects (lines 93-97).
// This is a critical security feature to prevent redirect-based attacks.
func TestTestURLConnectivity_RedirectLimit_ProductionPath(t *testing.T) {
redirectCount := 0
// Use mock transport to bypass SSRF protection and test redirect limit specifically
transport := &mockTransport{
handler: func(w http.ResponseWriter, r *http.Request) {
redirectCount++
if redirectCount <= 3 { // Try to redirect 3 times
http.Redirect(w, r, "http://example.com/next", http.StatusFound)
} else {
w.WriteHeader(http.StatusOK)
}
},
}
// Test with transport (will use CheckRedirect callback from production path)
reachable, latency, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
// Should fail due to redirect limit
assert.False(t, reachable)
assert.Error(t, err)
assert.Contains(t, err.Error(), "redirect", "error should mention redirects")
assert.Greater(t, latency, 0.0, "should have some latency")
}
// TestTestURLConnectivity_InvalidPortFormat tests error when URL has invalid port format.
// This would trigger errors in net.SplitHostPort during dialing (lines 19-21).
func TestTestURLConnectivity_InvalidPortFormat(t *testing.T) {
// URL with invalid port will fail at URL parsing stage
reachable, _, err := TestURLConnectivity("http://example.com:badport")
assert.False(t, reachable)
assert.Error(t, err)
// URL parsing will catch the invalid port before we even get to dialing
assert.Contains(t, err.Error(), "invalid port")
}
// TestTestURLConnectivity_EmptyDNSResult tests the empty DNS results
// error path (lines 29-31).
func TestTestURLConnectivity_EmptyDNSResult(t *testing.T) {
// Create a custom transport that simulates empty DNS result
transport := &mockTransport{
err: fmt.Errorf("DNS resolution failed: no IP addresses found for host"),
}
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
assert.False(t, reachable)
assert.Error(t, err)
assert.Contains(t, err.Error(), "connection failed")
}

View File

@@ -1,478 +0,0 @@
package utils
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/Wikid82/charon/backend/internal/models"
)
// setupTestDB creates an in-memory SQLite database for testing
func setupTestDB(t *testing.T) *gorm.DB {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err, "failed to connect to test database")
// Auto-migrate the Setting model
err = db.AutoMigrate(&models.Setting{})
require.NoError(t, err, "failed to migrate database")
return db
}
// TestGetPublicURL_WithConfiguredURL verifies retrieval of configured public URL
func TestGetPublicURL_WithConfiguredURL(t *testing.T) {
db := setupTestDB(t)
// Insert a configured public URL
setting := models.Setting{
Key: "app.public_url",
Value: "https://example.com/",
}
err := db.Create(&setting).Error
require.NoError(t, err)
// Create test Gin context
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/test", http.NoBody)
c.Request = req
// Test GetPublicURL
publicURL := GetPublicURL(db, c)
// Should return configured URL with trailing slash removed
assert.Equal(t, "https://example.com", publicURL)
}
// TestGetPublicURL_WithTrailingSlash verifies trailing slash removal
func TestGetPublicURL_WithTrailingSlash(t *testing.T) {
db := setupTestDB(t)
// Insert URL with multiple trailing slashes
setting := models.Setting{
Key: "app.public_url",
Value: "https://example.com///",
}
err := db.Create(&setting).Error
require.NoError(t, err)
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/test", http.NoBody)
c.Request = req
publicURL := GetPublicURL(db, c)
// Should remove only the trailing slash (TrimSuffix removes one slash)
assert.Equal(t, "https://example.com//", publicURL)
}
// TestGetPublicURL_Fallback_HTTPSWithTLS verifies fallback to request URL with TLS
func TestGetPublicURL_Fallback_HTTPSWithTLS(t *testing.T) {
db := setupTestDB(t)
// No setting in DB - should fallback
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request with TLS
req := httptest.NewRequest(http.MethodGet, "https://myapp.com:8443/path", http.NoBody)
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
c.Request = req
publicURL := GetPublicURL(db, c)
// Should detect TLS and use https
assert.Equal(t, "https://myapp.com:8443", publicURL)
}
// TestGetPublicURL_Fallback_HTTP verifies fallback to HTTP when no TLS
func TestGetPublicURL_Fallback_HTTP(t *testing.T) {
db := setupTestDB(t)
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/test", http.NoBody)
c.Request = req
publicURL := GetPublicURL(db, c)
// Should use http scheme when no TLS
assert.Equal(t, "http://localhost:8080", publicURL)
}
// TestGetPublicURL_Fallback_XForwardedProto verifies X-Forwarded-Proto header handling
func TestGetPublicURL_Fallback_XForwardedProto(t *testing.T) {
db := setupTestDB(t)
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "http://internal-server:8080/test", http.NoBody)
req.Header.Set("X-Forwarded-Proto", "https")
c.Request = req
publicURL := GetPublicURL(db, c)
// Should respect X-Forwarded-Proto header
assert.Equal(t, "https://internal-server:8080", publicURL)
}
// TestGetPublicURL_EmptyValue verifies behavior with empty setting value
func TestGetPublicURL_EmptyValue(t *testing.T) {
db := setupTestDB(t)
// Insert setting with empty value
setting := models.Setting{
Key: "app.public_url",
Value: "",
}
err := db.Create(&setting).Error
require.NoError(t, err)
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "http://localhost:9000/test", http.NoBody)
c.Request = req
publicURL := GetPublicURL(db, c)
// Should fallback to request URL when value is empty
assert.Equal(t, "http://localhost:9000", publicURL)
}
// TestGetPublicURL_NoSettingInDB verifies behavior when setting doesn't exist
func TestGetPublicURL_NoSettingInDB(t *testing.T) {
db := setupTestDB(t)
// No setting created - should fallback
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "http://fallback-host.com/test", http.NoBody)
c.Request = req
publicURL := GetPublicURL(db, c)
// Should fallback to request host
assert.Equal(t, "http://fallback-host.com", publicURL)
}
// TestValidateURL_ValidHTTPS verifies validation of valid HTTPS URLs
func TestValidateURL_ValidHTTPS(t *testing.T) {
testCases := []struct {
name string
url string
normalized string
}{
{"HTTPS with trailing slash", "https://example.com/", "https://example.com"},
{"HTTPS without path", "https://example.com", "https://example.com"},
{"HTTPS with port", "https://example.com:8443", "https://example.com:8443"},
{"HTTPS with subdomain", "https://app.example.com", "https://app.example.com"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
normalized, warning, err := ValidateURL(tc.url)
assert.NoError(t, err)
assert.Equal(t, tc.normalized, normalized)
assert.Empty(t, warning, "HTTPS should not produce warning")
})
}
}
// TestValidateURL_ValidHTTP verifies validation of HTTP URLs with warning
func TestValidateURL_ValidHTTP(t *testing.T) {
testCases := []struct {
name string
url string
normalized string
}{
{"HTTP with trailing slash", "http://example.com/", "http://example.com"},
{"HTTP without path", "http://example.com", "http://example.com"},
{"HTTP with port", "http://example.com:8080", "http://example.com:8080"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
normalized, warning, err := ValidateURL(tc.url)
assert.NoError(t, err)
assert.Equal(t, tc.normalized, normalized)
assert.NotEmpty(t, warning, "HTTP should produce security warning")
assert.Contains(t, warning, "HTTP", "warning should mention HTTP")
assert.Contains(t, warning, "HTTPS", "warning should suggest HTTPS")
})
}
}
// TestValidateURL_InvalidScheme verifies rejection of non-HTTP/HTTPS schemes
func TestValidateURL_InvalidScheme(t *testing.T) {
testCases := []string{
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
"ssh://user@host",
}
for _, url := range testCases {
t.Run(url, func(t *testing.T) {
_, _, err := ValidateURL(url)
assert.Error(t, err, "non-HTTP(S) scheme should be rejected")
})
}
}
// TestValidateURL_WithPath verifies rejection of URLs with paths
func TestValidateURL_WithPath(t *testing.T) {
testCases := []string{
"https://example.com/api/v1",
"https://example.com/admin",
"http://example.com/path/to/resource",
"https://example.com/index.html",
}
for _, url := range testCases {
t.Run(url, func(t *testing.T) {
_, _, err := ValidateURL(url)
assert.Error(t, err, "URL with path should be rejected")
})
}
}
// TestValidateURL_RootPathAllowed verifies "/" path is allowed
func TestValidateURL_RootPathAllowed(t *testing.T) {
testCases := []string{
"https://example.com/",
"http://example.com/",
}
for _, url := range testCases {
t.Run(url, func(t *testing.T) {
normalized, _, err := ValidateURL(url)
assert.NoError(t, err, "root path '/' should be allowed")
// Trailing slash should be removed
assert.NotContains(t, normalized[len(normalized)-1:], "/", "normalized URL should not end with slash")
})
}
}
// TestValidateURL_MalformedURL verifies handling of malformed URLs
func TestValidateURL_MalformedURL(t *testing.T) {
testCases := []struct {
url string
shouldFail bool
}{
{"not a url", true},
{"://missing-scheme", true},
{"http://", false}, // Valid URL with empty host - Parse accepts it
{"https://[invalid", true},
{"", true},
}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
_, _, err := ValidateURL(tc.url)
if tc.shouldFail {
assert.Error(t, err, "malformed URL should be rejected")
} else {
// Some URLs that look malformed are actually valid per RFC
assert.NoError(t, err)
}
})
}
}
// TestValidateURL_SpecialCharacters verifies handling of special characters
func TestValidateURL_SpecialCharacters(t *testing.T) {
testCases := []struct {
name string
url string
isValid bool
}{
{"Punycode domain", "https://xn--e1afmkfd.xn--p1ai", true},
{"Port with special chars", "https://example.com:8080", true},
{"Query string (no path component)", "https://example.com?query=1", true}, // Query strings have empty Path
{"Fragment (no path component)", "https://example.com#section", true}, // Fragments have empty Path
{"Userinfo", "https://user:pass@example.com", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, _, err := ValidateURL(tc.url)
if tc.isValid {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
// TestValidateURL_Normalization verifies URL normalization
func TestValidateURL_Normalization(t *testing.T) {
testCases := []struct {
input string
expected string
shouldFail bool
}{
{"https://EXAMPLE.COM", "https://EXAMPLE.COM", false}, // Case preserved
{"https://example.com/", "https://example.com", false}, // Trailing slash removed
{"https://example.com///", "", true}, // Multiple slashes = path component, should fail
{"http://example.com:80", "http://example.com:80", false}, // Port preserved
{"https://example.com:443", "https://example.com:443", false}, // Default HTTPS port preserved
}
for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
normalized, _, err := ValidateURL(tc.input)
if tc.shouldFail {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, normalized)
}
})
}
}
// TestGetBaseURL verifies base URL extraction from request
func TestGetBaseURL(t *testing.T) {
testCases := []struct {
name string
host string
hasTLS bool
xForwardedProto string
expected string
}{
{
name: "HTTPS with TLS",
host: "secure.example.com",
hasTLS: true,
expected: "https://secure.example.com",
},
{
name: "HTTP without TLS",
host: "insecure.example.com",
hasTLS: false,
expected: "http://insecure.example.com",
},
{
name: "X-Forwarded-Proto HTTPS",
host: "behind-proxy.com",
hasTLS: false,
xForwardedProto: "https",
expected: "https://behind-proxy.com",
},
{
name: "X-Forwarded-Proto HTTP",
host: "behind-proxy.com",
hasTLS: false,
xForwardedProto: "http",
expected: "http://behind-proxy.com",
},
{
name: "With port",
host: "example.com:8080",
hasTLS: false,
expected: "http://example.com:8080",
},
{
name: "IPv4 host",
host: "192.168.1.1:8080",
hasTLS: false,
expected: "http://192.168.1.1:8080",
},
{
name: "IPv6 host",
host: "[::1]:8080",
hasTLS: false,
expected: "http://[::1]:8080",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Build request URL
scheme := "http"
if tc.hasTLS {
scheme = "https"
}
req := httptest.NewRequest(http.MethodGet, scheme+"://"+tc.host+"/test", http.NoBody)
// Set TLS if needed
if tc.hasTLS {
req.TLS = &tls.ConnectionState{}
}
// Set X-Forwarded-Proto if specified
if tc.xForwardedProto != "" {
req.Header.Set("X-Forwarded-Proto", tc.xForwardedProto)
}
c.Request = req
baseURL := getBaseURL(c)
assert.Equal(t, tc.expected, baseURL)
})
}
}
// TestGetBaseURL_PrecedenceOrder verifies header precedence
func TestGetBaseURL_PrecedenceOrder(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Request with TLS but also X-Forwarded-Proto
req := httptest.NewRequest(http.MethodGet, "https://example.com/test", http.NoBody)
req.TLS = &tls.ConnectionState{}
req.Header.Set("X-Forwarded-Proto", "http") // Should be ignored when TLS is present
c.Request = req
baseURL := getBaseURL(c)
// TLS should take precedence over header
assert.Equal(t, "https://example.com", baseURL)
}
// TestGetBaseURL_EmptyHost verifies behavior with empty host
func TestGetBaseURL_EmptyHost(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "http:///test", http.NoBody)
req.Host = "" // Empty host
c.Request = req
baseURL := getBaseURL(c)
// Should still return valid URL with empty host
assert.Equal(t, "http://", baseURL)
}

View File

@@ -1,443 +0,0 @@
package utils
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/Wikid82/charon/backend/internal/metrics"
"github.com/Wikid82/charon/backend/internal/network"
"github.com/Wikid82/charon/backend/internal/security"
)
func resolveAllowedIP(ctx context.Context, host string, allowLocalhost bool) (net.IP, error) {
if host == "" {
return nil, fmt.Errorf("missing hostname")
}
// Fast-path: IP literal.
if ip := net.ParseIP(host); ip != nil {
if allowLocalhost && ip.IsLoopback() {
return ip, nil
}
if network.IsPrivateIP(ip) {
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip)
}
return ip, nil
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("DNS resolution failed: %w", err)
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IP addresses found for host")
}
var selected net.IP
for _, ip := range ips {
if allowLocalhost && ip.IP.IsLoopback() {
if selected == nil {
selected = ip.IP
}
continue
}
if network.IsPrivateIP(ip.IP) {
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP)
}
if selected == nil {
selected = ip.IP
}
}
if selected == nil {
return nil, fmt.Errorf("no allowed IP addresses found for host")
}
return selected, nil
}
// ssrfSafeDialer creates a custom dialer that validates IP addresses at connection time.
// This prevents DNS rebinding attacks by validating the IP just before connecting.
// Returns a DialContext function suitable for use in http.Transport.
func ssrfSafeDialer() func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, netw, addr string) (net.Conn, error) {
// Parse host and port from address
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("invalid address format: %w", err)
}
// Resolve DNS with context timeout
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("DNS resolution failed: %w", err)
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IP addresses found for host")
}
// Validate ALL resolved IPs - if any are private, reject immediately
// Using centralized network.IsPrivateIP for consistent SSRF protection
for _, ip := range ips {
if network.IsPrivateIP(ip.IP) {
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP)
}
}
// Connect to the first valid IP (prevents DNS rebinding)
dialer := &net.Dialer{
Timeout: 5 * time.Second,
}
return dialer.DialContext(ctx, netw, net.JoinHostPort(ips[0].IP.String(), port))
}
}
// NOTE: Redirect validation is implemented by validateRedirectTargetStrict.
// TestURLConnectivity performs a server-side connectivity test with SSRF protection.
// For testing purposes, an optional http.RoundTripper can be provided to bypass
// DNS resolution and network calls.
// Returns:
// - reachable: true if URL returned 2xx-3xx status
// - latency: round-trip time in milliseconds
// - error: validation or connectivity error
func TestURLConnectivity(rawURL string) (reachable bool, latency float64, err error) {
// NOTE: This wrapper preserves the exported API while enforcing
// deny-by-default SSRF-safe behavior.
//
// Do not add optional transports/options to the exported surface.
// Tests can exercise alternative paths via unexported helpers.
return testURLConnectivity(rawURL)
}
type urlConnectivityOptions struct {
transport http.RoundTripper
allowLocalhost bool
}
type urlConnectivityOption func(*urlConnectivityOptions)
//nolint:unused // Used in test files
func withTransportForTesting(rt http.RoundTripper) urlConnectivityOption {
return func(o *urlConnectivityOptions) {
o.transport = rt
}
}
//nolint:unused // Used in test files
func withAllowLocalhostForTesting() urlConnectivityOption {
return func(o *urlConnectivityOptions) {
o.allowLocalhost = true
}
}
// testURLConnectivity is the implementation behind TestURLConnectivity.
// It supports internal (package-only) hooks to keep unit tests deterministic
// without weakening production defaults.
func testURLConnectivity(rawURL string, opts ...urlConnectivityOption) (reachable bool, latency float64, err error) {
// Track start time for metrics
startTime := time.Now()
// Generate unique request ID for tracing
requestID := fmt.Sprintf("test-%d", time.Now().UnixNano())
options := urlConnectivityOptions{}
for _, opt := range opts {
opt(&options)
}
isTestMode := options.transport != nil
// Parse URL first to validate structure
parsed, err := url.Parse(rawURL)
if err != nil {
// ENHANCEMENT: Record validation failure metric
metrics.RecordURLValidation("error", "invalid_format")
// ENHANCEMENT: Audit log the failed validation
if !isTestMode {
security.LogURLTest(rawURL, requestID, "system", "", "error")
}
return false, 0, fmt.Errorf("invalid URL: %w", err)
}
// Validate scheme
if parsed.Scheme != "http" && parsed.Scheme != "https" {
// ENHANCEMENT: Record validation failure metric
metrics.RecordURLValidation("error", "unsupported_scheme")
// ENHANCEMENT: Audit log the failed validation
if !isTestMode {
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
}
return false, 0, fmt.Errorf("only http and https schemes are allowed")
}
// Reject URLs containing userinfo (username:password@host)
// This is checked again by security.ValidateExternalURL, but we keep it here
// to ensure consistent behavior across all call paths.
if parsed.User != nil {
metrics.RecordURLValidation("error", "userinfo_not_allowed")
if !isTestMode {
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
}
return false, 0, fmt.Errorf("urls with embedded credentials are not allowed")
}
// CRITICAL: Always validate the URL through centralized SSRF protection.
//
// Production defaults are deny-by-default:
// - Only http/https allowed
// - No localhost
// - No private/reserved IPs
//
// Tests may opt-in to localhost via withAllowLocalhostForTesting(), without
// weakening production behavior.
validationOpts := []security.ValidationOption{
security.WithAllowHTTP(),
}
if options.allowLocalhost {
validationOpts = append(validationOpts, security.WithAllowLocalhost())
}
validatedURL, err := security.ValidateExternalURL(rawURL, validationOpts...)
if err != nil {
// ENHANCEMENT: Record SSRF block metrics
// Determine the block reason from error message
errMsg := err.Error()
var blockReason string
switch {
case strings.Contains(errMsg, "private ip"):
blockReason = "private_ip"
metrics.RecordSSRFBlock("private", "system") // userID should come from context in production
// ENHANCEMENT: Audit log the SSRF block
security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "")
case strings.Contains(errMsg, "cloud metadata"):
blockReason = "metadata_endpoint"
metrics.RecordSSRFBlock("metadata", "system")
// ENHANCEMENT: Audit log the SSRF block
security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "")
case strings.Contains(errMsg, "dns resolution"):
blockReason = "dns_failed"
// ENHANCEMENT: Audit log the DNS failure
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
default:
blockReason = "validation_failed"
// ENHANCEMENT: Audit log the validation failure
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "blocked")
}
metrics.RecordURLValidation("blocked", blockReason)
// Transform error message for backward compatibility with existing tests
// The security package uses lowercase in error messages, but tests expect mixed case
errMsg = strings.Replace(errMsg, "dns resolution failed", "DNS resolution failed", 1)
errMsg = strings.ReplaceAll(errMsg, "private ip", "private IP")
// Cloud metadata endpoints are considered private IPs for test compatibility
if strings.Contains(errMsg, "cloud metadata endpoints") {
errMsg = strings.Replace(errMsg, "access to cloud metadata endpoints is blocked for security", "connection to private IP addresses is blocked for security", 1)
}
return false, 0, fmt.Errorf("security validation failed: %s", errMsg)
}
// ENHANCEMENT: Record successful validation
metrics.RecordURLValidation("allowed", "validated")
// ENHANCEMENT: Audit log successful validation (only in production to avoid test noise)
if !isTestMode {
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "allowed")
}
// Use validated URL for requests (breaks taint chain)
validatedRequestURL := validatedURL
const (
requestTimeout = 5 * time.Second
maxRedirects = 2
allowHTTPSUpgrade = true
)
transport := &http.Transport{
// Explicitly ignore proxy environment variables for SSRF-sensitive requests.
Proxy: nil,
DialContext: ssrfSafeDialer(),
MaxIdleConns: 1,
IdleConnTimeout: requestTimeout,
TLSHandshakeTimeout: requestTimeout,
ResponseHeaderTimeout: requestTimeout,
DisableKeepAlives: true,
}
if isTestMode {
// Test-only override: allows deterministic unit tests without real network.
transport = &http.Transport{
Proxy: nil,
DisableKeepAlives: true,
}
// If the provided transport is an http.RoundTripper that is not an *http.Transport,
// use it directly.
}
var rt http.RoundTripper = transport
if isTestMode {
rt = options.transport
}
client := &http.Client{
Timeout: requestTimeout,
Transport: rt,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return validateRedirectTargetStrict(req, via, maxRedirects, allowHTTPSUpgrade, options.allowLocalhost)
},
}
// Perform HTTP HEAD request with strict timeout
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
start := time.Now()
// Parse the validated URL to construct request from validated components
// This breaks the taint chain for static analysis by using parsed URL components
validatedParsed, err := url.Parse(validatedRequestURL)
if err != nil {
return false, 0, fmt.Errorf("failed to parse validated URL: %w", err)
}
// Normalize scheme to a constant value derived from an allowlisted set.
// This avoids propagating the original input string into request construction.
var safeScheme string
switch validatedParsed.Scheme {
case "http":
safeScheme = "http"
case "https":
safeScheme = "https"
default:
return false, 0, fmt.Errorf("security validation failed: unsupported scheme")
}
// If we connect to an IP-literal for HTTPS, ensure TLS SNI still uses the hostname.
if !isTestMode && safeScheme == "https" {
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: validatedParsed.Hostname(),
}
}
// Resolve to a concrete, allowed IP for the outbound request URL.
// We still preserve the original hostname via Host header and TLS SNI.
selectedIP, err := resolveAllowedIP(ctx, validatedParsed.Hostname(), options.allowLocalhost)
if err != nil {
return false, 0, fmt.Errorf("security validation failed: %s", err.Error())
}
port := validatedParsed.Port()
if port == "" {
if safeScheme == "https" {
port = "443"
} else {
port = "80"
}
} else {
p, convErr := strconv.Atoi(port)
if convErr != nil || p < 1 || p > 65535 {
return false, 0, fmt.Errorf("security validation failed: invalid port")
}
port = strconv.Itoa(p)
}
// Construct a request URL from SSRF-safe components.
// - Host is a resolved IP (selectedIP) to avoid hostname-based SSRF bypass
// - Path is fixed to "/" because this is a connectivity test
// nosemgrep: go.lang.security.audit.net.use-tls.use-tls
safeURL := &url.URL{
Scheme: safeScheme,
Host: net.JoinHostPort(selectedIP.String(), port),
Path: "/",
}
req, err := http.NewRequestWithContext(ctx, http.MethodHead, safeURL.String(), http.NoBody)
if err != nil {
return false, 0, fmt.Errorf("failed to create request: %w", err)
}
// Preserve the original hostname in Host header for virtual-hosting.
// This does not affect the destination IP (selectedIP), which is used in the URL.
req.Host = validatedParsed.Host
// Add custom User-Agent header
req.Header.Set("User-Agent", "Charon-Health-Check/1.0")
// ENHANCEMENT: Request Tracing Headers
// These headers help track and identify URL test requests in logs
req.Header.Set("X-Charon-Request-Type", "url-connectivity-test")
req.Header.Set("X-Request-ID", requestID) // Use consistent request ID for tracing
// SSRF Protection Summary:
// This HTTP request is protected against SSRF by multiple defense layers:
// 1. security.ValidateExternalURL() validates URL format, scheme, and performs
// DNS resolution with private IP blocking (RFC 1918, loopback, link-local, metadata)
// 2. ssrfSafeDialer() re-validates IPs at connection time (prevents DNS rebinding/TOCTOU)
// 3. validateRedirectTarget() validates all redirect URLs in production
// 4. safeURL is constructed from parsed/validated components (breaks taint chain)
// See: internal/security/url_validator.go, internal/network/safeclient.go
resp, err := client.Do(req) //nolint:bodyclose // Body closed via defer below
latency = time.Since(start).Seconds() * 1000 // Convert to milliseconds
// ENHANCEMENT: Record test duration metric (only in production to avoid test noise)
if !isTestMode {
durationSeconds := time.Since(startTime).Seconds()
metrics.RecordURLTestDuration(durationSeconds)
}
if err != nil {
return false, latency, fmt.Errorf("connection failed: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
logger.Log().WithError(err).Warn("Failed to close response body")
}
}()
// Accept 2xx and 3xx status codes as "reachable"
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
return true, latency, nil
}
return false, latency, fmt.Errorf("server returned status %d", resp.StatusCode)
}
// validateRedirectTargetStrict validates HTTP redirects with SSRF protection.
// It enforces:
// - a hard redirect limit
// - per-hop URL validation (scheme, userinfo, host, DNS, private/reserved IPs)
// - scheme-change policy (deny by default; optionally allow http->https upgrade)
func validateRedirectTargetStrict(req *http.Request, via []*http.Request, maxRedirects int, allowHTTPSUpgrade bool, allowLocalhost bool) error {
if len(via) >= maxRedirects {
return fmt.Errorf("too many redirects (max %d)", maxRedirects)
}
if len(via) > 0 {
prevScheme := via[len(via)-1].URL.Scheme
newScheme := req.URL.Scheme
if newScheme != prevScheme {
if !allowHTTPSUpgrade || prevScheme != "http" || newScheme != "https" {
return fmt.Errorf("redirect scheme change blocked: %s -> %s", prevScheme, newScheme)
}
}
}
validationOpts := []security.ValidationOption{security.WithAllowHTTP(), security.WithTimeout(3 * time.Second)}
if allowLocalhost {
validationOpts = append(validationOpts, security.WithAllowLocalhost())
}
_, err := security.ValidateExternalURL(req.URL.String(), validationOpts...)
if err != nil {
return fmt.Errorf("redirect target validation failed: %w", err)
}
return nil
}

View File

@@ -1,489 +0,0 @@
package utils
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// TestTestURLConnectivity_EnhancedSSRF tests additional SSRF attack vectors.
// ENHANCEMENT: Comprehensive SSRF testing per Supervisor review
func TestTestURLConnectivity_EnhancedSSRF(t *testing.T) {
tests := []struct {
name string
url string
shouldFail bool
errContains string
}{
// Cloud Metadata Endpoints
{
name: "AWS metadata endpoint",
url: "http://169.254.169.254/latest/meta-data/",
shouldFail: true,
errContains: "private IP",
},
{
name: "GCP metadata endpoint",
url: "http://metadata.google.internal/computeMetadata/v1/",
shouldFail: true,
errContains: "DNS resolution failed", // Will fail to resolve
},
{
name: "Azure metadata endpoint",
url: "http://169.254.169.254/metadata/instance",
shouldFail: true,
errContains: "private IP",
},
// RFC 1918 Private Networks
{
name: "Private 10.0.0.0/8",
url: "http://10.0.0.1/admin",
shouldFail: true,
errContains: "private IP",
},
{
name: "Private 172.16.0.0/12",
url: "http://172.16.0.1/admin",
shouldFail: true,
errContains: "private IP",
},
{
name: "Private 192.168.0.0/16",
url: "http://192.168.1.1/admin",
shouldFail: true,
errContains: "private IP",
},
// Loopback Addresses
{
name: "IPv4 loopback",
url: "http://127.0.0.1:6379/",
shouldFail: true,
errContains: "private IP",
},
{
name: "IPv6 loopback",
url: "http://[::1]:6379/",
shouldFail: true,
errContains: "private IP",
},
// Link-Local Addresses
{
name: "Link-local IPv4",
url: "http://169.254.1.1/",
shouldFail: true,
errContains: "private IP",
},
{
name: "Link-local IPv6",
url: "http://[fe80::1]/",
shouldFail: true,
errContains: "private IP",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reachable, _, err := TestURLConnectivity(tt.url)
if tt.shouldFail {
if err == nil {
t.Errorf("Expected test to fail, but it succeeded")
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing '%s', got: %s", tt.errContains, err.Error())
}
if reachable {
t.Errorf("Expected reachable=false, got true")
}
} else if err != nil {
t.Errorf("Expected test to succeed, but got error: %s", err.Error())
}
})
}
}
// TestTestURLConnectivity_RedirectValidation tests that redirects are properly validated.
// ENHANCEMENT: Critical test per Supervisor review - all redirects must be validated
func TestTestURLConnectivity_RedirectValidation(t *testing.T) {
// Create test servers
privateServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Private server"))
}))
defer privateServer.Close()
t.Run("Redirect to private IP should be blocked", func(t *testing.T) {
// Note: This test requires actual redirect validation to work
// The validateRedirectTarget function will validate the Location header
// For now, we skip this test as it requires complex mock setup
t.Skip("Redirect validation test requires complex HTTP client mocking")
})
t.Run("Too many redirects should be blocked", func(t *testing.T) {
// Create a server that redirects to itself multiple times
redirectCount := 0
var redirectServerURL string
redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
if redirectCount < 5 {
http.Redirect(w, r, redirectServerURL+fmt.Sprintf("/%d", redirectCount), http.StatusFound)
} else {
w.WriteHeader(http.StatusOK)
}
}))
defer redirectServer.Close()
redirectServerURL = redirectServer.URL
transport := redirectServer.Client().Transport
reachable, _, err := testURLConnectivity(redirectServerURL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
// Should fail due to too many redirects (max 2)
if err == nil {
t.Errorf("Expected too many redirects to fail, but it succeeded")
}
if !strings.Contains(err.Error(), "redirect") {
t.Errorf("Expected error about redirects, got: %s", err.Error())
}
if reachable {
t.Errorf("Expected reachable=false, got true")
}
})
}
// TestTestURLConnectivity_UnicodeHomograph tests Unicode homograph attack prevention.
// ENHANCEMENT: Tests for internationalized domain name attacks
func TestTestURLConnectivity_UnicodeHomograph(t *testing.T) {
tests := []struct {
name string
url string
}{
{
name: "Cyrillic homograph",
url: "https://gооgle.com", // Uses Cyrillic 'о' instead of Latin 'o'
},
{
name: "Mixed script attack",
url: "https://раypal.com", // Uses Cyrillic 'а' and 'y'
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// These should fail DNS resolution since they're not real domains
reachable, _, err := TestURLConnectivity(tt.url)
if err == nil {
t.Logf("Warning: homograph domain %s resolved - may indicate IDN issue", tt.url)
}
if reachable {
t.Errorf("Homograph domain %s appears reachable - security concern", tt.url)
}
})
}
}
// TestTestURLConnectivity_LongHostname tests extremely long hostname handling.
// ENHANCEMENT: DoS prevention via hostname length validation
func TestTestURLConnectivity_LongHostname(t *testing.T) {
// Create hostname exceeding RFC 1035 limit (253 chars)
longHostname := strings.Repeat("a", 254) + ".com"
url := "https://" + longHostname + "/path"
reachable, _, err := TestURLConnectivity(url)
if err == nil {
t.Errorf("Expected long hostname to be rejected, but it was accepted")
}
if reachable {
t.Errorf("Expected reachable=false for long hostname, got true")
}
}
// TestTestURLConnectivity_RequestTracingHeaders tests that tracing headers are added.
// ENHANCEMENT: Verifies request tracing per Supervisor review
func TestTestURLConnectivity_RequestTracingHeaders(t *testing.T) {
var capturedHeaders http.Header
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
transport := testServer.Client().Transport
_, _, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
// Verify headers were set
if capturedHeaders.Get("User-Agent") != "Charon-Health-Check/1.0" {
t.Errorf("Expected User-Agent header, got: %s", capturedHeaders.Get("User-Agent"))
}
if capturedHeaders.Get("X-Charon-Request-Type") != "url-connectivity-test" {
t.Errorf("Expected X-Charon-Request-Type header, got: %s", capturedHeaders.Get("X-Charon-Request-Type"))
}
if capturedHeaders.Get("X-Request-ID") == "" {
t.Errorf("Expected X-Request-ID header to be set")
}
if !strings.HasPrefix(capturedHeaders.Get("X-Request-ID"), "test-") {
t.Errorf("Expected X-Request-ID to start with 'test-', got: %s", capturedHeaders.Get("X-Request-ID"))
}
}
// TestTestURLConnectivity_MetricsIntegration tests that metrics are recorded.
// ENHANCEMENT: Validates metrics collection per Supervisor review
func TestTestURLConnectivity_MetricsIntegration(t *testing.T) {
// This test verifies that metrics functions are called
// Full metrics validation requires integration tests with Prometheus
t.Run("Valid URL records metrics", func(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
transport := testServer.Client().Transport
reachable, latency, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
if !reachable {
t.Errorf("Expected reachable=true, got false")
}
if latency <= 0 {
t.Errorf("Expected positive latency, got %f", latency)
}
// Metrics recorded: URLValidation (allowed), URLTestDuration
})
t.Run("Blocked URL records metrics", func(t *testing.T) {
reachable, _, err := TestURLConnectivity("http://127.0.0.1:6379/")
if err == nil {
t.Errorf("Expected private IP to be blocked")
}
if reachable {
t.Errorf("Expected reachable=false, got true")
}
// Metrics recorded: SSRFBlock, URLValidation (blocked)
})
t.Run("Invalid URL records metrics", func(t *testing.T) {
reachable, _, err := TestURLConnectivity("not-a-valid-url")
if err == nil {
t.Errorf("Expected invalid URL to fail")
}
if reachable {
t.Errorf("Expected reachable=false, got true")
}
// Metrics recorded: URLValidation (error, unsupported_scheme)
})
}
// TestValidateRedirectTarget tests the redirect validation function directly.
// ENHANCEMENT: Direct unit tests for redirect validation per Phase 2 requirements
func TestValidateRedirectTarget(t *testing.T) {
tests := []struct {
name string
url string
viaCount int
shouldErr bool
errContains string
viaURL string
allowLocal bool
}{
{
name: "Localhost redirect allowed",
url: "http://localhost/path",
viaCount: 0,
shouldErr: false,
allowLocal: true,
},
{
name: "127.0.0.1 redirect allowed",
url: "http://127.0.0.1:8080/path",
viaCount: 0,
shouldErr: false,
allowLocal: true,
},
{
name: "IPv6 loopback allowed",
url: "http://[::1]:8080/path",
viaCount: 0,
shouldErr: false,
allowLocal: true,
},
{
name: "Too many redirects",
url: "http://localhost/path",
viaCount: 2,
shouldErr: true,
errContains: "too many redirects",
allowLocal: true,
},
{
name: "Three redirects",
url: "http://localhost/path",
viaCount: 3,
shouldErr: true,
errContains: "too many redirects",
allowLocal: true,
},
{
name: "Scheme downgrade blocked (https -> http)",
url: "http://localhost/next",
viaURL: "https://localhost/start",
viaCount: 1,
shouldErr: true,
errContains: "scheme change blocked",
allowLocal: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a request for the redirect target
req, err := http.NewRequest("GET", tt.url, http.NoBody)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Create via slice (previous requests)
via := make([]*http.Request, 0, tt.viaCount)
if tt.viaCount > 0 {
viaURL := tt.viaURL
if viaURL == "" {
viaURL = "http://localhost/prev"
}
prevReq, prevErr := http.NewRequest("GET", viaURL, http.NoBody)
if prevErr != nil {
t.Fatalf("Failed to create via request: %v", prevErr)
}
via = append(via, prevReq)
for i := 1; i < tt.viaCount; i++ {
via = append(via, prevReq)
}
}
err = validateRedirectTargetStrict(req, via, 2, true, tt.allowLocal)
if tt.shouldErr {
if err == nil {
t.Errorf("Expected error, got nil")
} else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing '%s', got: %s", tt.errContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
// TestTestURLConnectivity_AuditLogging tests that audit logging is integrated.
// ENHANCEMENT: Validates audit logging integration per Phase 3 requirements
func TestTestURLConnectivity_AuditLogging(t *testing.T) {
// Note: In production, audit logs go to the configured logger.
// These tests verify the code paths execute without panic.
t.Run("Invalid URL format logs audit event", func(t *testing.T) {
// This should trigger audit logging for invalid format
reachable, _, err := TestURLConnectivity("://invalid")
if err == nil {
t.Errorf("Expected error for invalid URL")
}
if reachable {
t.Errorf("Expected reachable=false")
}
})
t.Run("Invalid scheme logs audit event", func(t *testing.T) {
// This should trigger audit logging for unsupported scheme
reachable, _, err := TestURLConnectivity("ftp://example.com/file")
if err == nil {
t.Errorf("Expected error for unsupported scheme")
}
if reachable {
t.Errorf("Expected reachable=false")
}
})
t.Run("Private IP logs SSRF block audit event", func(t *testing.T) {
// This should trigger SSRF block audit logging
reachable, _, err := TestURLConnectivity("http://10.0.0.1/admin")
if err == nil {
t.Errorf("Expected error for private IP")
}
if reachable {
t.Errorf("Expected reachable=false")
}
})
t.Run("Metadata endpoint logs SSRF block audit event", func(t *testing.T) {
// This should trigger metadata endpoint block audit logging
reachable, _, err := TestURLConnectivity("http://169.254.169.254/latest/meta-data/")
if err == nil {
t.Errorf("Expected error for metadata endpoint")
}
if reachable {
t.Errorf("Expected reachable=false")
}
})
t.Run("Valid URL with mock transport logs success", func(t *testing.T) {
// Create a test server
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
// Note: With mock transport, audit logging is skipped (isTestMode=true)
// This test verifies the code path doesn't panic
transport := testServer.Client().Transport
reachable, _, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
if !reachable {
t.Errorf("Expected reachable=true")
}
})
}
// TestTestURLConnectivity_RequestIDConsistency tests that request ID is consistent.
// ENHANCEMENT: Validates request tracing per Phase 3 requirements
func TestTestURLConnectivity_RequestIDConsistency(t *testing.T) {
var capturedRequestID string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedRequestID = r.Header.Get("X-Request-ID")
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
transport := testServer.Client().Transport
_, _, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if capturedRequestID == "" {
t.Error("X-Request-ID header was not set")
}
if !strings.HasPrefix(capturedRequestID, "test-") {
t.Errorf("X-Request-ID should start with 'test-', got: %s", capturedRequestID)
}
}

View File

@@ -1,320 +0,0 @@
package utils
import (
"context"
"errors"
"fmt"
"testing"
"time"
)
// TestResolveAllowedIP_EmptyHostname tests resolveAllowedIP with empty hostname.
func TestResolveAllowedIP_EmptyHostname(t *testing.T) {
ctx := context.Background()
_, err := resolveAllowedIP(ctx, "", false)
if err == nil {
t.Fatal("Expected error for empty hostname, got nil")
}
if err.Error() != "missing hostname" {
t.Errorf("Expected 'missing hostname', got: %v", err)
}
}
// TestResolveAllowedIP_LoopbackIPLiteral tests resolveAllowedIP with loopback IPs.
func TestResolveAllowedIP_LoopbackIPLiteral(t *testing.T) {
tests := []struct {
name string
ip string
allowLocalhost bool
shouldFail bool
}{
{
name: "127.0.0.1 without allowLocalhost",
ip: "127.0.0.1",
allowLocalhost: false,
shouldFail: true,
},
{
name: "127.0.0.1 with allowLocalhost",
ip: "127.0.0.1",
allowLocalhost: true,
shouldFail: false,
},
{
name: "::1 without allowLocalhost",
ip: "::1",
allowLocalhost: false,
shouldFail: true,
},
{
name: "::1 with allowLocalhost",
ip: "::1",
allowLocalhost: true,
shouldFail: false,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip, err := resolveAllowedIP(ctx, tt.ip, tt.allowLocalhost)
if tt.shouldFail {
if err == nil {
t.Errorf("Expected error for %s without allowLocalhost", tt.ip)
}
} else {
if err != nil {
t.Fatal("ErroF for allowed loopback", err)
}
if ip == nil {
t.Fatal("Expected non-nil IP")
}
}
})
}
}
// TestResolveAllowedIP_PrivateIPLiterals tests resolveAllowedIP blocks private IPs.
func TestResolveAllowedIP_PrivateIPLiterals(t *testing.T) {
privateIPs := []string{
"10.0.0.1",
"172.16.0.1",
"192.168.1.1",
"169.254.169.254", // AWS metadata
"fc00::1", // IPv6 unique local
"fe80::1", // IPv6 link-local
}
ctx := context.Background()
for _, ip := range privateIPs {
t.Run("IP_"+ip, func(t *testing.T) {
_, err := resolveAllowedIP(ctx, ip, false)
if err == nil {
t.Errorf("Expected error for private IP %s, got nil", ip)
}
if err != nil && err.Error() != fmt.Sprintf("access to private IP addresses is blocked (resolved to %s)", ip) {
// Check it contains the expected error substring
expectedMsg := "access to private IP addresses is blocked"
if !contains(err.Error(), expectedMsg) {
t.Errorf("Expected error containing '%s', got: %v", expectedMsg, err)
}
}
})
}
}
// TestResolveAllowedIP_PublicIPLiteral tests resolveAllowedIP allows public IPs.
func TestResolveAllowedIP_PublicIPLiteral(t *testing.T) {
publicIPs := []string{
"8.8.8.8",
"1.1.1.1",
"2001:4860:4860::8888",
}
ctx := context.Background()
for _, ipStr := range publicIPs {
t.Run("IP_"+ipStr, func(t *testing.T) {
ip, err := resolveAllowedIP(ctx, ipStr, false)
if err != nil {
t.Errorf("Expected no error for public IP %s, got: %v", ipStr, err)
}
if ip == nil {
t.Error("Expected non-nil IP for public address")
}
})
}
}
// TestResolveAllowedIP_Timeout tests DNS resolution timeout.
func TestResolveAllowedIP_Timeout(t *testing.T) {
// Create a context with very short timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
// Any hostname should timeout
_, err := resolveAllowedIP(ctx, "example.com", false)
if err == nil {
t.Fatal("Expected timeout error, got nil")
}
// Should be a context deadline exceeded error
if !errors.Is(err, context.DeadlineExceeded) && !contains(err.Error(), "deadline") && !contains(err.Error(), "timeout") {
t.Logf("Expected timeout/deadline error, got: %v", err)
}
}
// TestResolveAllowedIP_NoIPsResolved tests when DNS returns no IPs.
// Note: This is difficult to test without a custom resolver, so we skip it
func TestResolveAllowedIP_NoIPsResolved(t *testing.T) {
t.Skip("Requires custom DNS resolver to return empty IP list")
}
// TestSSRFSafeDialer_PrivateIPResolution tests that ssrfSafeDialer blocks private IPs.
// Note: This requires network access or mocking, so we test the concept
func TestSSRFSafeDialer_Concept(t *testing.T) {
// The ssrfSafeDialer function should:
// 1. Resolve the hostname to IPs
// 2. Check ALL IPs against private ranges
// 3. Reject if ANY IP is private
// 4. Connect only to validated IPs
// We can't easily test this without network calls, but we document the behavior
t.Log("ssrfSafeDialer validates IPs at dial time to prevent DNS rebinding")
t.Log("All resolved IPs must pass private IP check before connection")
}
// TestSSRFSafeDialer_InvalidAddress tests ssrfSafeDialer with malformed addresses.
func TestSSRFSafeDialer_InvalidAddress(t *testing.T) {
ctx := context.Background()
dialer := ssrfSafeDialer()
tests := []struct {
name string
addr string
}{
{
name: "No port",
addr: "example.com",
},
{
name: "Invalid format",
addr: ":/invalid",
},
{
name: "Empty address",
addr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := dialer(ctx, "tcp", tt.addr)
if err == nil {
t.Errorf("Expected error for invalid address %s, got nil", tt.addr)
}
})
}
}
// TestSSRFSafeDialer_ContextCancellation tests context cancellation during dial.
func TestSSRFSafeDialer_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
dialer := ssrfSafeDialer()
_, err := dialer(ctx, "tcp", "example.com:80")
if err == nil {
t.Fatal("Expected context cancellation error, got nil")
}
// Should be context canceled error
if !errors.Is(err, context.Canceled) && !contains(err.Error(), "canceled") {
t.Logf("Expected context canceled error, got: %v", err)
}
}
// TestTestURLConnectivity_ErrorPaths tests error handling in testURLConnectivity.
func TestTestURLConnectivity_ErrorPaths(t *testing.T) {
tests := []struct {
name string
url string
errMatch string
}{
{
name: "Invalid URL format",
url: "://invalid",
errMatch: "invalid URL",
},
{
name: "Unsupported scheme FTP",
url: "ftp://example.com",
errMatch: "only http and https schemes are allowed",
},
{
name: "Embedded credentials",
url: "https://user:pass@example.com",
errMatch: "embedded credentials are not allowed",
},
{
name: "Private IP 10.x",
url: "http://10.0.0.1",
errMatch: "private IP",
},
{
name: "Private IP 192.168.x",
url: "http://192.168.1.1",
errMatch: "private IP",
},
{
name: "AWS metadata endpoint",
url: "http://169.254.169.254/latest/meta-data/",
errMatch: "private IP",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reachable, latency, err := TestURLConnectivity(tt.url)
if err == nil {
t.Fatalf("Expected error for %s, got nil (reachable=%v, latency=%v)", tt.url, reachable, latency)
}
if !contains(err.Error(), tt.errMatch) {
t.Errorf("Expected error containing '%s', got: %v", tt.errMatch, err)
}
if reachable {
t.Error("Expected reachable=false for error case")
}
})
}
}
// TestTestURLConnectivity_InvalidPort tests port validation in testURLConnectivity.
func TestTestURLConnectivity_InvalidPort(t *testing.T) {
tests := []struct {
name string
url string
}{
{
name: "Port out of range (too high)",
url: "https://example.com:99999",
},
{
name: "Port zero",
url: "https://example.com:0",
},
{
name: "Negative port",
url: "https://example.com:-1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, err := TestURLConnectivity(tt.url)
if err == nil {
t.Errorf("Expected error for invalid port in %s", tt.url)
}
})
}
}
// TestValidateRedirectTargetStrict tests are in url_testing_test.go using proper http types
// Helper function already defined in security tests
func contains(s, substr string) bool {
return len(s) >= len(substr) && containsSubstring(s, substr)
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}