diff --git a/backend/cmd/api/main_test.go b/backend/cmd/api/main_test.go index 17059b38..2d9481c1 100644 --- a/backend/cmd/api/main_test.go +++ b/backend/cmd/api/main_test.go @@ -4,7 +4,9 @@ import ( "os" "os/exec" "path/filepath" + "syscall" "testing" + "time" "github.com/Wikid82/charon/backend/internal/database" "github.com/Wikid82/charon/backend/internal/models" @@ -281,3 +283,74 @@ func TestMain_ResetPasswordCommand_InProcess(t *testing.T) { t.Fatalf("expected failed login attempts reset to 0, got %d", updated.FailedLoginAttempts) } } + +func TestMain_DefaultStartupGracefulShutdown_Subprocess(t *testing.T) { + if os.Getenv("CHARON_TEST_RUN_MAIN_SERVER") == "1" { + os.Args = []string{"charon"} + + go func() { + time.Sleep(500 * time.Millisecond) + process, err := os.FindProcess(os.Getpid()) + if err == nil { + _ = process.Signal(syscall.SIGTERM) + } + }() + + main() + return + } + + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "data", "test.db") + if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { + t.Fatalf("mkdir db dir: %v", err) + } + + cmd := exec.Command(os.Args[0], "-test.run=TestMain_DefaultStartupGracefulShutdown_Subprocess") //nolint:gosec // G204: Test subprocess pattern using os.Args[0] is safe + cmd.Dir = tmp + cmd.Env = append(os.Environ(), + "CHARON_TEST_RUN_MAIN_SERVER=1", + "CHARON_DB_PATH="+dbPath, + "CHARON_HTTP_PORT=0", + "CHARON_EMERGENCY_SERVER_ENABLED=false", + "CHARON_CADDY_CONFIG_DIR="+filepath.Join(tmp, "caddy"), + "CHARON_IMPORT_DIR="+filepath.Join(tmp, "imports"), + "CHARON_IMPORT_CADDYFILE="+filepath.Join(tmp, "imports", "does-not-exist", "Caddyfile"), + "CHARON_FRONTEND_DIR="+filepath.Join(tmp, "frontend", "dist"), + ) + + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("expected startup/shutdown to exit 0; err=%v; output=%s", err, string(out)) + } +} + +func TestMain_DefaultStartupGracefulShutdown_InProcess(t *testing.T) { + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "data", "test.db") + if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { + t.Fatalf("mkdir db dir: %v", err) + } + + originalArgs := os.Args + t.Cleanup(func() { os.Args = originalArgs }) + + t.Setenv("CHARON_DB_PATH", dbPath) + t.Setenv("CHARON_HTTP_PORT", "0") + t.Setenv("CHARON_EMERGENCY_SERVER_ENABLED", "false") + t.Setenv("CHARON_CADDY_CONFIG_DIR", filepath.Join(tmp, "caddy")) + t.Setenv("CHARON_IMPORT_DIR", filepath.Join(tmp, "imports")) + t.Setenv("CHARON_IMPORT_CADDYFILE", filepath.Join(tmp, "imports", "does-not-exist", "Caddyfile")) + t.Setenv("CHARON_FRONTEND_DIR", filepath.Join(tmp, "frontend", "dist")) + os.Args = []string{"charon"} + + go func() { + time.Sleep(500 * time.Millisecond) + process, err := os.FindProcess(os.Getpid()) + if err == nil { + _ = process.Signal(syscall.SIGTERM) + } + }() + + main() +} diff --git a/backend/internal/api/handlers/auth_handler_test.go b/backend/internal/api/handlers/auth_handler_test.go index afe25a03..4241adea 100644 --- a/backend/internal/api/handlers/auth_handler_test.go +++ b/backend/internal/api/handlers/auth_handler_test.go @@ -184,6 +184,132 @@ func TestSetSecureCookie_OriginLoopbackForcesInsecure(t *testing.T) { assert.Equal(t, http.SameSiteLaxMode, cookie.SameSite) } +func TestIsProduction(t *testing.T) { + t.Setenv("CHARON_ENV", "production") + assert.True(t, isProduction()) + + t.Setenv("CHARON_ENV", "prod") + assert.True(t, isProduction()) + + t.Setenv("CHARON_ENV", "development") + assert.False(t, isProduction()) +} + +func TestRequestScheme(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("forwarded proto first value wins", func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest("GET", "http://example.com", http.NoBody) + req.Header.Set("X-Forwarded-Proto", "HTTPS, http") + ctx.Request = req + + assert.Equal(t, "https", requestScheme(ctx)) + }) + + t.Run("tls request", func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest("GET", "https://example.com", http.NoBody) + req.TLS = &tls.ConnectionState{} + ctx.Request = req + + assert.Equal(t, "https", requestScheme(ctx)) + }) + + t.Run("url scheme fallback", func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest("GET", "http://example.com", http.NoBody) + req.URL.Scheme = "HTTP" + ctx.Request = req + + assert.Equal(t, "http", requestScheme(ctx)) + }) + + t.Run("default http fallback", func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest("GET", "/", http.NoBody) + req.URL.Scheme = "" + ctx.Request = req + + assert.Equal(t, "http", requestScheme(ctx)) + }) +} + +func TestHostHelpers(t *testing.T) { + t.Run("normalizeHost", func(t *testing.T) { + assert.Equal(t, "", normalizeHost(" ")) + assert.Equal(t, "example.com", normalizeHost("example.com:8080")) + assert.Equal(t, "::1", normalizeHost("[::1]:2020")) + assert.Equal(t, "localhost", normalizeHost("localhost")) + }) + + t.Run("originHost", func(t *testing.T) { + assert.Equal(t, "", originHost("")) + assert.Equal(t, "", originHost("::://bad-url")) + assert.Equal(t, "localhost", originHost("http://localhost:8080/path")) + }) + + t.Run("isLocalHost", func(t *testing.T) { + assert.True(t, isLocalHost("localhost")) + assert.True(t, isLocalHost("127.0.0.1")) + assert.True(t, isLocalHost("::1")) + assert.False(t, isLocalHost("example.com")) + }) +} + +func TestIsLocalRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("forwarded host list includes localhost", func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest("GET", "http://example.com", http.NoBody) + req.Host = "example.com" + req.Header.Set("X-Forwarded-Host", "example.com, localhost:8080") + ctx.Request = req + + assert.True(t, isLocalRequest(ctx)) + }) + + t.Run("origin loopback", func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest("GET", "http://example.com", http.NoBody) + req.Header.Set("Origin", "http://127.0.0.1:3000") + ctx.Request = req + + assert.True(t, isLocalRequest(ctx)) + }) + + t.Run("non local request", func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest("GET", "http://example.com", http.NoBody) + req.Host = "example.com" + ctx.Request = req + + assert.False(t, isLocalRequest(ctx)) + }) +} + +func TestClearSecureCookie(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest("POST", "http://example.com/logout", http.NoBody) + + clearSecureCookie(ctx, "auth_token") + + cookies := recorder.Result().Cookies() + require.Len(t, cookies, 1) + assert.Equal(t, "auth_token", cookies[0].Name) + assert.Equal(t, -1, cookies[0].MaxAge) +} + func TestAuthHandler_Login_Errors(t *testing.T) { t.Parallel() handler, _ := setupAuthHandler(t)