hotfix(api): add UUID support to access list endpoints
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -27,6 +28,23 @@ func (h *AccessListHandler) SetGeoIPService(geoipSvc *services.GeoIPService) {
|
||||
h.service.SetGeoIPService(geoipSvc)
|
||||
}
|
||||
|
||||
// resolveAccessList resolves an access list by either numeric ID or UUID.
|
||||
// It first attempts to parse as uint (backward compatibility), then tries UUID.
|
||||
func (h *AccessListHandler) resolveAccessList(idOrUUID string) (*models.AccessList, error) {
|
||||
// Try parsing as numeric ID first (backward compatibility)
|
||||
if id, err := strconv.ParseUint(idOrUUID, 10, 32); err == nil {
|
||||
return h.service.GetByID(uint(id))
|
||||
}
|
||||
|
||||
// Empty string check
|
||||
if idOrUUID == "" {
|
||||
return nil, fmt.Errorf("invalid ID or UUID")
|
||||
}
|
||||
|
||||
// Try as UUID
|
||||
return h.service.GetByUUID(idOrUUID)
|
||||
}
|
||||
|
||||
// Create handles POST /api/v1/access-lists
|
||||
func (h *AccessListHandler) Create(c *gin.Context) {
|
||||
var acl models.AccessList
|
||||
@@ -55,19 +73,13 @@ func (h *AccessListHandler) List(c *gin.Context) {
|
||||
|
||||
// Get handles GET /api/v1/access-lists/:id
|
||||
func (h *AccessListHandler) Get(c *gin.Context) {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid ID"})
|
||||
return
|
||||
}
|
||||
|
||||
acl, err := h.service.GetByID(uint(id))
|
||||
acl, err := h.resolveAccessList(c.Param("id"))
|
||||
if err != nil {
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -76,9 +88,14 @@ func (h *AccessListHandler) Get(c *gin.Context) {
|
||||
|
||||
// Update handles PUT /api/v1/access-lists/:id
|
||||
func (h *AccessListHandler) Update(c *gin.Context) {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
// Resolve access list first to get the internal ID
|
||||
acl, err := h.resolveAccessList(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid ID"})
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,7 +105,7 @@ func (h *AccessListHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Update(uint(id), &updates); err != nil {
|
||||
if err := h.service.Update(acl.ID, &updates); err != nil {
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
@@ -98,19 +115,24 @@ func (h *AccessListHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Fetch updated record
|
||||
acl, _ := h.service.GetByID(uint(id))
|
||||
c.JSON(http.StatusOK, acl)
|
||||
updatedAcl, _ := h.service.GetByID(acl.ID)
|
||||
c.JSON(http.StatusOK, updatedAcl)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /api/v1/access-lists/:id
|
||||
func (h *AccessListHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
// Resolve access list first to get the internal ID
|
||||
acl, err := h.resolveAccessList(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid ID"})
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Delete(uint(id)); err != nil {
|
||||
if err := h.service.Delete(acl.ID); err != nil {
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
@@ -128,9 +150,14 @@ func (h *AccessListHandler) Delete(c *gin.Context) {
|
||||
|
||||
// TestIP handles POST /api/v1/access-lists/:id/test
|
||||
func (h *AccessListHandler) TestIP(c *gin.Context) {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
// Resolve access list first to get the internal ID
|
||||
acl, err := h.resolveAccessList(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid ID"})
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -142,12 +169,8 @@ func (h *AccessListHandler) TestIP(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
allowed, reason, err := h.service.TestIP(uint(id), req.IPAddress)
|
||||
allowed, reason, err := h.service.TestIP(acl.ID, req.IPAddress)
|
||||
if err != nil {
|
||||
if err == services.ErrAccessListNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "access list not found"})
|
||||
return
|
||||
}
|
||||
if err == services.ErrInvalidIPAddress {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid IP address"})
|
||||
return
|
||||
|
||||
@@ -41,23 +41,25 @@ func TestAccessListHandler_SetGeoIPService_Nil(t *testing.T) {
|
||||
func TestAccessListHandler_Get_InvalidID(t *testing.T) {
|
||||
router, _ := setupAccessListTestRouter(t)
|
||||
|
||||
// "invalid" is treated as a potential UUID, which doesn't exist, so 404 is returned
|
||||
req := httptest.NewRequest(http.MethodGet, "/access-lists/invalid", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestAccessListHandler_Update_InvalidID(t *testing.T) {
|
||||
router, _ := setupAccessListTestRouter(t)
|
||||
|
||||
// "invalid" is treated as a potential UUID, which doesn't exist, so 404 is returned
|
||||
body := []byte(`{"name":"Test","type":"whitelist"}`)
|
||||
req := httptest.NewRequest(http.MethodPut, "/access-lists/invalid", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestAccessListHandler_Update_InvalidJSON(t *testing.T) {
|
||||
@@ -78,23 +80,25 @@ func TestAccessListHandler_Update_InvalidJSON(t *testing.T) {
|
||||
func TestAccessListHandler_Delete_InvalidID(t *testing.T) {
|
||||
router, _ := setupAccessListTestRouter(t)
|
||||
|
||||
// "invalid" is treated as a potential UUID, which doesn't exist, so 404 is returned
|
||||
req := httptest.NewRequest(http.MethodDelete, "/access-lists/invalid", http.NoBody)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestAccessListHandler_TestIP_InvalidID(t *testing.T) {
|
||||
router, _ := setupAccessListTestRouter(t)
|
||||
|
||||
// "invalid" is treated as a potential UUID, which doesn't exist, so 404 is returned
|
||||
body := []byte(`{"ip_address":"192.168.1.1"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/access-lists/invalid/test", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestAccessListHandler_TestIP_MissingIPAddress(t *testing.T) {
|
||||
|
||||
@@ -160,15 +160,25 @@ func TestAccessListHandler_Get(t *testing.T) {
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "get existing ACL",
|
||||
name: "get existing ACL by numeric ID",
|
||||
id: "1",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "get non-existent ACL",
|
||||
name: "get existing ACL by UUID",
|
||||
id: "test-uuid",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "get non-existent ACL by numeric ID",
|
||||
id: "9999",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "get non-existent ACL by UUID",
|
||||
id: "non-existent-uuid",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -209,7 +219,7 @@ func TestAccessListHandler_Update(t *testing.T) {
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "update successfully",
|
||||
name: "update by numeric ID successfully",
|
||||
id: "1",
|
||||
payload: map[string]any{
|
||||
"name": "Updated Name",
|
||||
@@ -221,7 +231,19 @@ func TestAccessListHandler_Update(t *testing.T) {
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "update non-existent ACL",
|
||||
name: "update by UUID successfully",
|
||||
id: "test-uuid",
|
||||
payload: map[string]any{
|
||||
"name": "Updated via UUID",
|
||||
"description": "UUID update description",
|
||||
"enabled": true,
|
||||
"type": "whitelist",
|
||||
"ip_rules": `[]`,
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "update non-existent ACL by numeric ID",
|
||||
id: "9999",
|
||||
payload: map[string]any{
|
||||
"name": "Test",
|
||||
@@ -230,6 +252,16 @@ func TestAccessListHandler_Update(t *testing.T) {
|
||||
},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "update non-existent ACL by UUID",
|
||||
id: "non-existent-uuid",
|
||||
payload: map[string]any{
|
||||
"name": "Test",
|
||||
"type": "whitelist",
|
||||
"ip_rules": `[]`,
|
||||
},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -270,6 +302,15 @@ func TestAccessListHandler_Delete(t *testing.T) {
|
||||
}
|
||||
db.Create(&acl)
|
||||
|
||||
// Create ACL that will be deleted by UUID
|
||||
aclByUUID := models.AccessList{
|
||||
UUID: "delete-by-uuid",
|
||||
Name: "Delete By UUID ACL",
|
||||
Type: "whitelist",
|
||||
Enabled: true,
|
||||
}
|
||||
db.Create(&aclByUUID)
|
||||
|
||||
// Create ACL in use
|
||||
aclInUse := models.AccessList{
|
||||
UUID: "in-use-uuid",
|
||||
@@ -295,20 +336,30 @@ func TestAccessListHandler_Delete(t *testing.T) {
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "delete successfully",
|
||||
name: "delete by numeric ID successfully",
|
||||
id: "1",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "delete by UUID successfully",
|
||||
id: "delete-by-uuid",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "fail to delete ACL in use",
|
||||
id: "2",
|
||||
id: "3",
|
||||
wantStatus: http.StatusConflict,
|
||||
},
|
||||
{
|
||||
name: "delete non-existent ACL",
|
||||
name: "delete non-existent ACL by numeric ID",
|
||||
id: "9999",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "delete non-existent ACL by UUID",
|
||||
id: "non-existent-uuid",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -343,8 +394,14 @@ func TestAccessListHandler_TestIP(t *testing.T) {
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "test IP in whitelist",
|
||||
id: "1", // Use numeric ID
|
||||
name: "test IP in whitelist by numeric ID",
|
||||
id: "1",
|
||||
payload: map[string]string{"ip_address": "192.168.1.100"},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "test IP in whitelist by UUID",
|
||||
id: "test-uuid",
|
||||
payload: map[string]string{"ip_address": "192.168.1.100"},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
@@ -361,11 +418,17 @@ func TestAccessListHandler_TestIP(t *testing.T) {
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "test non-existent ACL",
|
||||
name: "test non-existent ACL by numeric ID",
|
||||
id: "9999",
|
||||
payload: map[string]string{"ip_address": "192.168.1.100"},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "test non-existent ACL by UUID",
|
||||
id: "non-existent-uuid",
|
||||
payload: map[string]string{"ip_address": "192.168.1.100"},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -413,3 +476,67 @@ func TestAccessListHandler_GetTemplates(t *testing.T) {
|
||||
assert.Contains(t, template, "type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessListHandler_resolveAccessList(t *testing.T) {
|
||||
_, db := setupAccessListTestRouter(t)
|
||||
|
||||
handler := NewAccessListHandler(db)
|
||||
|
||||
// Create test ACL with known UUID
|
||||
acl := models.AccessList{
|
||||
UUID: "resolve-test-uuid",
|
||||
Name: "Resolve Test ACL",
|
||||
Type: "whitelist",
|
||||
Enabled: true,
|
||||
}
|
||||
db.Create(&acl)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
idOrUUID string
|
||||
wantErr bool
|
||||
wantName string
|
||||
}{
|
||||
{
|
||||
name: "resolve by numeric ID",
|
||||
idOrUUID: "1",
|
||||
wantErr: false,
|
||||
wantName: "Resolve Test ACL",
|
||||
},
|
||||
{
|
||||
name: "resolve by UUID",
|
||||
idOrUUID: "resolve-test-uuid",
|
||||
wantErr: false,
|
||||
wantName: "Resolve Test ACL",
|
||||
},
|
||||
{
|
||||
name: "fail with non-existent numeric ID",
|
||||
idOrUUID: "9999",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail with non-existent UUID",
|
||||
idOrUUID: "non-existent-uuid",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail with empty string",
|
||||
idOrUUID: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := handler.resolveAccessList(tt.idOrUUID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tt.wantName, result.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user