chore: clean .gitignore cache
This commit is contained in:
@@ -1,52 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/network"
|
||||
)
|
||||
|
||||
// IsPrivateIP checks if the given host string is a private IPv4 address.
|
||||
// Returns false for hostnames, invalid IPs, or public IP addresses.
|
||||
//
|
||||
// Deprecated: This function only checks IPv4. For comprehensive SSRF protection,
|
||||
// use network.IsPrivateIP() directly which handles IPv4, IPv6, and IPv4-mapped IPv6.
|
||||
func IsPrivateIP(host string) bool {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Ensure it's IPv4 (for backward compatibility)
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use centralized network.IsPrivateIP for consistent checking
|
||||
return network.IsPrivateIP(ip)
|
||||
}
|
||||
|
||||
// IsDockerBridgeIP checks if the given host string is likely a Docker bridge network IP.
|
||||
// Docker typically uses 172.17.x.x for the default bridge and 172.18-31.x.x for user-defined networks.
|
||||
// Returns false for hostnames, invalid IPs, or non-Docker IP addresses.
|
||||
func IsDockerBridgeIP(host string) bool {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Ensure it's IPv4
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Docker bridge network CIDR range: 172.16.0.0/12
|
||||
_, dockerNetwork, err := net.ParseCIDR("172.16.0.0/12")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return dockerNetwork.Contains(ip4)
|
||||
}
|
||||
@@ -1,294 +0,0 @@
|
||||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
// Private IP ranges - Class A (10.0.0.0/8)
|
||||
{"10.0.0.1 is private", "10.0.0.1", true},
|
||||
{"10.255.255.255 is private", "10.255.255.255", true},
|
||||
{"10.10.10.10 is private", "10.10.10.10", true},
|
||||
|
||||
// Private IP ranges - Class B (172.16.0.0/12)
|
||||
{"172.16.0.1 is private", "172.16.0.1", true},
|
||||
{"172.31.255.255 is private", "172.31.255.255", true},
|
||||
{"172.20.0.1 is private", "172.20.0.1", true},
|
||||
|
||||
// Private IP ranges - Class C (192.168.0.0/16)
|
||||
{"192.168.1.1 is private", "192.168.1.1", true},
|
||||
{"192.168.0.1 is private", "192.168.0.1", true},
|
||||
{"192.168.255.255 is private", "192.168.255.255", true},
|
||||
|
||||
// Docker bridge IPs (subset of 172.16.0.0/12)
|
||||
{"172.17.0.2 is private", "172.17.0.2", true},
|
||||
{"172.18.0.5 is private", "172.18.0.5", true},
|
||||
|
||||
// Public IPs - should return false
|
||||
{"8.8.8.8 is public", "8.8.8.8", false},
|
||||
{"1.1.1.1 is public", "1.1.1.1", false},
|
||||
{"142.250.80.14 is public", "142.250.80.14", false},
|
||||
{"203.0.113.50 is public", "203.0.113.50", false},
|
||||
|
||||
// Edge cases for 172.x range (outside 172.16-31)
|
||||
{"172.15.0.1 is public", "172.15.0.1", false},
|
||||
{"172.32.0.1 is public", "172.32.0.1", false},
|
||||
|
||||
// Hostnames - should return false
|
||||
{"nginx hostname", "nginx", false},
|
||||
{"my-app hostname", "my-app", false},
|
||||
{"app.local hostname", "app.local", false},
|
||||
{"example.com hostname", "example.com", false},
|
||||
{"my-container.internal hostname", "my-container.internal", false},
|
||||
|
||||
// Invalid inputs - should return false
|
||||
{"empty string", "", false},
|
||||
{"malformed IP", "192.168.1", false},
|
||||
{"too many octets", "192.168.1.1.1", false},
|
||||
{"negative octet", "192.168.-1.1", false},
|
||||
{"octet out of range", "192.168.256.1", false},
|
||||
{"letters in IP", "192.168.a.1", false},
|
||||
{"IPv6 address", "::1", false},
|
||||
{"IPv6 full address", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", false},
|
||||
|
||||
// Localhost and special addresses - these are now blocked for SSRF protection
|
||||
{"localhost 127.0.0.1", "127.0.0.1", true},
|
||||
{"0.0.0.0", "0.0.0.0", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsPrivateIP(tt.host)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsPrivateIP(%q) = %v, want %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDockerBridgeIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
// Docker default bridge (172.17.x.x)
|
||||
{"172.17.0.1 is Docker bridge", "172.17.0.1", true},
|
||||
{"172.17.0.2 is Docker bridge", "172.17.0.2", true},
|
||||
{"172.17.255.255 is Docker bridge", "172.17.255.255", true},
|
||||
|
||||
// Docker user-defined networks (172.18-31.x.x)
|
||||
{"172.18.0.1 is Docker bridge", "172.18.0.1", true},
|
||||
{"172.18.0.5 is Docker bridge", "172.18.0.5", true},
|
||||
{"172.20.0.1 is Docker bridge", "172.20.0.1", true},
|
||||
{"172.31.0.1 is Docker bridge", "172.31.0.1", true},
|
||||
{"172.31.255.255 is Docker bridge", "172.31.255.255", true},
|
||||
|
||||
// Also matches 172.16.x.x (part of 172.16.0.0/12)
|
||||
{"172.16.0.1 is in Docker range", "172.16.0.1", true},
|
||||
|
||||
// Private IPs NOT in Docker bridge range
|
||||
{"10.0.0.1 is not Docker bridge", "10.0.0.1", false},
|
||||
{"192.168.1.1 is not Docker bridge", "192.168.1.1", false},
|
||||
|
||||
// Public IPs - should return false
|
||||
{"8.8.8.8 is public", "8.8.8.8", false},
|
||||
{"1.1.1.1 is public", "1.1.1.1", false},
|
||||
|
||||
// Edge cases for 172.x range (outside 172.16-31)
|
||||
{"172.15.0.1 is outside Docker range", "172.15.0.1", false},
|
||||
{"172.32.0.1 is outside Docker range", "172.32.0.1", false},
|
||||
|
||||
// Hostnames - should return false
|
||||
{"nginx hostname", "nginx", false},
|
||||
{"my-app hostname", "my-app", false},
|
||||
{"container-name hostname", "container-name", false},
|
||||
|
||||
// Invalid inputs - should return false
|
||||
{"empty string", "", false},
|
||||
{"malformed IP", "172.17.0", false},
|
||||
{"too many octets", "172.17.0.2.1", false},
|
||||
{"letters in IP", "172.17.a.1", false},
|
||||
{"IPv6 address", "::1", false},
|
||||
|
||||
// Special addresses
|
||||
{"localhost 127.0.0.1", "127.0.0.1", false},
|
||||
{"0.0.0.0", "0.0.0.0", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsDockerBridgeIP(tt.host)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsDockerBridgeIP(%q) = %v, want %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateIP_IPv4Mapped tests IPv4-mapped IPv6 addresses
|
||||
func TestIsPrivateIP_IPv4Mapped(t *testing.T) {
|
||||
// IPv4-mapped IPv6 addresses should be handled correctly
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
// net.ParseIP converts IPv4-mapped IPv6 to IPv4
|
||||
{"::ffff:10.0.0.1 mapped", "::ffff:10.0.0.1", true},
|
||||
{"::ffff:192.168.1.1 mapped", "::ffff:192.168.1.1", true},
|
||||
{"::ffff:8.8.8.8 mapped", "::ffff:8.8.8.8", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsPrivateIP(tt.host)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsPrivateIP(%q) = %v, want %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============== Phase 3.3: Additional IP Helpers Tests ==============
|
||||
|
||||
func TestIsPrivateIP_CIDRParseError(t *testing.T) {
|
||||
// Temporarily modify the private IP ranges to include an invalid CIDR
|
||||
// This tests graceful handling of CIDR parse errors
|
||||
|
||||
// Since we can't modify the package-level variable, we test the function behavior
|
||||
// with edge cases that might trigger parsing issues
|
||||
|
||||
// Test with various invalid IP formats (should return false gracefully)
|
||||
invalidInputs := []string{
|
||||
"10.0.0.1/8", // CIDR notation (not a raw IP)
|
||||
"10.0.0.256", // Invalid octet
|
||||
"999.999.999.999", // Out of range
|
||||
"10.0.0", // Incomplete
|
||||
"not-an-ip", // Hostname
|
||||
"", // Empty
|
||||
"10.0.0.1.1", // Too many octets
|
||||
}
|
||||
|
||||
for _, input := range invalidInputs {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
result := IsPrivateIP(input)
|
||||
// All invalid inputs should return false (not panic)
|
||||
if result {
|
||||
t.Errorf("IsPrivateIP(%q) = true, want false for invalid input", input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDockerBridgeIP_CIDRParseError(t *testing.T) {
|
||||
// Test graceful handling of invalid inputs
|
||||
invalidInputs := []string{
|
||||
"172.17.0.1/16", // CIDR notation
|
||||
"172.17.0.256", // Invalid octet
|
||||
"999.999.999.999", // Out of range
|
||||
"172.17", // Incomplete
|
||||
"not-an-ip", // Hostname
|
||||
"", // Empty
|
||||
}
|
||||
|
||||
for _, input := range invalidInputs {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
result := IsDockerBridgeIP(input)
|
||||
// All invalid inputs should return false (not panic)
|
||||
if result {
|
||||
t.Errorf("IsDockerBridgeIP(%q) = true, want false for invalid input", input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateIP_IPv6Comprehensive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
// IPv6 Loopback
|
||||
{"IPv6 loopback", "::1", false}, // Current implementation treats loopback as non-private
|
||||
{"IPv6 loopback expanded", "0000:0000:0000:0000:0000:0000:0000:0001", false},
|
||||
|
||||
// IPv6 Link-Local (fe80::/10)
|
||||
{"IPv6 link-local", "fe80::1", false},
|
||||
{"IPv6 link-local 2", "fe80::abcd:ef01:2345:6789", false},
|
||||
|
||||
// IPv6 Unique Local (fc00::/7)
|
||||
{"IPv6 unique local fc00", "fc00::1", false},
|
||||
{"IPv6 unique local fd00", "fd00::1", false},
|
||||
{"IPv6 unique local fdff", "fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", false},
|
||||
|
||||
// IPv6 Public addresses
|
||||
{"IPv6 public Google DNS", "2001:4860:4860::8888", false},
|
||||
{"IPv6 public Cloudflare", "2606:4700:4700::1111", false},
|
||||
|
||||
// IPv6 mapped IPv4
|
||||
{"IPv6 mapped private", "::ffff:10.0.0.1", true},
|
||||
{"IPv6 mapped public", "::ffff:8.8.8.8", false},
|
||||
|
||||
// Invalid IPv6
|
||||
{"Invalid IPv6", "gggg::1", false},
|
||||
{"Incomplete IPv6", "2001::", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsPrivateIP(tt.host)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsPrivateIP(%q) = %v, want %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDockerBridgeIP_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
// Boundaries of 172.16.0.0/12 range
|
||||
{"Lower boundary - 1", "172.15.255.255", false}, // Just outside
|
||||
{"Lower boundary", "172.16.0.0", true}, // Start of range
|
||||
{"Lower boundary + 1", "172.16.0.1", true},
|
||||
|
||||
{"Upper boundary - 1", "172.31.255.254", true},
|
||||
{"Upper boundary", "172.31.255.255", true}, // End of range
|
||||
{"Upper boundary + 1", "172.32.0.0", false}, // Just outside
|
||||
{"Upper boundary + 2", "172.32.0.1", false},
|
||||
|
||||
// Docker default bridge (172.17.0.0/16)
|
||||
{"Docker default bridge start", "172.17.0.0", true},
|
||||
{"Docker default bridge gateway", "172.17.0.1", true},
|
||||
{"Docker default bridge host", "172.17.0.2", true},
|
||||
{"Docker default bridge end", "172.17.255.255", true},
|
||||
|
||||
// Docker user-defined networks
|
||||
{"User network 1", "172.18.0.1", true},
|
||||
{"User network 2", "172.19.0.1", true},
|
||||
{"User network 30", "172.30.0.1", true},
|
||||
{"User network 31", "172.31.0.1", true},
|
||||
|
||||
// Edge of 172.x range
|
||||
{"172.0.0.1", "172.0.0.1", false}, // Below range
|
||||
{"172.15.0.1", "172.15.0.1", false}, // Below range
|
||||
{"172.32.0.1", "172.32.0.1", false}, // Above range
|
||||
{"172.255.255.255", "172.255.255.255", false}, // Above range
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsDockerBridgeIP(tt.host)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsDockerBridgeIP(%q) = %v, want %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
)
|
||||
|
||||
// GetConfiguredPublicURL returns the configured, normalized public URL.
|
||||
//
|
||||
// Security note:
|
||||
// This function intentionally never derives URLs from request data (Host/X-Forwarded-*),
|
||||
// so it is safe to use for embedding external links (e.g., invite emails).
|
||||
func GetConfiguredPublicURL(db *gorm.DB) (string, bool) {
|
||||
var setting models.Setting
|
||||
if err := db.Where("key = ?", "app.public_url").First(&setting).Error; err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
normalized, err := normalizeConfiguredPublicURL(setting.Value)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return normalized, true
|
||||
}
|
||||
|
||||
// GetPublicURL retrieves the configured public URL or falls back to request host.
|
||||
// This should be used for all user-facing URLs (emails, invite links).
|
||||
func GetPublicURL(db *gorm.DB, c *gin.Context) string {
|
||||
var setting models.Setting
|
||||
if err := db.Where("key = ?", "app.public_url").First(&setting).Error; err == nil {
|
||||
if setting.Value != "" {
|
||||
return strings.TrimSuffix(setting.Value, "/")
|
||||
}
|
||||
}
|
||||
// Fallback to request-derived URL
|
||||
return getBaseURL(c)
|
||||
}
|
||||
|
||||
func normalizeConfiguredPublicURL(raw string) (string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", errors.New("public URL is empty")
|
||||
}
|
||||
if strings.ContainsAny(raw, "\r\n") {
|
||||
return "", errors.New("public URL contains invalid characters")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", errors.New("public URL must use http or https")
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
return "", errors.New("public URL must include a host")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return "", errors.New("public URL must not include userinfo")
|
||||
}
|
||||
if parsed.RawQuery != "" || parsed.Fragment != "" {
|
||||
return "", errors.New("public URL must not include query or fragment")
|
||||
}
|
||||
if parsed.Path != "" && parsed.Path != "/" {
|
||||
return "", errors.New("public URL must not include a path")
|
||||
}
|
||||
if parsed.Opaque != "" {
|
||||
return "", errors.New("public URL must not be opaque")
|
||||
}
|
||||
|
||||
normalized := (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host}).String()
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// getBaseURL extracts the base URL from the request.
|
||||
func getBaseURL(c *gin.Context) string {
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check for X-Forwarded-Proto header
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
return scheme + "://" + c.Request.Host
|
||||
}
|
||||
|
||||
// ValidateURL validates that a URL is properly formatted for use as an application URL.
|
||||
// Returns error message if invalid, empty string if valid.
|
||||
func ValidateURL(rawURL string) (normalized, warning string, err error) {
|
||||
// Parse URL
|
||||
parsed, parseErr := url.Parse(rawURL)
|
||||
if parseErr != nil {
|
||||
return "", "", parseErr
|
||||
}
|
||||
|
||||
// Validate scheme
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", "", &url.Error{
|
||||
Op: "parse",
|
||||
URL: rawURL,
|
||||
Err: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if HTTP
|
||||
if parsed.Scheme == "http" {
|
||||
warning = "Using HTTP is not recommended. Consider using HTTPS for security."
|
||||
}
|
||||
|
||||
// Reject URLs with path components beyond "/"
|
||||
if parsed.Path != "" && parsed.Path != "/" {
|
||||
return "", "", &url.Error{
|
||||
Op: "validate",
|
||||
URL: rawURL,
|
||||
Err: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize URL (remove trailing slash, keep scheme and host)
|
||||
normalized = strings.TrimSuffix(rawURL, "/")
|
||||
|
||||
return normalized, warning, nil
|
||||
}
|
||||
@@ -1,426 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/network"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockTransport is a custom http.RoundTripper for testing that bypasses network calls
|
||||
type mockTransport struct {
|
||||
statusCode int
|
||||
headers http.Header
|
||||
body string
|
||||
err error
|
||||
handler http.HandlerFunc // For dynamic responses
|
||||
}
|
||||
|
||||
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
// Use handler if provided (for dynamic responses like redirects)
|
||||
if m.handler != nil {
|
||||
w := httptest.NewRecorder()
|
||||
m.handler(w, req)
|
||||
return w.Result(), nil
|
||||
}
|
||||
|
||||
// Static response
|
||||
resp := &http.Response{
|
||||
StatusCode: m.statusCode,
|
||||
Header: m.headers,
|
||||
Body: io.NopCloser(strings.NewReader(m.body)),
|
||||
Request: req,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_Success verifies that valid public URLs are reachable
|
||||
func TestTestURLConnectivity_Success(t *testing.T) {
|
||||
transport := &mockTransport{
|
||||
statusCode: http.StatusOK,
|
||||
headers: http.Header{"Content-Type": []string{"text/html"}},
|
||||
body: "",
|
||||
}
|
||||
|
||||
reachable, latency, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, reachable)
|
||||
assert.Greater(t, latency, 0.0, "latency should be positive")
|
||||
assert.Less(t, latency, 5000.0, "latency should be reasonable (< 5s)")
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_Redirect verifies redirect handling
|
||||
func TestTestURLConnectivity_Redirect(t *testing.T) {
|
||||
redirectCount := 0
|
||||
transport := &mockTransport{
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
redirectCount++
|
||||
// Only redirect once, then return OK
|
||||
if redirectCount == 1 {
|
||||
w.Header().Set("Location", "http://example.com/final")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
}
|
||||
|
||||
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, reachable)
|
||||
assert.LessOrEqual(t, redirectCount, 3, "should follow max 2 redirects")
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_TooManyRedirects verifies redirect limit enforcement
|
||||
func TestTestURLConnectivity_TooManyRedirects(t *testing.T) {
|
||||
transport := &mockTransport{
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", "/redirect")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
},
|
||||
}
|
||||
|
||||
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.False(t, reachable)
|
||||
assert.Contains(t, err.Error(), "redirect", "error should mention redirects")
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_StatusCodes verifies handling of different HTTP status codes
|
||||
func TestTestURLConnectivity_StatusCodes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expected bool
|
||||
}{
|
||||
{"200 OK", http.StatusOK, true},
|
||||
{"201 Created", http.StatusCreated, true},
|
||||
{"204 No Content", http.StatusNoContent, true},
|
||||
{"301 Moved Permanently", http.StatusMovedPermanently, true},
|
||||
{"302 Found", http.StatusFound, true},
|
||||
{"400 Bad Request", http.StatusBadRequest, false},
|
||||
{"401 Unauthorized", http.StatusUnauthorized, false},
|
||||
{"403 Forbidden", http.StatusForbidden, false},
|
||||
{"404 Not Found", http.StatusNotFound, false},
|
||||
{"500 Internal Server Error", http.StatusInternalServerError, false},
|
||||
{"503 Service Unavailable", http.StatusServiceUnavailable, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
transport := &mockTransport{
|
||||
statusCode: tc.statusCode,
|
||||
}
|
||||
|
||||
reachable, latency, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
if tc.expected {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, reachable)
|
||||
assert.Greater(t, latency, 0.0)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
assert.False(t, reachable)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("status %d", tc.statusCode))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_InvalidURL verifies invalid URL handling
|
||||
func TestTestURLConnectivity_InvalidURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"Empty URL", ""},
|
||||
{"Invalid scheme", "ftp://example.com"},
|
||||
{"Malformed URL", "http://[invalid"},
|
||||
{"No scheme", "example.com"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// No transport needed - these fail at URL parsing
|
||||
reachable, _, err := TestURLConnectivity(tc.url)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.False(t, reachable)
|
||||
// Latency varies depending on error type
|
||||
// Some errors may still measure time before failing
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_DNSFailure verifies DNS resolution error handling
|
||||
func TestTestURLConnectivity_DNSFailure(t *testing.T) {
|
||||
// Without transport, this will try real DNS and should fail
|
||||
reachable, _, err := TestURLConnectivity("http://nonexistent-domain-12345.invalid")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.False(t, reachable)
|
||||
assert.Contains(t, err.Error(), "DNS resolution failed", "error should mention DNS failure")
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_Timeout verifies timeout enforcement
|
||||
func TestTestURLConnectivity_Timeout(t *testing.T) {
|
||||
transport := &mockTransport{
|
||||
err: fmt.Errorf("context deadline exceeded"),
|
||||
}
|
||||
|
||||
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.False(t, reachable)
|
||||
assert.Contains(t, err.Error(), "connection failed", "error should mention connection failure")
|
||||
}
|
||||
|
||||
// TestIsPrivateIP_PrivateIPv4Ranges verifies blocking of private IPv4 ranges
|
||||
func TestIsPrivateIP_PrivateIPv4Ranges(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// RFC 1918 Private Networks
|
||||
{"10.0.0.0/8 start", "10.0.0.1", true},
|
||||
{"10.0.0.0/8 mid", "10.128.0.1", true},
|
||||
{"10.0.0.0/8 end", "10.255.255.254", true},
|
||||
{"172.16.0.0/12 start", "172.16.0.1", true},
|
||||
{"172.16.0.0/12 mid", "172.20.0.1", true},
|
||||
{"172.16.0.0/12 end", "172.31.255.254", true},
|
||||
{"192.168.0.0/16 start", "192.168.0.1", true},
|
||||
{"192.168.0.0/16 end", "192.168.255.254", true},
|
||||
|
||||
// Loopback
|
||||
{"127.0.0.1 localhost", "127.0.0.1", true},
|
||||
{"127.0.0.0/8 start", "127.0.0.0", true},
|
||||
{"127.0.0.0/8 end", "127.255.255.255", true},
|
||||
|
||||
// Link-Local (includes AWS/GCP metadata)
|
||||
{"169.254.0.0/16 start", "169.254.0.1", true},
|
||||
{"169.254.169.254 AWS metadata", "169.254.169.254", true},
|
||||
{"169.254.0.0/16 end", "169.254.255.254", true},
|
||||
|
||||
// Reserved ranges
|
||||
{"0.0.0.0/8", "0.0.0.1", true},
|
||||
{"240.0.0.0/4", "240.0.0.1", true},
|
||||
{"255.255.255.255 broadcast", "255.255.255.255", true},
|
||||
|
||||
// Public IPs (should NOT be blocked)
|
||||
{"8.8.8.8 Google DNS", "8.8.8.8", false},
|
||||
{"1.1.1.1 Cloudflare DNS", "1.1.1.1", false},
|
||||
{"93.184.216.34 example.com", "93.184.216.34", false},
|
||||
{"151.101.1.140 GitHub", "151.101.1.140", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
require.NotNil(t, ip, "IP should parse successfully")
|
||||
|
||||
result := network.IsPrivateIP(ip)
|
||||
assert.Equal(t, tc.expected, result,
|
||||
"IP %s should be private=%v", tc.ip, tc.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateIP_PrivateIPv6Ranges verifies blocking of private IPv6 ranges
|
||||
func TestIsPrivateIP_PrivateIPv6Ranges(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// IPv6 Loopback
|
||||
{"::1 loopback", "::1", true},
|
||||
|
||||
// IPv6 Link-Local
|
||||
{"fe80::/10 start", "fe80::1", true},
|
||||
{"fe80::/10 mid", "fe80:1234::5678", true},
|
||||
|
||||
// IPv6 Unique Local (RFC 4193)
|
||||
{"fc00::/7 start", "fc00::1", true},
|
||||
{"fc00::/7 mid", "fd12:3456:789a::1", true},
|
||||
|
||||
// Public IPv6 (should NOT be blocked)
|
||||
{"2001:4860:4860::8888 Google DNS", "2001:4860:4860::8888", false},
|
||||
{"2606:4700:4700::1111 Cloudflare DNS", "2606:4700:4700::1111", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
require.NotNil(t, ip, "IP should parse successfully")
|
||||
|
||||
result := network.IsPrivateIP(ip)
|
||||
assert.Equal(t, tc.expected, result,
|
||||
"IP %s should be private=%v", tc.ip, tc.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_PrivateIP_Blocked verifies SSRF protection
|
||||
func TestTestURLConnectivity_PrivateIP_Blocked(t *testing.T) {
|
||||
// Note: This test will fail if run on a system that actually resolves
|
||||
// these hostnames to private IPs. In a production test environment,
|
||||
// you might want to mock DNS resolution.
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"localhost", "http://localhost"},
|
||||
{"127.0.0.1", "http://127.0.0.1"},
|
||||
{"Private IP 10.x", "http://10.0.0.1"},
|
||||
{"Private IP 192.168.x", "http://192.168.1.1"},
|
||||
{"AWS metadata", "http://169.254.169.254"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reachable, _, err := TestURLConnectivity(tc.url)
|
||||
|
||||
// Should fail with private IP error
|
||||
assert.Error(t, err)
|
||||
assert.False(t, reachable)
|
||||
assert.Contains(t, err.Error(), "private IP", "error should mention private IP blocking")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_SSRF_Protection_Comprehensive performs comprehensive SSRF tests
|
||||
func TestTestURLConnectivity_SSRF_Protection_Comprehensive(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping comprehensive SSRF test in short mode")
|
||||
}
|
||||
|
||||
// Test various SSRF attack vectors
|
||||
attackVectors := []string{
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://0.0.0.0:8080",
|
||||
"http://[::1]:8080",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://metadata.google.internal/computeMetadata/v1/",
|
||||
}
|
||||
|
||||
for _, url := range attackVectors {
|
||||
t.Run(url, func(t *testing.T) {
|
||||
reachable, _, err := TestURLConnectivity(url)
|
||||
|
||||
// All should be blocked
|
||||
assert.Error(t, err, "SSRF attack vector should be blocked")
|
||||
assert.False(t, reachable)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_HTTPSSupport verifies HTTPS support
|
||||
func TestTestURLConnectivity_HTTPSSupport(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Note: This will likely fail due to self-signed cert in test server
|
||||
// but it demonstrates HTTPS support
|
||||
reachable, _, err := TestURLConnectivity(server.URL)
|
||||
|
||||
// May fail due to cert validation, but should not panic
|
||||
if err != nil {
|
||||
t.Logf("HTTPS test failed (expected with self-signed cert): %v", err)
|
||||
} else {
|
||||
assert.True(t, reachable)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTestURLConnectivity benchmarks the connectivity test
|
||||
func BenchmarkTestURLConnectivity(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = TestURLConnectivity(server.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkIsPrivateIP benchmarks private IP checking
|
||||
func BenchmarkIsPrivateIP(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = network.IsPrivateIP(ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_RedirectLimit_ProductionPath verifies the production
|
||||
// CheckRedirect callback enforces a maximum of 2 redirects (lines 93-97).
|
||||
// This is a critical security feature to prevent redirect-based attacks.
|
||||
func TestTestURLConnectivity_RedirectLimit_ProductionPath(t *testing.T) {
|
||||
redirectCount := 0
|
||||
// Use mock transport to bypass SSRF protection and test redirect limit specifically
|
||||
transport := &mockTransport{
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
redirectCount++
|
||||
if redirectCount <= 3 { // Try to redirect 3 times
|
||||
http.Redirect(w, r, "http://example.com/next", http.StatusFound)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Test with transport (will use CheckRedirect callback from production path)
|
||||
reachable, latency, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
// Should fail due to redirect limit
|
||||
assert.False(t, reachable)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "redirect", "error should mention redirects")
|
||||
assert.Greater(t, latency, 0.0, "should have some latency")
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_InvalidPortFormat tests error when URL has invalid port format.
|
||||
// This would trigger errors in net.SplitHostPort during dialing (lines 19-21).
|
||||
func TestTestURLConnectivity_InvalidPortFormat(t *testing.T) {
|
||||
// URL with invalid port will fail at URL parsing stage
|
||||
reachable, _, err := TestURLConnectivity("http://example.com:badport")
|
||||
|
||||
assert.False(t, reachable)
|
||||
assert.Error(t, err)
|
||||
// URL parsing will catch the invalid port before we even get to dialing
|
||||
assert.Contains(t, err.Error(), "invalid port")
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_EmptyDNSResult tests the empty DNS results
|
||||
// error path (lines 29-31).
|
||||
func TestTestURLConnectivity_EmptyDNSResult(t *testing.T) {
|
||||
// Create a custom transport that simulates empty DNS result
|
||||
transport := &mockTransport{
|
||||
err: fmt.Errorf("DNS resolution failed: no IP addresses found for host"),
|
||||
}
|
||||
|
||||
reachable, _, err := testURLConnectivity("http://localhost", withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
assert.False(t, reachable)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "connection failed")
|
||||
}
|
||||
@@ -1,478 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/models"
|
||||
)
|
||||
|
||||
// setupTestDB creates an in-memory SQLite database for testing
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
require.NoError(t, err, "failed to connect to test database")
|
||||
|
||||
// Auto-migrate the Setting model
|
||||
err = db.AutoMigrate(&models.Setting{})
|
||||
require.NoError(t, err, "failed to migrate database")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// TestGetPublicURL_WithConfiguredURL verifies retrieval of configured public URL
|
||||
func TestGetPublicURL_WithConfiguredURL(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
// Insert a configured public URL
|
||||
setting := models.Setting{
|
||||
Key: "app.public_url",
|
||||
Value: "https://example.com/",
|
||||
}
|
||||
err := db.Create(&setting).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test Gin context
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/test", http.NoBody)
|
||||
c.Request = req
|
||||
|
||||
// Test GetPublicURL
|
||||
publicURL := GetPublicURL(db, c)
|
||||
|
||||
// Should return configured URL with trailing slash removed
|
||||
assert.Equal(t, "https://example.com", publicURL)
|
||||
}
|
||||
|
||||
// TestGetPublicURL_WithTrailingSlash verifies trailing slash removal
|
||||
func TestGetPublicURL_WithTrailingSlash(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
// Insert URL with multiple trailing slashes
|
||||
setting := models.Setting{
|
||||
Key: "app.public_url",
|
||||
Value: "https://example.com///",
|
||||
}
|
||||
err := db.Create(&setting).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/test", http.NoBody)
|
||||
c.Request = req
|
||||
|
||||
publicURL := GetPublicURL(db, c)
|
||||
|
||||
// Should remove only the trailing slash (TrimSuffix removes one slash)
|
||||
assert.Equal(t, "https://example.com//", publicURL)
|
||||
}
|
||||
|
||||
// TestGetPublicURL_Fallback_HTTPSWithTLS verifies fallback to request URL with TLS
|
||||
func TestGetPublicURL_Fallback_HTTPSWithTLS(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
// No setting in DB - should fallback
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create request with TLS
|
||||
req := httptest.NewRequest(http.MethodGet, "https://myapp.com:8443/path", http.NoBody)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
c.Request = req
|
||||
|
||||
publicURL := GetPublicURL(db, c)
|
||||
|
||||
// Should detect TLS and use https
|
||||
assert.Equal(t, "https://myapp.com:8443", publicURL)
|
||||
}
|
||||
|
||||
// TestGetPublicURL_Fallback_HTTP verifies fallback to HTTP when no TLS
|
||||
func TestGetPublicURL_Fallback_HTTP(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/test", http.NoBody)
|
||||
c.Request = req
|
||||
|
||||
publicURL := GetPublicURL(db, c)
|
||||
|
||||
// Should use http scheme when no TLS
|
||||
assert.Equal(t, "http://localhost:8080", publicURL)
|
||||
}
|
||||
|
||||
// TestGetPublicURL_Fallback_XForwardedProto verifies X-Forwarded-Proto header handling
|
||||
func TestGetPublicURL_Fallback_XForwardedProto(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://internal-server:8080/test", http.NoBody)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
c.Request = req
|
||||
|
||||
publicURL := GetPublicURL(db, c)
|
||||
|
||||
// Should respect X-Forwarded-Proto header
|
||||
assert.Equal(t, "https://internal-server:8080", publicURL)
|
||||
}
|
||||
|
||||
// TestGetPublicURL_EmptyValue verifies behavior with empty setting value
|
||||
func TestGetPublicURL_EmptyValue(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
// Insert setting with empty value
|
||||
setting := models.Setting{
|
||||
Key: "app.public_url",
|
||||
Value: "",
|
||||
}
|
||||
err := db.Create(&setting).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:9000/test", http.NoBody)
|
||||
c.Request = req
|
||||
|
||||
publicURL := GetPublicURL(db, c)
|
||||
|
||||
// Should fallback to request URL when value is empty
|
||||
assert.Equal(t, "http://localhost:9000", publicURL)
|
||||
}
|
||||
|
||||
// TestGetPublicURL_NoSettingInDB verifies behavior when setting doesn't exist
|
||||
func TestGetPublicURL_NoSettingInDB(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
// No setting created - should fallback
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://fallback-host.com/test", http.NoBody)
|
||||
c.Request = req
|
||||
|
||||
publicURL := GetPublicURL(db, c)
|
||||
|
||||
// Should fallback to request host
|
||||
assert.Equal(t, "http://fallback-host.com", publicURL)
|
||||
}
|
||||
|
||||
// TestValidateURL_ValidHTTPS verifies validation of valid HTTPS URLs
|
||||
func TestValidateURL_ValidHTTPS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
normalized string
|
||||
}{
|
||||
{"HTTPS with trailing slash", "https://example.com/", "https://example.com"},
|
||||
{"HTTPS without path", "https://example.com", "https://example.com"},
|
||||
{"HTTPS with port", "https://example.com:8443", "https://example.com:8443"},
|
||||
{"HTTPS with subdomain", "https://app.example.com", "https://app.example.com"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
normalized, warning, err := ValidateURL(tc.url)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.normalized, normalized)
|
||||
assert.Empty(t, warning, "HTTPS should not produce warning")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateURL_ValidHTTP verifies validation of HTTP URLs with warning
|
||||
func TestValidateURL_ValidHTTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
normalized string
|
||||
}{
|
||||
{"HTTP with trailing slash", "http://example.com/", "http://example.com"},
|
||||
{"HTTP without path", "http://example.com", "http://example.com"},
|
||||
{"HTTP with port", "http://example.com:8080", "http://example.com:8080"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
normalized, warning, err := ValidateURL(tc.url)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.normalized, normalized)
|
||||
assert.NotEmpty(t, warning, "HTTP should produce security warning")
|
||||
assert.Contains(t, warning, "HTTP", "warning should mention HTTP")
|
||||
assert.Contains(t, warning, "HTTPS", "warning should suggest HTTPS")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateURL_InvalidScheme verifies rejection of non-HTTP/HTTPS schemes
|
||||
func TestValidateURL_InvalidScheme(t *testing.T) {
|
||||
testCases := []string{
|
||||
"ftp://example.com",
|
||||
"file:///etc/passwd",
|
||||
"javascript:alert(1)",
|
||||
"data:text/html,<script>alert(1)</script>",
|
||||
"ssh://user@host",
|
||||
}
|
||||
|
||||
for _, url := range testCases {
|
||||
t.Run(url, func(t *testing.T) {
|
||||
_, _, err := ValidateURL(url)
|
||||
|
||||
assert.Error(t, err, "non-HTTP(S) scheme should be rejected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateURL_WithPath verifies rejection of URLs with paths
|
||||
func TestValidateURL_WithPath(t *testing.T) {
|
||||
testCases := []string{
|
||||
"https://example.com/api/v1",
|
||||
"https://example.com/admin",
|
||||
"http://example.com/path/to/resource",
|
||||
"https://example.com/index.html",
|
||||
}
|
||||
|
||||
for _, url := range testCases {
|
||||
t.Run(url, func(t *testing.T) {
|
||||
_, _, err := ValidateURL(url)
|
||||
|
||||
assert.Error(t, err, "URL with path should be rejected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateURL_RootPathAllowed verifies "/" path is allowed
|
||||
func TestValidateURL_RootPathAllowed(t *testing.T) {
|
||||
testCases := []string{
|
||||
"https://example.com/",
|
||||
"http://example.com/",
|
||||
}
|
||||
|
||||
for _, url := range testCases {
|
||||
t.Run(url, func(t *testing.T) {
|
||||
normalized, _, err := ValidateURL(url)
|
||||
|
||||
assert.NoError(t, err, "root path '/' should be allowed")
|
||||
// Trailing slash should be removed
|
||||
assert.NotContains(t, normalized[len(normalized)-1:], "/", "normalized URL should not end with slash")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateURL_MalformedURL verifies handling of malformed URLs
|
||||
func TestValidateURL_MalformedURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
url string
|
||||
shouldFail bool
|
||||
}{
|
||||
{"not a url", true},
|
||||
{"://missing-scheme", true},
|
||||
{"http://", false}, // Valid URL with empty host - Parse accepts it
|
||||
{"https://[invalid", true},
|
||||
{"", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.url, func(t *testing.T) {
|
||||
_, _, err := ValidateURL(tc.url)
|
||||
|
||||
if tc.shouldFail {
|
||||
assert.Error(t, err, "malformed URL should be rejected")
|
||||
} else {
|
||||
// Some URLs that look malformed are actually valid per RFC
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateURL_SpecialCharacters verifies handling of special characters
|
||||
func TestValidateURL_SpecialCharacters(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
isValid bool
|
||||
}{
|
||||
{"Punycode domain", "https://xn--e1afmkfd.xn--p1ai", true},
|
||||
{"Port with special chars", "https://example.com:8080", true},
|
||||
{"Query string (no path component)", "https://example.com?query=1", true}, // Query strings have empty Path
|
||||
{"Fragment (no path component)", "https://example.com#section", true}, // Fragments have empty Path
|
||||
{"Userinfo", "https://user:pass@example.com", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, _, err := ValidateURL(tc.url)
|
||||
|
||||
if tc.isValid {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateURL_Normalization verifies URL normalization
|
||||
func TestValidateURL_Normalization(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expected string
|
||||
shouldFail bool
|
||||
}{
|
||||
{"https://EXAMPLE.COM", "https://EXAMPLE.COM", false}, // Case preserved
|
||||
{"https://example.com/", "https://example.com", false}, // Trailing slash removed
|
||||
{"https://example.com///", "", true}, // Multiple slashes = path component, should fail
|
||||
{"http://example.com:80", "http://example.com:80", false}, // Port preserved
|
||||
{"https://example.com:443", "https://example.com:443", false}, // Default HTTPS port preserved
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
normalized, _, err := ValidateURL(tc.input)
|
||||
|
||||
if tc.shouldFail {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, normalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetBaseURL verifies base URL extraction from request
|
||||
func TestGetBaseURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
hasTLS bool
|
||||
xForwardedProto string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "HTTPS with TLS",
|
||||
host: "secure.example.com",
|
||||
hasTLS: true,
|
||||
expected: "https://secure.example.com",
|
||||
},
|
||||
{
|
||||
name: "HTTP without TLS",
|
||||
host: "insecure.example.com",
|
||||
hasTLS: false,
|
||||
expected: "http://insecure.example.com",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto HTTPS",
|
||||
host: "behind-proxy.com",
|
||||
hasTLS: false,
|
||||
xForwardedProto: "https",
|
||||
expected: "https://behind-proxy.com",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto HTTP",
|
||||
host: "behind-proxy.com",
|
||||
hasTLS: false,
|
||||
xForwardedProto: "http",
|
||||
expected: "http://behind-proxy.com",
|
||||
},
|
||||
{
|
||||
name: "With port",
|
||||
host: "example.com:8080",
|
||||
hasTLS: false,
|
||||
expected: "http://example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "IPv4 host",
|
||||
host: "192.168.1.1:8080",
|
||||
hasTLS: false,
|
||||
expected: "http://192.168.1.1:8080",
|
||||
},
|
||||
{
|
||||
name: "IPv6 host",
|
||||
host: "[::1]:8080",
|
||||
hasTLS: false,
|
||||
expected: "http://[::1]:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Build request URL
|
||||
scheme := "http"
|
||||
if tc.hasTLS {
|
||||
scheme = "https"
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, scheme+"://"+tc.host+"/test", http.NoBody)
|
||||
|
||||
// Set TLS if needed
|
||||
if tc.hasTLS {
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
|
||||
// Set X-Forwarded-Proto if specified
|
||||
if tc.xForwardedProto != "" {
|
||||
req.Header.Set("X-Forwarded-Proto", tc.xForwardedProto)
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
|
||||
baseURL := getBaseURL(c)
|
||||
|
||||
assert.Equal(t, tc.expected, baseURL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetBaseURL_PrecedenceOrder verifies header precedence
|
||||
func TestGetBaseURL_PrecedenceOrder(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Request with TLS but also X-Forwarded-Proto
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/test", http.NoBody)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req.Header.Set("X-Forwarded-Proto", "http") // Should be ignored when TLS is present
|
||||
c.Request = req
|
||||
|
||||
baseURL := getBaseURL(c)
|
||||
|
||||
// TLS should take precedence over header
|
||||
assert.Equal(t, "https://example.com", baseURL)
|
||||
}
|
||||
|
||||
// TestGetBaseURL_EmptyHost verifies behavior with empty host
|
||||
func TestGetBaseURL_EmptyHost(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "http:///test", http.NoBody)
|
||||
req.Host = "" // Empty host
|
||||
c.Request = req
|
||||
|
||||
baseURL := getBaseURL(c)
|
||||
|
||||
// Should still return valid URL with empty host
|
||||
assert.Equal(t, "http://", baseURL)
|
||||
}
|
||||
@@ -1,443 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/Wikid82/charon/backend/internal/metrics"
|
||||
"github.com/Wikid82/charon/backend/internal/network"
|
||||
"github.com/Wikid82/charon/backend/internal/security"
|
||||
)
|
||||
|
||||
func resolveAllowedIP(ctx context.Context, host string, allowLocalhost bool) (net.IP, error) {
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("missing hostname")
|
||||
}
|
||||
|
||||
// Fast-path: IP literal.
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if allowLocalhost && ip.IsLoopback() {
|
||||
return ip, nil
|
||||
}
|
||||
if network.IsPrivateIP(ip) {
|
||||
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip)
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS resolution failed: %w", err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IP addresses found for host")
|
||||
}
|
||||
|
||||
var selected net.IP
|
||||
for _, ip := range ips {
|
||||
if allowLocalhost && ip.IP.IsLoopback() {
|
||||
if selected == nil {
|
||||
selected = ip.IP
|
||||
}
|
||||
continue
|
||||
}
|
||||
if network.IsPrivateIP(ip.IP) {
|
||||
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP)
|
||||
}
|
||||
if selected == nil {
|
||||
selected = ip.IP
|
||||
}
|
||||
}
|
||||
|
||||
if selected == nil {
|
||||
return nil, fmt.Errorf("no allowed IP addresses found for host")
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// ssrfSafeDialer creates a custom dialer that validates IP addresses at connection time.
|
||||
// This prevents DNS rebinding attacks by validating the IP just before connecting.
|
||||
// Returns a DialContext function suitable for use in http.Transport.
|
||||
func ssrfSafeDialer() func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
// Parse host and port from address
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address format: %w", err)
|
||||
}
|
||||
|
||||
// Resolve DNS with context timeout
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS resolution failed: %w", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IP addresses found for host")
|
||||
}
|
||||
|
||||
// Validate ALL resolved IPs - if any are private, reject immediately
|
||||
// Using centralized network.IsPrivateIP for consistent SSRF protection
|
||||
for _, ip := range ips {
|
||||
if network.IsPrivateIP(ip.IP) {
|
||||
return nil, fmt.Errorf("access to private IP addresses is blocked (resolved to %s)", ip.IP)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to the first valid IP (prevents DNS rebinding)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
return dialer.DialContext(ctx, netw, net.JoinHostPort(ips[0].IP.String(), port))
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Redirect validation is implemented by validateRedirectTargetStrict.
|
||||
|
||||
// TestURLConnectivity performs a server-side connectivity test with SSRF protection.
|
||||
// For testing purposes, an optional http.RoundTripper can be provided to bypass
|
||||
// DNS resolution and network calls.
|
||||
// Returns:
|
||||
// - reachable: true if URL returned 2xx-3xx status
|
||||
// - latency: round-trip time in milliseconds
|
||||
// - error: validation or connectivity error
|
||||
func TestURLConnectivity(rawURL string) (reachable bool, latency float64, err error) {
|
||||
// NOTE: This wrapper preserves the exported API while enforcing
|
||||
// deny-by-default SSRF-safe behavior.
|
||||
//
|
||||
// Do not add optional transports/options to the exported surface.
|
||||
// Tests can exercise alternative paths via unexported helpers.
|
||||
return testURLConnectivity(rawURL)
|
||||
}
|
||||
|
||||
type urlConnectivityOptions struct {
|
||||
transport http.RoundTripper
|
||||
allowLocalhost bool
|
||||
}
|
||||
|
||||
type urlConnectivityOption func(*urlConnectivityOptions)
|
||||
|
||||
//nolint:unused // Used in test files
|
||||
func withTransportForTesting(rt http.RoundTripper) urlConnectivityOption {
|
||||
return func(o *urlConnectivityOptions) {
|
||||
o.transport = rt
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:unused // Used in test files
|
||||
func withAllowLocalhostForTesting() urlConnectivityOption {
|
||||
return func(o *urlConnectivityOptions) {
|
||||
o.allowLocalhost = true
|
||||
}
|
||||
}
|
||||
|
||||
// testURLConnectivity is the implementation behind TestURLConnectivity.
|
||||
// It supports internal (package-only) hooks to keep unit tests deterministic
|
||||
// without weakening production defaults.
|
||||
func testURLConnectivity(rawURL string, opts ...urlConnectivityOption) (reachable bool, latency float64, err error) {
|
||||
// Track start time for metrics
|
||||
startTime := time.Now()
|
||||
|
||||
// Generate unique request ID for tracing
|
||||
requestID := fmt.Sprintf("test-%d", time.Now().UnixNano())
|
||||
|
||||
options := urlConnectivityOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
isTestMode := options.transport != nil
|
||||
|
||||
// Parse URL first to validate structure
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
// ENHANCEMENT: Record validation failure metric
|
||||
metrics.RecordURLValidation("error", "invalid_format")
|
||||
// ENHANCEMENT: Audit log the failed validation
|
||||
if !isTestMode {
|
||||
security.LogURLTest(rawURL, requestID, "system", "", "error")
|
||||
}
|
||||
return false, 0, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Validate scheme
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
// ENHANCEMENT: Record validation failure metric
|
||||
metrics.RecordURLValidation("error", "unsupported_scheme")
|
||||
// ENHANCEMENT: Audit log the failed validation
|
||||
if !isTestMode {
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
|
||||
}
|
||||
return false, 0, fmt.Errorf("only http and https schemes are allowed")
|
||||
}
|
||||
|
||||
// Reject URLs containing userinfo (username:password@host)
|
||||
// This is checked again by security.ValidateExternalURL, but we keep it here
|
||||
// to ensure consistent behavior across all call paths.
|
||||
if parsed.User != nil {
|
||||
metrics.RecordURLValidation("error", "userinfo_not_allowed")
|
||||
if !isTestMode {
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
|
||||
}
|
||||
return false, 0, fmt.Errorf("urls with embedded credentials are not allowed")
|
||||
}
|
||||
|
||||
// CRITICAL: Always validate the URL through centralized SSRF protection.
|
||||
//
|
||||
// Production defaults are deny-by-default:
|
||||
// - Only http/https allowed
|
||||
// - No localhost
|
||||
// - No private/reserved IPs
|
||||
//
|
||||
// Tests may opt-in to localhost via withAllowLocalhostForTesting(), without
|
||||
// weakening production behavior.
|
||||
validationOpts := []security.ValidationOption{
|
||||
security.WithAllowHTTP(),
|
||||
}
|
||||
if options.allowLocalhost {
|
||||
validationOpts = append(validationOpts, security.WithAllowLocalhost())
|
||||
}
|
||||
|
||||
validatedURL, err := security.ValidateExternalURL(rawURL, validationOpts...)
|
||||
if err != nil {
|
||||
// ENHANCEMENT: Record SSRF block metrics
|
||||
// Determine the block reason from error message
|
||||
errMsg := err.Error()
|
||||
var blockReason string
|
||||
switch {
|
||||
case strings.Contains(errMsg, "private ip"):
|
||||
blockReason = "private_ip"
|
||||
metrics.RecordSSRFBlock("private", "system") // userID should come from context in production
|
||||
// ENHANCEMENT: Audit log the SSRF block
|
||||
security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "")
|
||||
case strings.Contains(errMsg, "cloud metadata"):
|
||||
blockReason = "metadata_endpoint"
|
||||
metrics.RecordSSRFBlock("metadata", "system")
|
||||
// ENHANCEMENT: Audit log the SSRF block
|
||||
security.LogSSRFBlock(parsed.Hostname(), nil, blockReason, "system", "")
|
||||
case strings.Contains(errMsg, "dns resolution"):
|
||||
blockReason = "dns_failed"
|
||||
// ENHANCEMENT: Audit log the DNS failure
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "error")
|
||||
default:
|
||||
blockReason = "validation_failed"
|
||||
// ENHANCEMENT: Audit log the validation failure
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "blocked")
|
||||
}
|
||||
metrics.RecordURLValidation("blocked", blockReason)
|
||||
|
||||
// Transform error message for backward compatibility with existing tests
|
||||
// The security package uses lowercase in error messages, but tests expect mixed case
|
||||
errMsg = strings.Replace(errMsg, "dns resolution failed", "DNS resolution failed", 1)
|
||||
errMsg = strings.ReplaceAll(errMsg, "private ip", "private IP")
|
||||
// Cloud metadata endpoints are considered private IPs for test compatibility
|
||||
if strings.Contains(errMsg, "cloud metadata endpoints") {
|
||||
errMsg = strings.Replace(errMsg, "access to cloud metadata endpoints is blocked for security", "connection to private IP addresses is blocked for security", 1)
|
||||
}
|
||||
return false, 0, fmt.Errorf("security validation failed: %s", errMsg)
|
||||
}
|
||||
// ENHANCEMENT: Record successful validation
|
||||
metrics.RecordURLValidation("allowed", "validated")
|
||||
// ENHANCEMENT: Audit log successful validation (only in production to avoid test noise)
|
||||
if !isTestMode {
|
||||
security.LogURLTest(parsed.Hostname(), requestID, "system", "", "allowed")
|
||||
}
|
||||
|
||||
// Use validated URL for requests (breaks taint chain)
|
||||
validatedRequestURL := validatedURL
|
||||
|
||||
const (
|
||||
requestTimeout = 5 * time.Second
|
||||
maxRedirects = 2
|
||||
allowHTTPSUpgrade = true
|
||||
)
|
||||
|
||||
transport := &http.Transport{
|
||||
// Explicitly ignore proxy environment variables for SSRF-sensitive requests.
|
||||
Proxy: nil,
|
||||
DialContext: ssrfSafeDialer(),
|
||||
MaxIdleConns: 1,
|
||||
IdleConnTimeout: requestTimeout,
|
||||
TLSHandshakeTimeout: requestTimeout,
|
||||
ResponseHeaderTimeout: requestTimeout,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
|
||||
if isTestMode {
|
||||
// Test-only override: allows deterministic unit tests without real network.
|
||||
transport = &http.Transport{
|
||||
Proxy: nil,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
// If the provided transport is an http.RoundTripper that is not an *http.Transport,
|
||||
// use it directly.
|
||||
}
|
||||
|
||||
var rt http.RoundTripper = transport
|
||||
if isTestMode {
|
||||
rt = options.transport
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: requestTimeout,
|
||||
Transport: rt,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return validateRedirectTargetStrict(req, via, maxRedirects, allowHTTPSUpgrade, options.allowLocalhost)
|
||||
},
|
||||
}
|
||||
|
||||
// Perform HTTP HEAD request with strict timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
|
||||
// Parse the validated URL to construct request from validated components
|
||||
// This breaks the taint chain for static analysis by using parsed URL components
|
||||
validatedParsed, err := url.Parse(validatedRequestURL)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to parse validated URL: %w", err)
|
||||
}
|
||||
|
||||
// Normalize scheme to a constant value derived from an allowlisted set.
|
||||
// This avoids propagating the original input string into request construction.
|
||||
var safeScheme string
|
||||
switch validatedParsed.Scheme {
|
||||
case "http":
|
||||
safeScheme = "http"
|
||||
case "https":
|
||||
safeScheme = "https"
|
||||
default:
|
||||
return false, 0, fmt.Errorf("security validation failed: unsupported scheme")
|
||||
}
|
||||
|
||||
// If we connect to an IP-literal for HTTPS, ensure TLS SNI still uses the hostname.
|
||||
if !isTestMode && safeScheme == "https" {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: validatedParsed.Hostname(),
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve to a concrete, allowed IP for the outbound request URL.
|
||||
// We still preserve the original hostname via Host header and TLS SNI.
|
||||
selectedIP, err := resolveAllowedIP(ctx, validatedParsed.Hostname(), options.allowLocalhost)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("security validation failed: %s", err.Error())
|
||||
}
|
||||
|
||||
port := validatedParsed.Port()
|
||||
if port == "" {
|
||||
if safeScheme == "https" {
|
||||
port = "443"
|
||||
} else {
|
||||
port = "80"
|
||||
}
|
||||
} else {
|
||||
p, convErr := strconv.Atoi(port)
|
||||
if convErr != nil || p < 1 || p > 65535 {
|
||||
return false, 0, fmt.Errorf("security validation failed: invalid port")
|
||||
}
|
||||
port = strconv.Itoa(p)
|
||||
}
|
||||
|
||||
// Construct a request URL from SSRF-safe components.
|
||||
// - Host is a resolved IP (selectedIP) to avoid hostname-based SSRF bypass
|
||||
// - Path is fixed to "/" because this is a connectivity test
|
||||
// nosemgrep: go.lang.security.audit.net.use-tls.use-tls
|
||||
safeURL := &url.URL{
|
||||
Scheme: safeScheme,
|
||||
Host: net.JoinHostPort(selectedIP.String(), port),
|
||||
Path: "/",
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, safeURL.String(), http.NoBody)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Preserve the original hostname in Host header for virtual-hosting.
|
||||
// This does not affect the destination IP (selectedIP), which is used in the URL.
|
||||
req.Host = validatedParsed.Host
|
||||
|
||||
// Add custom User-Agent header
|
||||
req.Header.Set("User-Agent", "Charon-Health-Check/1.0")
|
||||
|
||||
// ENHANCEMENT: Request Tracing Headers
|
||||
// These headers help track and identify URL test requests in logs
|
||||
req.Header.Set("X-Charon-Request-Type", "url-connectivity-test")
|
||||
req.Header.Set("X-Request-ID", requestID) // Use consistent request ID for tracing
|
||||
|
||||
// SSRF Protection Summary:
|
||||
// This HTTP request is protected against SSRF by multiple defense layers:
|
||||
// 1. security.ValidateExternalURL() validates URL format, scheme, and performs
|
||||
// DNS resolution with private IP blocking (RFC 1918, loopback, link-local, metadata)
|
||||
// 2. ssrfSafeDialer() re-validates IPs at connection time (prevents DNS rebinding/TOCTOU)
|
||||
// 3. validateRedirectTarget() validates all redirect URLs in production
|
||||
// 4. safeURL is constructed from parsed/validated components (breaks taint chain)
|
||||
// See: internal/security/url_validator.go, internal/network/safeclient.go
|
||||
resp, err := client.Do(req) //nolint:bodyclose // Body closed via defer below
|
||||
latency = time.Since(start).Seconds() * 1000 // Convert to milliseconds
|
||||
|
||||
// ENHANCEMENT: Record test duration metric (only in production to avoid test noise)
|
||||
if !isTestMode {
|
||||
durationSeconds := time.Since(startTime).Seconds()
|
||||
metrics.RecordURLTestDuration(durationSeconds)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, latency, fmt.Errorf("connection failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
logger.Log().WithError(err).Warn("Failed to close response body")
|
||||
}
|
||||
}()
|
||||
|
||||
// Accept 2xx and 3xx status codes as "reachable"
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
|
||||
return true, latency, nil
|
||||
}
|
||||
|
||||
return false, latency, fmt.Errorf("server returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// validateRedirectTargetStrict validates HTTP redirects with SSRF protection.
|
||||
// It enforces:
|
||||
// - a hard redirect limit
|
||||
// - per-hop URL validation (scheme, userinfo, host, DNS, private/reserved IPs)
|
||||
// - scheme-change policy (deny by default; optionally allow http->https upgrade)
|
||||
func validateRedirectTargetStrict(req *http.Request, via []*http.Request, maxRedirects int, allowHTTPSUpgrade bool, allowLocalhost bool) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("too many redirects (max %d)", maxRedirects)
|
||||
}
|
||||
|
||||
if len(via) > 0 {
|
||||
prevScheme := via[len(via)-1].URL.Scheme
|
||||
newScheme := req.URL.Scheme
|
||||
if newScheme != prevScheme {
|
||||
if !allowHTTPSUpgrade || prevScheme != "http" || newScheme != "https" {
|
||||
return fmt.Errorf("redirect scheme change blocked: %s -> %s", prevScheme, newScheme)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validationOpts := []security.ValidationOption{security.WithAllowHTTP(), security.WithTimeout(3 * time.Second)}
|
||||
if allowLocalhost {
|
||||
validationOpts = append(validationOpts, security.WithAllowLocalhost())
|
||||
}
|
||||
|
||||
_, err := security.ValidateExternalURL(req.URL.String(), validationOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("redirect target validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,489 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTestURLConnectivity_EnhancedSSRF tests additional SSRF attack vectors.
|
||||
// ENHANCEMENT: Comprehensive SSRF testing per Supervisor review
|
||||
func TestTestURLConnectivity_EnhancedSSRF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
shouldFail bool
|
||||
errContains string
|
||||
}{
|
||||
// Cloud Metadata Endpoints
|
||||
{
|
||||
name: "AWS metadata endpoint",
|
||||
url: "http://169.254.169.254/latest/meta-data/",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
{
|
||||
name: "GCP metadata endpoint",
|
||||
url: "http://metadata.google.internal/computeMetadata/v1/",
|
||||
shouldFail: true,
|
||||
errContains: "DNS resolution failed", // Will fail to resolve
|
||||
},
|
||||
{
|
||||
name: "Azure metadata endpoint",
|
||||
url: "http://169.254.169.254/metadata/instance",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
|
||||
// RFC 1918 Private Networks
|
||||
{
|
||||
name: "Private 10.0.0.0/8",
|
||||
url: "http://10.0.0.1/admin",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
{
|
||||
name: "Private 172.16.0.0/12",
|
||||
url: "http://172.16.0.1/admin",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
{
|
||||
name: "Private 192.168.0.0/16",
|
||||
url: "http://192.168.1.1/admin",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
|
||||
// Loopback Addresses
|
||||
{
|
||||
name: "IPv4 loopback",
|
||||
url: "http://127.0.0.1:6379/",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
url: "http://[::1]:6379/",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
|
||||
// Link-Local Addresses
|
||||
{
|
||||
name: "Link-local IPv4",
|
||||
url: "http://169.254.1.1/",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
{
|
||||
name: "Link-local IPv6",
|
||||
url: "http://[fe80::1]/",
|
||||
shouldFail: true,
|
||||
errContains: "private IP",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reachable, _, err := TestURLConnectivity(tt.url)
|
||||
|
||||
if tt.shouldFail {
|
||||
if err == nil {
|
||||
t.Errorf("Expected test to fail, but it succeeded")
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Expected error containing '%s', got: %s", tt.errContains, err.Error())
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false, got true")
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("Expected test to succeed, but got error: %s", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_RedirectValidation tests that redirects are properly validated.
|
||||
// ENHANCEMENT: Critical test per Supervisor review - all redirects must be validated
|
||||
func TestTestURLConnectivity_RedirectValidation(t *testing.T) {
|
||||
// Create test servers
|
||||
privateServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Private server"))
|
||||
}))
|
||||
defer privateServer.Close()
|
||||
|
||||
t.Run("Redirect to private IP should be blocked", func(t *testing.T) {
|
||||
// Note: This test requires actual redirect validation to work
|
||||
// The validateRedirectTarget function will validate the Location header
|
||||
// For now, we skip this test as it requires complex mock setup
|
||||
t.Skip("Redirect validation test requires complex HTTP client mocking")
|
||||
})
|
||||
|
||||
t.Run("Too many redirects should be blocked", func(t *testing.T) {
|
||||
// Create a server that redirects to itself multiple times
|
||||
redirectCount := 0
|
||||
var redirectServerURL string
|
||||
redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
redirectCount++
|
||||
if redirectCount < 5 {
|
||||
http.Redirect(w, r, redirectServerURL+fmt.Sprintf("/%d", redirectCount), http.StatusFound)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}))
|
||||
defer redirectServer.Close()
|
||||
redirectServerURL = redirectServer.URL
|
||||
|
||||
transport := redirectServer.Client().Transport
|
||||
reachable, _, err := testURLConnectivity(redirectServerURL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
// Should fail due to too many redirects (max 2)
|
||||
if err == nil {
|
||||
t.Errorf("Expected too many redirects to fail, but it succeeded")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "redirect") {
|
||||
t.Errorf("Expected error about redirects, got: %s", err.Error())
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false, got true")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_UnicodeHomograph tests Unicode homograph attack prevention.
|
||||
// ENHANCEMENT: Tests for internationalized domain name attacks
|
||||
func TestTestURLConnectivity_UnicodeHomograph(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{
|
||||
name: "Cyrillic homograph",
|
||||
url: "https://gооgle.com", // Uses Cyrillic 'о' instead of Latin 'o'
|
||||
},
|
||||
{
|
||||
name: "Mixed script attack",
|
||||
url: "https://раypal.com", // Uses Cyrillic 'а' and 'y'
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// These should fail DNS resolution since they're not real domains
|
||||
reachable, _, err := TestURLConnectivity(tt.url)
|
||||
if err == nil {
|
||||
t.Logf("Warning: homograph domain %s resolved - may indicate IDN issue", tt.url)
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Homograph domain %s appears reachable - security concern", tt.url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_LongHostname tests extremely long hostname handling.
|
||||
// ENHANCEMENT: DoS prevention via hostname length validation
|
||||
func TestTestURLConnectivity_LongHostname(t *testing.T) {
|
||||
// Create hostname exceeding RFC 1035 limit (253 chars)
|
||||
longHostname := strings.Repeat("a", 254) + ".com"
|
||||
url := "https://" + longHostname + "/path"
|
||||
|
||||
reachable, _, err := TestURLConnectivity(url)
|
||||
if err == nil {
|
||||
t.Errorf("Expected long hostname to be rejected, but it was accepted")
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false for long hostname, got true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_RequestTracingHeaders tests that tracing headers are added.
|
||||
// ENHANCEMENT: Verifies request tracing per Supervisor review
|
||||
func TestTestURLConnectivity_RequestTracingHeaders(t *testing.T) {
|
||||
var capturedHeaders http.Header
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
transport := testServer.Client().Transport
|
||||
_, _, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// Verify headers were set
|
||||
if capturedHeaders.Get("User-Agent") != "Charon-Health-Check/1.0" {
|
||||
t.Errorf("Expected User-Agent header, got: %s", capturedHeaders.Get("User-Agent"))
|
||||
}
|
||||
if capturedHeaders.Get("X-Charon-Request-Type") != "url-connectivity-test" {
|
||||
t.Errorf("Expected X-Charon-Request-Type header, got: %s", capturedHeaders.Get("X-Charon-Request-Type"))
|
||||
}
|
||||
if capturedHeaders.Get("X-Request-ID") == "" {
|
||||
t.Errorf("Expected X-Request-ID header to be set")
|
||||
}
|
||||
if !strings.HasPrefix(capturedHeaders.Get("X-Request-ID"), "test-") {
|
||||
t.Errorf("Expected X-Request-ID to start with 'test-', got: %s", capturedHeaders.Get("X-Request-ID"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_MetricsIntegration tests that metrics are recorded.
|
||||
// ENHANCEMENT: Validates metrics collection per Supervisor review
|
||||
func TestTestURLConnectivity_MetricsIntegration(t *testing.T) {
|
||||
// This test verifies that metrics functions are called
|
||||
// Full metrics validation requires integration tests with Prometheus
|
||||
|
||||
t.Run("Valid URL records metrics", func(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
transport := testServer.Client().Transport
|
||||
reachable, latency, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %s", err)
|
||||
}
|
||||
if !reachable {
|
||||
t.Errorf("Expected reachable=true, got false")
|
||||
}
|
||||
if latency <= 0 {
|
||||
t.Errorf("Expected positive latency, got %f", latency)
|
||||
}
|
||||
// Metrics recorded: URLValidation (allowed), URLTestDuration
|
||||
})
|
||||
|
||||
t.Run("Blocked URL records metrics", func(t *testing.T) {
|
||||
reachable, _, err := TestURLConnectivity("http://127.0.0.1:6379/")
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected private IP to be blocked")
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false, got true")
|
||||
}
|
||||
// Metrics recorded: SSRFBlock, URLValidation (blocked)
|
||||
})
|
||||
|
||||
t.Run("Invalid URL records metrics", func(t *testing.T) {
|
||||
reachable, _, err := TestURLConnectivity("not-a-valid-url")
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected invalid URL to fail")
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false, got true")
|
||||
}
|
||||
// Metrics recorded: URLValidation (error, unsupported_scheme)
|
||||
})
|
||||
}
|
||||
|
||||
// TestValidateRedirectTarget tests the redirect validation function directly.
|
||||
// ENHANCEMENT: Direct unit tests for redirect validation per Phase 2 requirements
|
||||
func TestValidateRedirectTarget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
viaCount int
|
||||
shouldErr bool
|
||||
errContains string
|
||||
viaURL string
|
||||
allowLocal bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost redirect allowed",
|
||||
url: "http://localhost/path",
|
||||
viaCount: 0,
|
||||
shouldErr: false,
|
||||
allowLocal: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 redirect allowed",
|
||||
url: "http://127.0.0.1:8080/path",
|
||||
viaCount: 0,
|
||||
shouldErr: false,
|
||||
allowLocal: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback allowed",
|
||||
url: "http://[::1]:8080/path",
|
||||
viaCount: 0,
|
||||
shouldErr: false,
|
||||
allowLocal: true,
|
||||
},
|
||||
{
|
||||
name: "Too many redirects",
|
||||
url: "http://localhost/path",
|
||||
viaCount: 2,
|
||||
shouldErr: true,
|
||||
errContains: "too many redirects",
|
||||
allowLocal: true,
|
||||
},
|
||||
{
|
||||
name: "Three redirects",
|
||||
url: "http://localhost/path",
|
||||
viaCount: 3,
|
||||
shouldErr: true,
|
||||
errContains: "too many redirects",
|
||||
allowLocal: true,
|
||||
},
|
||||
{
|
||||
name: "Scheme downgrade blocked (https -> http)",
|
||||
url: "http://localhost/next",
|
||||
viaURL: "https://localhost/start",
|
||||
viaCount: 1,
|
||||
shouldErr: true,
|
||||
errContains: "scheme change blocked",
|
||||
allowLocal: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a request for the redirect target
|
||||
req, err := http.NewRequest("GET", tt.url, http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Create via slice (previous requests)
|
||||
via := make([]*http.Request, 0, tt.viaCount)
|
||||
if tt.viaCount > 0 {
|
||||
viaURL := tt.viaURL
|
||||
if viaURL == "" {
|
||||
viaURL = "http://localhost/prev"
|
||||
}
|
||||
prevReq, prevErr := http.NewRequest("GET", viaURL, http.NoBody)
|
||||
if prevErr != nil {
|
||||
t.Fatalf("Failed to create via request: %v", prevErr)
|
||||
}
|
||||
via = append(via, prevReq)
|
||||
for i := 1; i < tt.viaCount; i++ {
|
||||
via = append(via, prevReq)
|
||||
}
|
||||
}
|
||||
|
||||
err = validateRedirectTargetStrict(req, via, 2, true, tt.allowLocal)
|
||||
|
||||
if tt.shouldErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
} else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Expected error containing '%s', got: %s", tt.errContains, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_AuditLogging tests that audit logging is integrated.
|
||||
// ENHANCEMENT: Validates audit logging integration per Phase 3 requirements
|
||||
func TestTestURLConnectivity_AuditLogging(t *testing.T) {
|
||||
// Note: In production, audit logs go to the configured logger.
|
||||
// These tests verify the code paths execute without panic.
|
||||
|
||||
t.Run("Invalid URL format logs audit event", func(t *testing.T) {
|
||||
// This should trigger audit logging for invalid format
|
||||
reachable, _, err := TestURLConnectivity("://invalid")
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid URL")
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid scheme logs audit event", func(t *testing.T) {
|
||||
// This should trigger audit logging for unsupported scheme
|
||||
reachable, _, err := TestURLConnectivity("ftp://example.com/file")
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for unsupported scheme")
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Private IP logs SSRF block audit event", func(t *testing.T) {
|
||||
// This should trigger SSRF block audit logging
|
||||
reachable, _, err := TestURLConnectivity("http://10.0.0.1/admin")
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for private IP")
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Metadata endpoint logs SSRF block audit event", func(t *testing.T) {
|
||||
// This should trigger metadata endpoint block audit logging
|
||||
reachable, _, err := TestURLConnectivity("http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for metadata endpoint")
|
||||
}
|
||||
if reachable {
|
||||
t.Errorf("Expected reachable=false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid URL with mock transport logs success", func(t *testing.T) {
|
||||
// Create a test server
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
// Note: With mock transport, audit logging is skipped (isTestMode=true)
|
||||
// This test verifies the code path doesn't panic
|
||||
transport := testServer.Client().Transport
|
||||
reachable, _, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %s", err)
|
||||
}
|
||||
if !reachable {
|
||||
t.Errorf("Expected reachable=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_RequestIDConsistency tests that request ID is consistent.
|
||||
// ENHANCEMENT: Validates request tracing per Phase 3 requirements
|
||||
func TestTestURLConnectivity_RequestIDConsistency(t *testing.T) {
|
||||
var capturedRequestID string
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedRequestID = r.Header.Get("X-Request-ID")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
transport := testServer.Client().Transport
|
||||
_, _, err := testURLConnectivity(testServer.URL, withAllowLocalhostForTesting(), withTransportForTesting(transport))
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %s", err)
|
||||
}
|
||||
|
||||
if capturedRequestID == "" {
|
||||
t.Error("X-Request-ID header was not set")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(capturedRequestID, "test-") {
|
||||
t.Errorf("X-Request-ID should start with 'test-', got: %s", capturedRequestID)
|
||||
}
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestResolveAllowedIP_EmptyHostname tests resolveAllowedIP with empty hostname.
|
||||
func TestResolveAllowedIP_EmptyHostname(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := resolveAllowedIP(ctx, "", false)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty hostname, got nil")
|
||||
}
|
||||
if err.Error() != "missing hostname" {
|
||||
t.Errorf("Expected 'missing hostname', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_LoopbackIPLiteral tests resolveAllowedIP with loopback IPs.
|
||||
func TestResolveAllowedIP_LoopbackIPLiteral(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
allowLocalhost bool
|
||||
shouldFail bool
|
||||
}{
|
||||
{
|
||||
name: "127.0.0.1 without allowLocalhost",
|
||||
ip: "127.0.0.1",
|
||||
allowLocalhost: false,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with allowLocalhost",
|
||||
ip: "127.0.0.1",
|
||||
allowLocalhost: true,
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
name: "::1 without allowLocalhost",
|
||||
ip: "::1",
|
||||
allowLocalhost: false,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "::1 with allowLocalhost",
|
||||
ip: "::1",
|
||||
allowLocalhost: true,
|
||||
shouldFail: false,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip, err := resolveAllowedIP(ctx, tt.ip, tt.allowLocalhost)
|
||||
|
||||
if tt.shouldFail {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for %s without allowLocalhost", tt.ip)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal("ErroF for allowed loopback", err)
|
||||
}
|
||||
if ip == nil {
|
||||
t.Fatal("Expected non-nil IP")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_PrivateIPLiterals tests resolveAllowedIP blocks private IPs.
|
||||
func TestResolveAllowedIP_PrivateIPLiterals(t *testing.T) {
|
||||
privateIPs := []string{
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"192.168.1.1",
|
||||
"169.254.169.254", // AWS metadata
|
||||
"fc00::1", // IPv6 unique local
|
||||
"fe80::1", // IPv6 link-local
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, ip := range privateIPs {
|
||||
t.Run("IP_"+ip, func(t *testing.T) {
|
||||
_, err := resolveAllowedIP(ctx, ip, false)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for private IP %s, got nil", ip)
|
||||
}
|
||||
if err != nil && err.Error() != fmt.Sprintf("access to private IP addresses is blocked (resolved to %s)", ip) {
|
||||
// Check it contains the expected error substring
|
||||
expectedMsg := "access to private IP addresses is blocked"
|
||||
if !contains(err.Error(), expectedMsg) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", expectedMsg, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_PublicIPLiteral tests resolveAllowedIP allows public IPs.
|
||||
func TestResolveAllowedIP_PublicIPLiteral(t *testing.T) {
|
||||
publicIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"2001:4860:4860::8888",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, ipStr := range publicIPs {
|
||||
t.Run("IP_"+ipStr, func(t *testing.T) {
|
||||
ip, err := resolveAllowedIP(ctx, ipStr, false)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for public IP %s, got: %v", ipStr, err)
|
||||
}
|
||||
if ip == nil {
|
||||
t.Error("Expected non-nil IP for public address")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_Timeout tests DNS resolution timeout.
|
||||
func TestResolveAllowedIP_Timeout(t *testing.T) {
|
||||
// Create a context with very short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
|
||||
defer cancel()
|
||||
|
||||
// Any hostname should timeout
|
||||
_, err := resolveAllowedIP(ctx, "example.com", false)
|
||||
if err == nil {
|
||||
t.Fatal("Expected timeout error, got nil")
|
||||
}
|
||||
|
||||
// Should be a context deadline exceeded error
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !contains(err.Error(), "deadline") && !contains(err.Error(), "timeout") {
|
||||
t.Logf("Expected timeout/deadline error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_NoIPsResolved tests when DNS returns no IPs.
|
||||
// Note: This is difficult to test without a custom resolver, so we skip it
|
||||
func TestResolveAllowedIP_NoIPsResolved(t *testing.T) {
|
||||
t.Skip("Requires custom DNS resolver to return empty IP list")
|
||||
}
|
||||
|
||||
// TestSSRFSafeDialer_PrivateIPResolution tests that ssrfSafeDialer blocks private IPs.
|
||||
// Note: This requires network access or mocking, so we test the concept
|
||||
func TestSSRFSafeDialer_Concept(t *testing.T) {
|
||||
// The ssrfSafeDialer function should:
|
||||
// 1. Resolve the hostname to IPs
|
||||
// 2. Check ALL IPs against private ranges
|
||||
// 3. Reject if ANY IP is private
|
||||
// 4. Connect only to validated IPs
|
||||
|
||||
// We can't easily test this without network calls, but we document the behavior
|
||||
t.Log("ssrfSafeDialer validates IPs at dial time to prevent DNS rebinding")
|
||||
t.Log("All resolved IPs must pass private IP check before connection")
|
||||
}
|
||||
|
||||
// TestSSRFSafeDialer_InvalidAddress tests ssrfSafeDialer with malformed addresses.
|
||||
func TestSSRFSafeDialer_InvalidAddress(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dialer := ssrfSafeDialer()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
}{
|
||||
{
|
||||
name: "No port",
|
||||
addr: "example.com",
|
||||
},
|
||||
{
|
||||
name: "Invalid format",
|
||||
addr: ":/invalid",
|
||||
},
|
||||
{
|
||||
name: "Empty address",
|
||||
addr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := dialer(ctx, "tcp", tt.addr)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid address %s, got nil", tt.addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSRFSafeDialer_ContextCancellation tests context cancellation during dial.
|
||||
func TestSSRFSafeDialer_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
dialer := ssrfSafeDialer()
|
||||
_, err := dialer(ctx, "tcp", "example.com:80")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected context cancellation error, got nil")
|
||||
}
|
||||
|
||||
// Should be context canceled error
|
||||
if !errors.Is(err, context.Canceled) && !contains(err.Error(), "canceled") {
|
||||
t.Logf("Expected context canceled error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_ErrorPaths tests error handling in testURLConnectivity.
|
||||
func TestTestURLConnectivity_ErrorPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
errMatch string
|
||||
}{
|
||||
{
|
||||
name: "Invalid URL format",
|
||||
url: "://invalid",
|
||||
errMatch: "invalid URL",
|
||||
},
|
||||
{
|
||||
name: "Unsupported scheme FTP",
|
||||
url: "ftp://example.com",
|
||||
errMatch: "only http and https schemes are allowed",
|
||||
},
|
||||
{
|
||||
name: "Embedded credentials",
|
||||
url: "https://user:pass@example.com",
|
||||
errMatch: "embedded credentials are not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 10.x",
|
||||
url: "http://10.0.0.1",
|
||||
errMatch: "private IP",
|
||||
},
|
||||
{
|
||||
name: "Private IP 192.168.x",
|
||||
url: "http://192.168.1.1",
|
||||
errMatch: "private IP",
|
||||
},
|
||||
{
|
||||
name: "AWS metadata endpoint",
|
||||
url: "http://169.254.169.254/latest/meta-data/",
|
||||
errMatch: "private IP",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reachable, latency, err := TestURLConnectivity(tt.url)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for %s, got nil (reachable=%v, latency=%v)", tt.url, reachable, latency)
|
||||
}
|
||||
|
||||
if !contains(err.Error(), tt.errMatch) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", tt.errMatch, err)
|
||||
}
|
||||
|
||||
if reachable {
|
||||
t.Error("Expected reachable=false for error case")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_InvalidPort tests port validation in testURLConnectivity.
|
||||
func TestTestURLConnectivity_InvalidPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{
|
||||
name: "Port out of range (too high)",
|
||||
url: "https://example.com:99999",
|
||||
},
|
||||
{
|
||||
name: "Port zero",
|
||||
url: "https://example.com:0",
|
||||
},
|
||||
{
|
||||
name: "Negative port",
|
||||
url: "https://example.com:-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _, err := TestURLConnectivity(tt.url)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid port in %s", tt.url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRedirectTargetStrict tests are in url_testing_test.go using proper http types
|
||||
|
||||
// Helper function already defined in security tests
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && containsSubstring(s, substr)
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user