From cd8f5f9608c6a6ffb0f836a4332e3434a7cc9708 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Sun, 15 Feb 2026 20:11:03 +0000 Subject: [PATCH] fix: add parsing functions for nullable uint fields and forward port validation in proxy host updates --- .../api/handlers/proxy_host_handler.go | 145 ++++++++----- .../proxy_host_handler_update_test.go | 193 ++++++++++++++++++ 2 files changed, 291 insertions(+), 47 deletions(-) diff --git a/backend/internal/api/handlers/proxy_host_handler.go b/backend/internal/api/handlers/proxy_host_handler.go index f5556da6..cf6858ea 100644 --- a/backend/internal/api/handlers/proxy_host_handler.go +++ b/backend/internal/api/handlers/proxy_host_handler.go @@ -3,9 +3,11 @@ package handlers import ( "encoding/json" "fmt" + "math" "net" "net/http" "strconv" + "strings" "time" "github.com/gin-gonic/gin" @@ -149,6 +151,72 @@ func safeFloat64ToUint(f float64) (uint, bool) { return uint(f), true } +func parseNullableUintField(value any, fieldName string) (*uint, bool, error) { + if value == nil { + return nil, true, nil + } + + switch v := value.(type) { + case float64: + if id, ok := safeFloat64ToUint(v); ok { + return &id, true, nil + } + return nil, true, fmt.Errorf("invalid %s: unable to convert value %v of type %T to uint", fieldName, value, value) + case int: + if id, ok := safeIntToUint(v); ok { + return &id, true, nil + } + return nil, true, fmt.Errorf("invalid %s: unable to convert value %v of type %T to uint", fieldName, value, value) + case string: + trimmed := strings.TrimSpace(v) + if trimmed == "" { + return nil, true, nil + } + n, err := strconv.ParseUint(trimmed, 10, 32) + if err != nil { + return nil, true, fmt.Errorf("invalid %s: unable to convert value %v of type %T to uint", fieldName, value, value) + } + id := uint(n) + return &id, true, nil + default: + return nil, true, fmt.Errorf("invalid %s: unable to convert value %v of type %T to uint", fieldName, value, value) + } +} + +func parseForwardPortField(value any) (int, error) { + switch v := value.(type) { + case float64: + if v != math.Trunc(v) { + return 0, fmt.Errorf("invalid forward_port: must be an integer") + } + port := int(v) + if port < 1 || port > 65535 { + return 0, fmt.Errorf("invalid forward_port: must be between 1 and 65535") + } + return port, nil + case int: + if v < 1 || v > 65535 { + return 0, fmt.Errorf("invalid forward_port: must be between 1 and 65535") + } + return v, nil + case string: + trimmed := strings.TrimSpace(v) + if trimmed == "" { + return 0, fmt.Errorf("invalid forward_port: must be between 1 and 65535") + } + port, err := strconv.Atoi(trimmed) + if err != nil { + return 0, fmt.Errorf("invalid forward_port: must be an integer") + } + if port < 1 || port > 65535 { + return 0, fmt.Errorf("invalid forward_port: must be between 1 and 65535") + } + return port, nil + default: + return 0, fmt.Errorf("invalid forward_port: unsupported type %T", value) + } +} + // NewProxyHostHandler creates a new proxy host handler. func NewProxyHostHandler(db *gorm.DB, caddyManager *caddy.Manager, ns *services.NotificationService, uptimeService *services.UptimeService) *ProxyHostHandler { return &ProxyHostHandler{ @@ -292,25 +360,21 @@ func (h *ProxyHostHandler) Update(c *gin.Context) { host.Name = v } if v, ok := payload["domain_names"].(string); ok { - host.DomainNames = v + host.DomainNames = strings.TrimSpace(v) } if v, ok := payload["forward_scheme"].(string); ok { host.ForwardScheme = v } if v, ok := payload["forward_host"].(string); ok { - host.ForwardHost = v + host.ForwardHost = strings.TrimSpace(v) } if v, ok := payload["forward_port"]; ok { - switch t := v.(type) { - case float64: - host.ForwardPort = int(t) - case int: - host.ForwardPort = t - case string: - if p, err := strconv.Atoi(t); err == nil { - host.ForwardPort = p - } + port, parseErr := parseForwardPortField(v) + if parseErr != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": parseErr.Error()}) + return } + host.ForwardPort = port } if v, ok := payload["ssl_forced"].(bool); ok { host.SSLForced = v @@ -358,46 +422,33 @@ func (h *ProxyHostHandler) Update(c *gin.Context) { // Nullable foreign keys if v, ok := payload["certificate_id"]; ok { - if v == nil { - host.CertificateID = nil - } else { - switch t := v.(type) { - case float64: - if id, ok := safeFloat64ToUint(t); ok { - host.CertificateID = &id - } - case int: - if id, ok := safeIntToUint(t); ok { - host.CertificateID = &id - } - case string: - if n, err := strconv.ParseUint(t, 10, 32); err == nil { - id := uint(n) - host.CertificateID = &id - } - } + parsedID, _, parseErr := parseNullableUintField(v, "certificate_id") + if parseErr != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": parseErr.Error()}) + return } + host.CertificateID = parsedID } if v, ok := payload["access_list_id"]; ok { - if v == nil { - host.AccessListID = nil - } else { - switch t := v.(type) { - case float64: - if id, ok := safeFloat64ToUint(t); ok { - host.AccessListID = &id - } - case int: - if id, ok := safeIntToUint(t); ok { - host.AccessListID = &id - } - case string: - if n, err := strconv.ParseUint(t, 10, 32); err == nil { - id := uint(n) - host.AccessListID = &id - } - } + parsedID, _, parseErr := parseNullableUintField(v, "access_list_id") + if parseErr != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": parseErr.Error()}) + return } + host.AccessListID = parsedID + } + + if v, ok := payload["dns_provider_id"]; ok { + parsedID, _, parseErr := parseNullableUintField(v, "dns_provider_id") + if parseErr != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": parseErr.Error()}) + return + } + host.DNSProviderID = parsedID + } + + if v, ok := payload["use_dns_challenge"].(bool); ok { + host.UseDNSChallenge = v } // Security Header Profile: update only if provided diff --git a/backend/internal/api/handlers/proxy_host_handler_update_test.go b/backend/internal/api/handlers/proxy_host_handler_update_test.go index ceeba53c..698d8bd0 100644 --- a/backend/internal/api/handlers/proxy_host_handler_update_test.go +++ b/backend/internal/api/handlers/proxy_host_handler_update_test.go @@ -295,6 +295,120 @@ func TestProxyHostUpdate_WAFDisabled(t *testing.T) { assert.True(t, updated.WAFDisabled) } +func TestProxyHostUpdate_DNSChallengeFieldsPersist(t *testing.T) { + t.Parallel() + router, db := setupUpdateTestRouter(t) + + host := models.ProxyHost{ + UUID: uuid.NewString(), + Name: "DNS Challenge Host", + DomainNames: "dns-challenge.example.com", + ForwardScheme: "http", + ForwardHost: "localhost", + ForwardPort: 8080, + Enabled: true, + UseDNSChallenge: false, + DNSProviderID: nil, + } + require.NoError(t, db.Create(&host).Error) + + updateBody := map[string]any{ + "domain_names": "dns-challenge.example.com", + "forward_host": "localhost", + "forward_port": 8080, + "dns_provider_id": "7", + "use_dns_challenge": true, + } + body, _ := json.Marshal(updateBody) + + req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + + var updated models.ProxyHost + require.NoError(t, db.First(&updated, "uuid = ?", host.UUID).Error) + require.NotNil(t, updated.DNSProviderID) + assert.Equal(t, uint(7), *updated.DNSProviderID) + assert.True(t, updated.UseDNSChallenge) +} + +func TestProxyHostUpdate_DNSChallengeRequiresProvider(t *testing.T) { + t.Parallel() + router, db := setupUpdateTestRouter(t) + + host := createTestProxyHost(t, db, "dns-validation") + + updateBody := map[string]any{ + "domain_names": "dns-validation.test.com", + "forward_host": "localhost", + "forward_port": 8080, + "dns_provider_id": nil, + "use_dns_challenge": true, + } + body, _ := json.Marshal(updateBody) + + req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, http.StatusBadRequest, resp.Code) + + var updated models.ProxyHost + require.NoError(t, db.First(&updated, "uuid = ?", host.UUID).Error) + assert.False(t, updated.UseDNSChallenge) + assert.Nil(t, updated.DNSProviderID) +} + +func TestProxyHostUpdate_InvalidForwardPortRejected(t *testing.T) { + t.Parallel() + router, db := setupUpdateTestRouter(t) + + host := createTestProxyHost(t, db, "invalid-forward-port") + + updateBody := map[string]any{ + "forward_port": 70000, + } + body, _ := json.Marshal(updateBody) + + req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, http.StatusBadRequest, resp.Code) + + var updated models.ProxyHost + require.NoError(t, db.First(&updated, "uuid = ?", host.UUID).Error) + assert.Equal(t, 8080, updated.ForwardPort) +} + +func TestProxyHostUpdate_InvalidCertificateIDRejected(t *testing.T) { + t.Parallel() + router, db := setupUpdateTestRouter(t) + + host := createTestProxyHost(t, db, "invalid-certificate-id") + + updateBody := map[string]any{ + "certificate_id": true, + } + body, _ := json.Marshal(updateBody) + + req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, http.StatusBadRequest, resp.Code) + + var result map[string]any + require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &result)) + assert.Contains(t, result["error"], "invalid certificate_id") +} + func TestProxyHostUpdate_RejectsEmptyDomainNamesAndPreservesOriginal(t *testing.T) { t.Parallel() router, db := setupUpdateTestRouter(t) @@ -649,3 +763,82 @@ func TestBulkUpdateSecurityHeaders_DBError_NonNotFound(t *testing.T) { // The handler should return 500 when DB operations fail require.Equal(t, http.StatusInternalServerError, resp.Code) } + +func TestParseNullableUintField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value any + wantID *uint + wantErr bool + errContain string + }{ + {name: "nil", value: nil, wantID: nil, wantErr: false}, + {name: "float64", value: 5.0, wantID: func() *uint { v := uint(5); return &v }(), wantErr: false}, + {name: "int", value: 9, wantID: func() *uint { v := uint(9); return &v }(), wantErr: false}, + {name: "string", value: "12", wantID: func() *uint { v := uint(12); return &v }(), wantErr: false}, + {name: "blank string", value: " ", wantID: nil, wantErr: false}, + {name: "negative float", value: -1.0, wantErr: true, errContain: "invalid test_field"}, + {name: "invalid string", value: "nope", wantErr: true, errContain: "invalid test_field"}, + {name: "unsupported", value: true, wantErr: true, errContain: "invalid test_field"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + id, _, err := parseNullableUintField(tt.value, "test_field") + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContain) + return + } + + require.NoError(t, err) + if tt.wantID == nil { + assert.Nil(t, id) + return + } + require.NotNil(t, id) + assert.Equal(t, *tt.wantID, *id) + }) + } +} + +func TestParseForwardPortField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value any + wantPort int + wantErr bool + errContain string + }{ + {name: "float integer", value: 8080.0, wantPort: 8080, wantErr: false}, + {name: "float decimal", value: 8080.5, wantErr: true, errContain: "must be an integer"}, + {name: "int", value: 3000, wantPort: 3000, wantErr: false}, + {name: "int low", value: 0, wantErr: true, errContain: "between 1 and 65535"}, + {name: "string", value: "443", wantPort: 443, wantErr: false}, + {name: "string blank", value: " ", wantErr: true, errContain: "between 1 and 65535"}, + {name: "string invalid", value: "abc", wantErr: true, errContain: "must be an integer"}, + {name: "unsupported", value: true, wantErr: true, errContain: "unsupported type"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + port, err := parseForwardPortField(tt.value) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContain) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantPort, port) + }) + } +}