540 lines
14 KiB
Go
540 lines
14 KiB
Go
package notifications
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
crand "crypto/rand"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
neturl "net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Wikid82/charon/backend/internal/network"
|
|
"github.com/Wikid82/charon/backend/internal/security"
|
|
)
|
|
|
|
const (
|
|
MaxNotifyRequestBodyBytes = 256 * 1024
|
|
MaxNotifyResponseBodyBytes = 1024 * 1024
|
|
)
|
|
|
|
type RetryPolicy struct {
|
|
MaxAttempts int
|
|
BaseDelay time.Duration
|
|
MaxDelay time.Duration
|
|
}
|
|
|
|
type HTTPWrapperRequest struct {
|
|
URL string
|
|
Headers map[string]string
|
|
Body []byte
|
|
}
|
|
|
|
type HTTPWrapperResult struct {
|
|
StatusCode int
|
|
ResponseBody []byte
|
|
Attempts int
|
|
}
|
|
|
|
type HTTPWrapper struct {
|
|
retryPolicy RetryPolicy
|
|
allowHTTP bool
|
|
maxRedirects int
|
|
httpClientFactory func(allowHTTP bool, maxRedirects int) *http.Client
|
|
sleep func(time.Duration)
|
|
jitterNanos func(int64) int64
|
|
}
|
|
|
|
func NewNotifyHTTPWrapper() *HTTPWrapper {
|
|
return &HTTPWrapper{
|
|
retryPolicy: RetryPolicy{
|
|
MaxAttempts: 3,
|
|
BaseDelay: 200 * time.Millisecond,
|
|
MaxDelay: 2 * time.Second,
|
|
},
|
|
allowHTTP: allowNotifyHTTPOverride(),
|
|
maxRedirects: notifyMaxRedirects(),
|
|
httpClientFactory: func(allowHTTP bool, maxRedirects int) *http.Client {
|
|
opts := []network.Option{network.WithTimeout(10 * time.Second), network.WithMaxRedirects(maxRedirects)}
|
|
if allowHTTP {
|
|
opts = append(opts, network.WithAllowLocalhost())
|
|
}
|
|
return network.NewSafeHTTPClient(opts...)
|
|
},
|
|
sleep: time.Sleep,
|
|
}
|
|
}
|
|
|
|
func (w *HTTPWrapper) Send(ctx context.Context, request HTTPWrapperRequest) (*HTTPWrapperResult, error) {
|
|
if len(request.Body) > MaxNotifyRequestBodyBytes {
|
|
return nil, fmt.Errorf("request payload exceeds maximum size")
|
|
}
|
|
|
|
validatedURL, err := w.validateURL(request.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parsedValidatedURL, err := neturl.Parse(validatedURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
validationOptions := []security.ValidationOption{}
|
|
if w.allowHTTP {
|
|
validationOptions = append(validationOptions, security.WithAllowHTTP(), security.WithAllowLocalhost())
|
|
}
|
|
|
|
safeURL, safeURLErr := security.ValidateExternalURL(parsedValidatedURL.String(), validationOptions...)
|
|
if safeURLErr != nil {
|
|
return nil, fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
safeParsedURL, safeParseErr := neturl.Parse(safeURL)
|
|
if safeParseErr != nil {
|
|
return nil, fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
if err := w.guardDestination(safeParsedURL); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
safeRequestURL, hostHeader, safeRequestErr := w.buildSafeRequestURL(safeParsedURL)
|
|
if safeRequestErr != nil {
|
|
return nil, safeRequestErr
|
|
}
|
|
|
|
headers := sanitizeOutboundHeaders(request.Headers)
|
|
client := w.httpClientFactory(w.allowHTTP, w.maxRedirects)
|
|
w.applyRedirectGuard(client)
|
|
|
|
var lastErr error
|
|
for attempt := 1; attempt <= w.retryPolicy.MaxAttempts; attempt++ {
|
|
httpReq, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, safeRequestURL.String(), bytes.NewReader(request.Body))
|
|
if reqErr != nil {
|
|
return nil, fmt.Errorf("create outbound request: %w", reqErr)
|
|
}
|
|
|
|
httpReq.Host = hostHeader
|
|
|
|
for key, value := range headers {
|
|
httpReq.Header.Set(key, value)
|
|
}
|
|
|
|
if httpReq.Header.Get("Content-Type") == "" {
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
resp, doErr := executeNotifyRequest(client, httpReq)
|
|
if doErr != nil {
|
|
lastErr = doErr
|
|
if attempt < w.retryPolicy.MaxAttempts && shouldRetry(nil, doErr) {
|
|
w.waitBeforeRetry(attempt)
|
|
continue
|
|
}
|
|
return nil, fmt.Errorf("outbound request failed: %s", sanitizeTransportErrorReason(doErr))
|
|
}
|
|
|
|
body, bodyErr := readCappedResponseBody(resp.Body)
|
|
closeErr := resp.Body.Close()
|
|
if bodyErr != nil {
|
|
return nil, bodyErr
|
|
}
|
|
if closeErr != nil {
|
|
return nil, fmt.Errorf("close response body: %w", closeErr)
|
|
}
|
|
|
|
if shouldRetry(resp, nil) && attempt < w.retryPolicy.MaxAttempts {
|
|
w.waitBeforeRetry(attempt)
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode >= http.StatusBadRequest {
|
|
if hint := extractProviderErrorHint(body); hint != "" {
|
|
return nil, fmt.Errorf("provider returned status %d: %s", resp.StatusCode, hint)
|
|
}
|
|
return nil, fmt.Errorf("provider returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
return &HTTPWrapperResult{
|
|
StatusCode: resp.StatusCode,
|
|
ResponseBody: body,
|
|
Attempts: attempt,
|
|
}, nil
|
|
}
|
|
|
|
if lastErr != nil {
|
|
return nil, fmt.Errorf("provider request failed after retries: %s", sanitizeTransportErrorReason(lastErr))
|
|
}
|
|
|
|
return nil, fmt.Errorf("provider request failed")
|
|
}
|
|
|
|
func sanitizeTransportErrorReason(err error) string {
|
|
if err == nil {
|
|
return "connection failed"
|
|
}
|
|
|
|
errText := strings.ToLower(strings.TrimSpace(err.Error()))
|
|
|
|
switch {
|
|
case strings.Contains(errText, "no such host"):
|
|
return "dns lookup failed"
|
|
case strings.Contains(errText, "connection refused"):
|
|
return "connection refused"
|
|
case strings.Contains(errText, "no route to host") || strings.Contains(errText, "network is unreachable"):
|
|
return "network unreachable"
|
|
case strings.Contains(errText, "timeout") || strings.Contains(errText, "deadline exceeded"):
|
|
return "request timed out"
|
|
case strings.Contains(errText, "tls") || strings.Contains(errText, "certificate") || strings.Contains(errText, "x509"):
|
|
return "tls handshake failed"
|
|
default:
|
|
return "connection failed"
|
|
}
|
|
}
|
|
|
|
func (w *HTTPWrapper) applyRedirectGuard(client *http.Client) {
|
|
if client == nil {
|
|
return
|
|
}
|
|
|
|
originalCheckRedirect := client.CheckRedirect
|
|
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
|
if originalCheckRedirect != nil {
|
|
if err := originalCheckRedirect(req, via); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return w.guardOutboundRequestURL(req)
|
|
}
|
|
}
|
|
|
|
func (w *HTTPWrapper) validateURL(rawURL string) (string, error) {
|
|
parsedURL, err := neturl.Parse(rawURL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid destination URL")
|
|
}
|
|
|
|
if hasDisallowedQueryAuthKey(parsedURL.Query()) {
|
|
return "", fmt.Errorf("destination URL query authentication is not allowed")
|
|
}
|
|
|
|
options := []security.ValidationOption{}
|
|
if w.allowHTTP {
|
|
options = append(options, security.WithAllowHTTP(), security.WithAllowLocalhost())
|
|
}
|
|
|
|
validatedURL, err := security.ValidateExternalURL(rawURL, options...)
|
|
if err != nil {
|
|
return "", fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
return validatedURL, nil
|
|
}
|
|
|
|
func hasDisallowedQueryAuthKey(query neturl.Values) bool {
|
|
for key := range query {
|
|
normalizedKey := strings.ToLower(strings.TrimSpace(key))
|
|
switch normalizedKey {
|
|
case "token", "auth", "apikey", "api_key":
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (w *HTTPWrapper) guardOutboundRequestURL(httpReq *http.Request) error {
|
|
if httpReq == nil || httpReq.URL == nil {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
reqURL := httpReq.URL.String()
|
|
validatedURL, err := w.validateURL(reqURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
parsedValidatedURL, err := neturl.Parse(validatedURL)
|
|
if err != nil {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
return w.guardDestination(parsedValidatedURL)
|
|
}
|
|
|
|
func (w *HTTPWrapper) guardDestination(destinationURL *neturl.URL) error {
|
|
if destinationURL == nil {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
if destinationURL.User != nil || destinationURL.Fragment != "" {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
hostname := strings.TrimSpace(destinationURL.Hostname())
|
|
if hostname == "" {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
if parsedIP := net.ParseIP(hostname); parsedIP != nil {
|
|
if !w.isAllowedDestinationIP(hostname, parsedIP) {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
resolvedIPs, err := net.LookupIP(hostname)
|
|
if err != nil || len(resolvedIPs) == 0 {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
for _, resolvedIP := range resolvedIPs {
|
|
if !w.isAllowedDestinationIP(hostname, resolvedIP) {
|
|
return fmt.Errorf("destination URL validation failed")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *HTTPWrapper) isAllowedDestinationIP(hostname string, ip net.IP) bool {
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
|
|
if ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
|
return false
|
|
}
|
|
|
|
if ip.IsLoopback() {
|
|
return w.allowHTTP && isLocalDestinationHost(hostname)
|
|
}
|
|
|
|
if network.IsPrivateIP(ip) {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (w *HTTPWrapper) buildSafeRequestURL(destinationURL *neturl.URL) (*neturl.URL, string, error) {
|
|
if destinationURL == nil {
|
|
return nil, "", fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
hostname := strings.TrimSpace(destinationURL.Hostname())
|
|
if hostname == "" {
|
|
return nil, "", fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
// Validate destination IPs are allowed (defense-in-depth alongside safeDialer).
|
|
_, err := w.resolveAllowedDestinationIP(hostname)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Preserve the original hostname in the URL so Go's TLS layer derives the
|
|
// correct ServerName for SNI and certificate verification. The safeDialer
|
|
// resolves DNS, validates IPs against SSRF rules, and connects to a
|
|
// validated IP at dial time, so protection is maintained without
|
|
// IP-pinning in the URL.
|
|
safeRequestURL := &neturl.URL{
|
|
Scheme: destinationURL.Scheme,
|
|
Host: destinationURL.Host,
|
|
Path: destinationURL.EscapedPath(),
|
|
RawQuery: destinationURL.RawQuery,
|
|
}
|
|
|
|
if safeRequestURL.Path == "" {
|
|
safeRequestURL.Path = "/"
|
|
}
|
|
|
|
return safeRequestURL, destinationURL.Host, nil
|
|
}
|
|
|
|
func (w *HTTPWrapper) resolveAllowedDestinationIP(hostname string) (net.IP, error) {
|
|
if parsedIP := net.ParseIP(hostname); parsedIP != nil {
|
|
if !w.isAllowedDestinationIP(hostname, parsedIP) {
|
|
return nil, fmt.Errorf("destination URL validation failed")
|
|
}
|
|
return parsedIP, nil
|
|
}
|
|
|
|
resolvedIPs, err := net.LookupIP(hostname)
|
|
if err != nil || len(resolvedIPs) == 0 {
|
|
return nil, fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
for _, resolvedIP := range resolvedIPs {
|
|
if w.isAllowedDestinationIP(hostname, resolvedIP) {
|
|
return resolvedIP, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("destination URL validation failed")
|
|
}
|
|
|
|
func isLocalDestinationHost(host string) bool {
|
|
trimmedHost := strings.TrimSpace(host)
|
|
if strings.EqualFold(trimmedHost, "localhost") {
|
|
return true
|
|
}
|
|
|
|
parsedIP := net.ParseIP(trimmedHost)
|
|
return parsedIP != nil && parsedIP.IsLoopback()
|
|
}
|
|
|
|
func shouldRetry(resp *http.Response, err error) bool {
|
|
if err != nil {
|
|
var netErr net.Error
|
|
if isNetErr := strings.Contains(strings.ToLower(err.Error()), "timeout") || strings.Contains(strings.ToLower(err.Error()), "connection"); isNetErr {
|
|
return true
|
|
}
|
|
return errors.As(err, &netErr)
|
|
}
|
|
|
|
if resp == nil {
|
|
return false
|
|
}
|
|
|
|
if resp.StatusCode == http.StatusTooManyRequests {
|
|
return true
|
|
}
|
|
|
|
return resp.StatusCode >= http.StatusInternalServerError
|
|
}
|
|
|
|
// extractProviderErrorHint attempts to extract a short, human-readable error description
|
|
// from a JSON error response body. Only well-known fields are extracted to avoid
|
|
// accidentally surfacing sensitive or overlong content from arbitrary providers.
|
|
func extractProviderErrorHint(body []byte) string {
|
|
if len(body) == 0 {
|
|
return ""
|
|
}
|
|
var errResp map[string]any
|
|
if err := json.Unmarshal(body, &errResp); err != nil {
|
|
return ""
|
|
}
|
|
for _, key := range []string{"description", "message", "error", "error_description"} {
|
|
v, ok := errResp[key]
|
|
if !ok {
|
|
continue
|
|
}
|
|
s, ok := v.(string)
|
|
if !ok || strings.TrimSpace(s) == "" {
|
|
continue
|
|
}
|
|
if len(s) > 100 {
|
|
s = s[:100] + "..."
|
|
}
|
|
return strings.TrimSpace(s)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func readCappedResponseBody(body io.Reader) ([]byte, error) {
|
|
limited := io.LimitReader(body, MaxNotifyResponseBodyBytes+1)
|
|
content, err := io.ReadAll(limited)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response body: %w", err)
|
|
}
|
|
|
|
if len(content) > MaxNotifyResponseBodyBytes {
|
|
return nil, fmt.Errorf("response payload exceeds maximum size")
|
|
}
|
|
|
|
return content, nil
|
|
}
|
|
|
|
func sanitizeOutboundHeaders(headers map[string]string) map[string]string {
|
|
allowed := map[string]struct{}{
|
|
"content-type": {},
|
|
"user-agent": {},
|
|
"x-request-id": {},
|
|
"x-gotify-key": {},
|
|
}
|
|
|
|
sanitized := make(map[string]string)
|
|
for key, value := range headers {
|
|
normalizedKey := strings.ToLower(strings.TrimSpace(key))
|
|
if _, ok := allowed[normalizedKey]; !ok {
|
|
continue
|
|
}
|
|
sanitized[http.CanonicalHeaderKey(normalizedKey)] = strings.TrimSpace(value)
|
|
}
|
|
|
|
return sanitized
|
|
}
|
|
|
|
func (w *HTTPWrapper) waitBeforeRetry(attempt int) {
|
|
delay := w.retryPolicy.BaseDelay << (attempt - 1)
|
|
if delay > w.retryPolicy.MaxDelay {
|
|
delay = w.retryPolicy.MaxDelay
|
|
}
|
|
|
|
jitterFn := w.jitterNanos
|
|
if jitterFn == nil {
|
|
jitterFn = func(max int64) int64 {
|
|
if max <= 0 {
|
|
return 0
|
|
}
|
|
n, err := crand.Int(crand.Reader, big.NewInt(max))
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return n.Int64()
|
|
}
|
|
}
|
|
|
|
jitter := time.Duration(jitterFn(int64(delay) / 2))
|
|
sleepFn := w.sleep
|
|
if sleepFn == nil {
|
|
sleepFn = time.Sleep
|
|
}
|
|
sleepFn(delay + jitter)
|
|
}
|
|
|
|
func allowNotifyHTTPOverride() bool {
|
|
if strings.HasSuffix(os.Args[0], ".test") {
|
|
return true
|
|
}
|
|
|
|
allowHTTP := strings.EqualFold(strings.TrimSpace(os.Getenv("CHARON_NOTIFY_ALLOW_HTTP")), "true")
|
|
if !allowHTTP {
|
|
return false
|
|
}
|
|
|
|
environment := strings.ToLower(strings.TrimSpace(os.Getenv("CHARON_ENV")))
|
|
return environment == "development" || environment == "test"
|
|
}
|
|
|
|
func notifyMaxRedirects() int {
|
|
raw := strings.TrimSpace(os.Getenv("CHARON_NOTIFY_MAX_REDIRECTS"))
|
|
if raw == "" {
|
|
return 0
|
|
}
|
|
|
|
value, err := strconv.Atoi(raw)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
|
|
if value < 0 {
|
|
return 0
|
|
}
|
|
if value > 5 {
|
|
return 5
|
|
}
|
|
return value
|
|
}
|