chore: clean .gitignore cache
This commit is contained in:
@@ -1,320 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestResolveAllowedIP_EmptyHostname tests resolveAllowedIP with empty hostname.
|
||||
func TestResolveAllowedIP_EmptyHostname(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := resolveAllowedIP(ctx, "", false)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty hostname, got nil")
|
||||
}
|
||||
if err.Error() != "missing hostname" {
|
||||
t.Errorf("Expected 'missing hostname', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_LoopbackIPLiteral tests resolveAllowedIP with loopback IPs.
|
||||
func TestResolveAllowedIP_LoopbackIPLiteral(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
allowLocalhost bool
|
||||
shouldFail bool
|
||||
}{
|
||||
{
|
||||
name: "127.0.0.1 without allowLocalhost",
|
||||
ip: "127.0.0.1",
|
||||
allowLocalhost: false,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with allowLocalhost",
|
||||
ip: "127.0.0.1",
|
||||
allowLocalhost: true,
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
name: "::1 without allowLocalhost",
|
||||
ip: "::1",
|
||||
allowLocalhost: false,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "::1 with allowLocalhost",
|
||||
ip: "::1",
|
||||
allowLocalhost: true,
|
||||
shouldFail: false,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip, err := resolveAllowedIP(ctx, tt.ip, tt.allowLocalhost)
|
||||
|
||||
if tt.shouldFail {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for %s without allowLocalhost", tt.ip)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal("ErroF for allowed loopback", err)
|
||||
}
|
||||
if ip == nil {
|
||||
t.Fatal("Expected non-nil IP")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_PrivateIPLiterals tests resolveAllowedIP blocks private IPs.
|
||||
func TestResolveAllowedIP_PrivateIPLiterals(t *testing.T) {
|
||||
privateIPs := []string{
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"192.168.1.1",
|
||||
"169.254.169.254", // AWS metadata
|
||||
"fc00::1", // IPv6 unique local
|
||||
"fe80::1", // IPv6 link-local
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, ip := range privateIPs {
|
||||
t.Run("IP_"+ip, func(t *testing.T) {
|
||||
_, err := resolveAllowedIP(ctx, ip, false)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for private IP %s, got nil", ip)
|
||||
}
|
||||
if err != nil && err.Error() != fmt.Sprintf("access to private IP addresses is blocked (resolved to %s)", ip) {
|
||||
// Check it contains the expected error substring
|
||||
expectedMsg := "access to private IP addresses is blocked"
|
||||
if !contains(err.Error(), expectedMsg) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", expectedMsg, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_PublicIPLiteral tests resolveAllowedIP allows public IPs.
|
||||
func TestResolveAllowedIP_PublicIPLiteral(t *testing.T) {
|
||||
publicIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"2001:4860:4860::8888",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, ipStr := range publicIPs {
|
||||
t.Run("IP_"+ipStr, func(t *testing.T) {
|
||||
ip, err := resolveAllowedIP(ctx, ipStr, false)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for public IP %s, got: %v", ipStr, err)
|
||||
}
|
||||
if ip == nil {
|
||||
t.Error("Expected non-nil IP for public address")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_Timeout tests DNS resolution timeout.
|
||||
func TestResolveAllowedIP_Timeout(t *testing.T) {
|
||||
// Create a context with very short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
|
||||
defer cancel()
|
||||
|
||||
// Any hostname should timeout
|
||||
_, err := resolveAllowedIP(ctx, "example.com", false)
|
||||
if err == nil {
|
||||
t.Fatal("Expected timeout error, got nil")
|
||||
}
|
||||
|
||||
// Should be a context deadline exceeded error
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !contains(err.Error(), "deadline") && !contains(err.Error(), "timeout") {
|
||||
t.Logf("Expected timeout/deadline error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAllowedIP_NoIPsResolved tests when DNS returns no IPs.
|
||||
// Note: This is difficult to test without a custom resolver, so we skip it
|
||||
func TestResolveAllowedIP_NoIPsResolved(t *testing.T) {
|
||||
t.Skip("Requires custom DNS resolver to return empty IP list")
|
||||
}
|
||||
|
||||
// TestSSRFSafeDialer_PrivateIPResolution tests that ssrfSafeDialer blocks private IPs.
|
||||
// Note: This requires network access or mocking, so we test the concept
|
||||
func TestSSRFSafeDialer_Concept(t *testing.T) {
|
||||
// The ssrfSafeDialer function should:
|
||||
// 1. Resolve the hostname to IPs
|
||||
// 2. Check ALL IPs against private ranges
|
||||
// 3. Reject if ANY IP is private
|
||||
// 4. Connect only to validated IPs
|
||||
|
||||
// We can't easily test this without network calls, but we document the behavior
|
||||
t.Log("ssrfSafeDialer validates IPs at dial time to prevent DNS rebinding")
|
||||
t.Log("All resolved IPs must pass private IP check before connection")
|
||||
}
|
||||
|
||||
// TestSSRFSafeDialer_InvalidAddress tests ssrfSafeDialer with malformed addresses.
|
||||
func TestSSRFSafeDialer_InvalidAddress(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dialer := ssrfSafeDialer()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
}{
|
||||
{
|
||||
name: "No port",
|
||||
addr: "example.com",
|
||||
},
|
||||
{
|
||||
name: "Invalid format",
|
||||
addr: ":/invalid",
|
||||
},
|
||||
{
|
||||
name: "Empty address",
|
||||
addr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := dialer(ctx, "tcp", tt.addr)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid address %s, got nil", tt.addr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSRFSafeDialer_ContextCancellation tests context cancellation during dial.
|
||||
func TestSSRFSafeDialer_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
dialer := ssrfSafeDialer()
|
||||
_, err := dialer(ctx, "tcp", "example.com:80")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected context cancellation error, got nil")
|
||||
}
|
||||
|
||||
// Should be context canceled error
|
||||
if !errors.Is(err, context.Canceled) && !contains(err.Error(), "canceled") {
|
||||
t.Logf("Expected context canceled error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_ErrorPaths tests error handling in testURLConnectivity.
|
||||
func TestTestURLConnectivity_ErrorPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
errMatch string
|
||||
}{
|
||||
{
|
||||
name: "Invalid URL format",
|
||||
url: "://invalid",
|
||||
errMatch: "invalid URL",
|
||||
},
|
||||
{
|
||||
name: "Unsupported scheme FTP",
|
||||
url: "ftp://example.com",
|
||||
errMatch: "only http and https schemes are allowed",
|
||||
},
|
||||
{
|
||||
name: "Embedded credentials",
|
||||
url: "https://user:pass@example.com",
|
||||
errMatch: "embedded credentials are not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 10.x",
|
||||
url: "http://10.0.0.1",
|
||||
errMatch: "private IP",
|
||||
},
|
||||
{
|
||||
name: "Private IP 192.168.x",
|
||||
url: "http://192.168.1.1",
|
||||
errMatch: "private IP",
|
||||
},
|
||||
{
|
||||
name: "AWS metadata endpoint",
|
||||
url: "http://169.254.169.254/latest/meta-data/",
|
||||
errMatch: "private IP",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reachable, latency, err := TestURLConnectivity(tt.url)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error for %s, got nil (reachable=%v, latency=%v)", tt.url, reachable, latency)
|
||||
}
|
||||
|
||||
if !contains(err.Error(), tt.errMatch) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", tt.errMatch, err)
|
||||
}
|
||||
|
||||
if reachable {
|
||||
t.Error("Expected reachable=false for error case")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTestURLConnectivity_InvalidPort tests port validation in testURLConnectivity.
|
||||
func TestTestURLConnectivity_InvalidPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{
|
||||
name: "Port out of range (too high)",
|
||||
url: "https://example.com:99999",
|
||||
},
|
||||
{
|
||||
name: "Port zero",
|
||||
url: "https://example.com:0",
|
||||
},
|
||||
{
|
||||
name: "Negative port",
|
||||
url: "https://example.com:-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _, err := TestURLConnectivity(tt.url)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid port in %s", tt.url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRedirectTargetStrict tests are in url_testing_test.go using proper http types
|
||||
|
||||
// Helper function already defined in security tests
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && containsSubstring(s, substr)
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user