fix: add parsing functions for nullable uint fields and forward port validation in proxy host updates

This commit is contained in:
GitHub Actions
2026-02-15 20:11:03 +00:00
parent f4fafde161
commit cd8f5f9608
2 changed files with 291 additions and 47 deletions
@@ -295,6 +295,120 @@ func TestProxyHostUpdate_WAFDisabled(t *testing.T) {
assert.True(t, updated.WAFDisabled)
}
func TestProxyHostUpdate_DNSChallengeFieldsPersist(t *testing.T) {
t.Parallel()
router, db := setupUpdateTestRouter(t)
host := models.ProxyHost{
UUID: uuid.NewString(),
Name: "DNS Challenge Host",
DomainNames: "dns-challenge.example.com",
ForwardScheme: "http",
ForwardHost: "localhost",
ForwardPort: 8080,
Enabled: true,
UseDNSChallenge: false,
DNSProviderID: nil,
}
require.NoError(t, db.Create(&host).Error)
updateBody := map[string]any{
"domain_names": "dns-challenge.example.com",
"forward_host": "localhost",
"forward_port": 8080,
"dns_provider_id": "7",
"use_dns_challenge": true,
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code)
var updated models.ProxyHost
require.NoError(t, db.First(&updated, "uuid = ?", host.UUID).Error)
require.NotNil(t, updated.DNSProviderID)
assert.Equal(t, uint(7), *updated.DNSProviderID)
assert.True(t, updated.UseDNSChallenge)
}
func TestProxyHostUpdate_DNSChallengeRequiresProvider(t *testing.T) {
t.Parallel()
router, db := setupUpdateTestRouter(t)
host := createTestProxyHost(t, db, "dns-validation")
updateBody := map[string]any{
"domain_names": "dns-validation.test.com",
"forward_host": "localhost",
"forward_port": 8080,
"dns_provider_id": nil,
"use_dns_challenge": true,
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
var updated models.ProxyHost
require.NoError(t, db.First(&updated, "uuid = ?", host.UUID).Error)
assert.False(t, updated.UseDNSChallenge)
assert.Nil(t, updated.DNSProviderID)
}
func TestProxyHostUpdate_InvalidForwardPortRejected(t *testing.T) {
t.Parallel()
router, db := setupUpdateTestRouter(t)
host := createTestProxyHost(t, db, "invalid-forward-port")
updateBody := map[string]any{
"forward_port": 70000,
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
var updated models.ProxyHost
require.NoError(t, db.First(&updated, "uuid = ?", host.UUID).Error)
assert.Equal(t, 8080, updated.ForwardPort)
}
func TestProxyHostUpdate_InvalidCertificateIDRejected(t *testing.T) {
t.Parallel()
router, db := setupUpdateTestRouter(t)
host := createTestProxyHost(t, db, "invalid-certificate-id")
updateBody := map[string]any{
"certificate_id": true,
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest(http.MethodPut, "/api/v1/proxy-hosts/"+host.UUID, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
var result map[string]any
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &result))
assert.Contains(t, result["error"], "invalid certificate_id")
}
func TestProxyHostUpdate_RejectsEmptyDomainNamesAndPreservesOriginal(t *testing.T) {
t.Parallel()
router, db := setupUpdateTestRouter(t)
@@ -649,3 +763,82 @@ func TestBulkUpdateSecurityHeaders_DBError_NonNotFound(t *testing.T) {
// The handler should return 500 when DB operations fail
require.Equal(t, http.StatusInternalServerError, resp.Code)
}
func TestParseNullableUintField(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value any
wantID *uint
wantErr bool
errContain string
}{
{name: "nil", value: nil, wantID: nil, wantErr: false},
{name: "float64", value: 5.0, wantID: func() *uint { v := uint(5); return &v }(), wantErr: false},
{name: "int", value: 9, wantID: func() *uint { v := uint(9); return &v }(), wantErr: false},
{name: "string", value: "12", wantID: func() *uint { v := uint(12); return &v }(), wantErr: false},
{name: "blank string", value: " ", wantID: nil, wantErr: false},
{name: "negative float", value: -1.0, wantErr: true, errContain: "invalid test_field"},
{name: "invalid string", value: "nope", wantErr: true, errContain: "invalid test_field"},
{name: "unsupported", value: true, wantErr: true, errContain: "invalid test_field"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
id, _, err := parseNullableUintField(tt.value, "test_field")
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContain)
return
}
require.NoError(t, err)
if tt.wantID == nil {
assert.Nil(t, id)
return
}
require.NotNil(t, id)
assert.Equal(t, *tt.wantID, *id)
})
}
}
func TestParseForwardPortField(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value any
wantPort int
wantErr bool
errContain string
}{
{name: "float integer", value: 8080.0, wantPort: 8080, wantErr: false},
{name: "float decimal", value: 8080.5, wantErr: true, errContain: "must be an integer"},
{name: "int", value: 3000, wantPort: 3000, wantErr: false},
{name: "int low", value: 0, wantErr: true, errContain: "between 1 and 65535"},
{name: "string", value: "443", wantPort: 443, wantErr: false},
{name: "string blank", value: " ", wantErr: true, errContain: "between 1 and 65535"},
{name: "string invalid", value: "abc", wantErr: true, errContain: "must be an integer"},
{name: "unsupported", value: true, wantErr: true, errContain: "unsupported type"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
port, err := parseForwardPortField(tt.value)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContain)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantPort, port)
})
}
}