feat: propagate request context in notification service and related handlers

This commit is contained in:
GitHub Actions
2025-11-30 22:56:31 +00:00
parent fe1e62a360
commit d27f28e20c
9 changed files with 107 additions and 71 deletions
@@ -92,7 +92,7 @@ func (h *CertificateHandler) Upload(c *gin.Context) {
// Send Notification
if h.notificationService != nil {
h.notificationService.SendExternal(
h.notificationService.SendExternal(c.Request.Context(),
"cert",
"Certificate Uploaded",
fmt.Sprintf("Certificate %s uploaded", cert.Name),
@@ -122,7 +122,7 @@ func (h *CertificateHandler) Delete(c *gin.Context) {
// Send Notification
if h.notificationService != nil {
h.notificationService.SendExternal(
h.notificationService.SendExternal(c.Request.Context(),
"cert",
"Certificate Deleted",
fmt.Sprintf("Certificate ID %d deleted", id),
@@ -52,7 +52,7 @@ func (h *DomainHandler) Create(c *gin.Context) {
// Send Notification
if h.notificationService != nil {
h.notificationService.SendExternal(
h.notificationService.SendExternal(c.Request.Context(),
"domain",
"Domain Added",
fmt.Sprintf("Domain %s added", domain.Name),
@@ -72,7 +72,7 @@ func (h *DomainHandler) Delete(c *gin.Context) {
if err := h.DB.Where("uuid = ?", id).First(&domain).Error; err == nil {
// Send Notification before delete (or after if we keep the name)
if h.notificationService != nil {
h.notificationService.SendExternal(
h.notificationService.SendExternal(c.Request.Context(),
"domain",
"Domain Deleted",
fmt.Sprintf("Domain %s deleted", domain.Name),
@@ -107,7 +107,7 @@ func (h *ProxyHostHandler) Create(c *gin.Context) {
// Send Notification
if h.notificationService != nil {
h.notificationService.SendExternal(
h.notificationService.SendExternal(c.Request.Context(),
"proxy_host",
"Proxy Host Created",
fmt.Sprintf("Proxy Host %s (%s) created", host.Name, host.DomainNames),
@@ -243,7 +243,7 @@ func (h *ProxyHostHandler) Delete(c *gin.Context) {
// Send Notification
if h.notificationService != nil {
h.notificationService.SendExternal(
h.notificationService.SendExternal(c.Request.Context(),
"proxy_host",
"Proxy Host Deleted",
fmt.Sprintf("Proxy Host %s deleted", host.Name),
@@ -68,7 +68,7 @@ func (h *RemoteServerHandler) Create(c *gin.Context) {
// Send Notification
if h.notificationService != nil {
h.notificationService.SendExternal(
h.notificationService.SendExternal(c.Request.Context(),
"remote_server",
"Remote Server Added",
fmt.Sprintf("Remote Server %s (%s:%d) added", server.Name, server.Host, server.Port),
@@ -136,17 +136,17 @@ func (h *RemoteServerHandler) Delete(c *gin.Context) {
}
// Send Notification
if h.notificationService != nil {
h.notificationService.SendExternal(
"remote_server",
"Remote Server Deleted",
fmt.Sprintf("Remote Server %s deleted", server.Name),
map[string]interface{}{
"Name": server.Name,
"Action": "deleted",
},
)
}
if h.notificationService != nil {
h.notificationService.SendExternal(c.Request.Context(),
"remote_server",
"Remote Server Deleted",
fmt.Sprintf("Remote Server %s deleted", server.Name),
map[string]interface{}{
"Name": server.Name,
"Action": "deleted",
},
)
}
c.JSON(http.StatusNoContent, nil)
}
+7 -6
View File
@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wikid82/charon/backend/internal/config"
"github.com/Wikid82/charon/backend/internal/logger"
"github.com/robfig/cron/v3"
)
@@ -31,7 +32,7 @@ func NewBackupService(cfg *config.Config) *BackupService {
// Ensure backup directory exists
backupDir := filepath.Join(filepath.Dir(cfg.DatabasePath), "backups")
if err := os.MkdirAll(backupDir, 0755); err != nil {
fmt.Printf("Failed to create backup directory: %v\n", err)
logger.Log().WithError(err).Error("Failed to create backup directory")
}
s := &BackupService{
@@ -44,7 +45,7 @@ func NewBackupService(cfg *config.Config) *BackupService {
// Schedule daily backup at 3 AM
_, err := s.Cron.AddFunc("0 3 * * *", s.RunScheduledBackup)
if err != nil {
fmt.Printf("Failed to schedule backup: %v\n", err)
logger.Log().WithError(err).Error("Failed to schedule backup")
}
s.Cron.Start()
@@ -52,11 +53,11 @@ func NewBackupService(cfg *config.Config) *BackupService {
}
func (s *BackupService) RunScheduledBackup() {
fmt.Println("Starting scheduled backup...")
logger.Log().Info("Starting scheduled backup")
if name, err := s.CreateBackup(); err != nil {
fmt.Printf("Scheduled backup failed: %v\n", err)
logger.Log().WithError(err).Error("Scheduled backup failed")
} else {
fmt.Printf("Scheduled backup created: %s\n", name)
logger.Log().WithField("backup", name).Info("Scheduled backup created")
}
}
@@ -123,7 +124,7 @@ func (s *BackupService) CreateBackup() (string, error) {
caddyDir := filepath.Join(s.DataDir, "caddy")
if err := s.addDirToZip(w, caddyDir, "caddy"); err != nil {
// It's possible caddy dir doesn't exist yet, which is fine
fmt.Printf("Warning: could not backup caddy dir: %v\n", err)
logger.Log().WithError(err).Warn("Warning: could not backup caddy dir")
}
// Close zip writer and check for errors (important for zip integrity)
@@ -4,7 +4,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"github.com/Wikid82/charon/backend/internal/logger"
"os"
"path/filepath"
"strings"
@@ -50,7 +50,7 @@ func NewCertificateService(dataDir string, db *gorm.DB) *CertificateService {
// Perform initial scan in background
go func() {
if err := svc.SyncFromDisk(); err != nil {
log.Printf("CertificateService: initial sync failed: %v", err)
logger.Log().WithError(err).Error("CertificateService: initial sync failed")
}
}()
return svc
@@ -63,7 +63,7 @@ func (s *CertificateService) SyncFromDisk() error {
defer s.cacheMu.Unlock()
certRoot := filepath.Join(s.dataDir, "certificates")
log.Printf("CertificateService: scanning cert directory: %s", certRoot)
logger.Log().WithField("certRoot", certRoot).Info("CertificateService: scanning cert directory")
foundDomains := map[string]struct{}{}
@@ -71,14 +71,14 @@ func (s *CertificateService) SyncFromDisk() error {
if _, err := os.Stat(certRoot); err == nil {
_ = filepath.Walk(certRoot, func(path string, info os.FileInfo, err error) error {
if err != nil {
log.Printf("CertificateService: walk error for %s: %v\n", path, err)
logger.Log().WithField("path", path).WithError(err).Error("CertificateService: walk error")
return nil
}
if !info.IsDir() && strings.HasSuffix(info.Name(), ".crt") {
certData, err := os.ReadFile(path)
if err != nil {
log.Printf("CertificateService: failed to read cert file %s: %v", path, err)
logger.Log().WithField("path", path).WithError(err).Error("CertificateService: failed to read cert file")
return nil
}
@@ -90,7 +90,7 @@ func (s *CertificateService) SyncFromDisk() error {
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Printf("CertificateService: failed to parse cert %s: %v", path, err)
logger.Log().WithField("path", path).WithError(err).Error("CertificateService: failed to parse cert")
return nil
}
@@ -134,10 +134,10 @@ func (s *CertificateService) SyncFromDisk() error {
UpdatedAt: now,
}
if err := s.db.Create(&newCert).Error; err != nil {
log.Printf("CertificateService: failed to create DB cert for %s: %v\n", domain, err)
logger.Log().WithField("domain", domain).WithError(err).Error("CertificateService: failed to create DB cert")
}
} else {
log.Printf("CertificateService: db error querying cert %s: %v\n", domain, res.Error)
logger.Log().WithField("domain", domain).WithError(res.Error).Error("CertificateService: db error querying cert")
}
} else {
// Update expiry/certificate content and provider if changed
@@ -169,12 +169,12 @@ func (s *CertificateService) SyncFromDisk() error {
if updated {
existing.UpdatedAt = time.Now()
if err := s.db.Save(&existing).Error; err != nil {
log.Printf("CertificateService: failed to update DB cert for %s: %v\n", domain, err)
logger.Log().WithField("domain", domain).WithError(err).Error("CertificateService: failed to update DB cert")
}
} else {
// still update ExpiresAt if needed
if err := s.db.Model(&existing).Update("expires_at", &expiresAt).Error; err != nil {
log.Printf("CertificateService: failed to update expiry for %s: %v\n", domain, err)
logger.Log().WithField("domain", domain).WithError(err).Error("CertificateService: failed to update expiry")
}
}
}
@@ -183,9 +183,9 @@ func (s *CertificateService) SyncFromDisk() error {
})
} else {
if os.IsNotExist(err) {
log.Printf("CertificateService: cert directory does not exist: %s\n", certRoot)
logger.Log().WithField("certRoot", certRoot).Info("CertificateService: cert directory does not exist")
} else {
log.Printf("CertificateService: failed to stat cert directory: %v\n", err)
logger.Log().WithError(err).Error("CertificateService: failed to stat cert directory")
}
}
@@ -195,10 +195,10 @@ func (s *CertificateService) SyncFromDisk() error {
for _, c := range acmeCerts {
if _, ok := foundDomains[c.Domains]; !ok {
// remove stale record
if err := s.db.Delete(&models.SSLCertificate{}, "id = ?", c.ID).Error; err != nil {
log.Printf("CertificateService: failed to delete stale cert %s: %v\n", c.Domains, err)
if err := s.db.Delete(&models.SSLCertificate{}, "id = ?", c.ID).Error; err != nil {
logger.Log().WithField("domain", c.Domains).WithError(err).Error("CertificateService: failed to delete stale cert")
} else {
log.Printf("CertificateService: removed stale DB cert for %s\n", c.Domains)
logger.Log().WithField("domain", c.Domains).Info("CertificateService: removed stale DB cert")
}
}
}
@@ -211,7 +211,7 @@ func (s *CertificateService) SyncFromDisk() error {
s.lastScan = time.Now()
s.initialized = true
log.Printf("CertificateService: disk sync complete, %d certificates cached", len(s.cache))
logger.Log().WithField("count", len(s.cache)).Info("CertificateService: disk sync complete")
return nil
}
@@ -319,7 +319,7 @@ func (s *CertificateService) ListCertificates() ([]CertificateInfo, error) {
// Trigger background rescan for stale cache
go func() {
if err := s.SyncFromDisk(); err != nil {
log.Printf("CertificateService: background sync failed: %v", err)
logger.Log().WithError(err).Error("CertificateService: background sync failed")
}
}()
}
@@ -396,9 +396,9 @@ func (s *CertificateService) DeleteCertificate(id uint) error {
if err == nil && !info.IsDir() && strings.HasSuffix(info.Name(), ".crt") {
if info.Name() == cert.Domains+".crt" {
// Found it
log.Printf("CertificateService: deleting ACME cert file %s", path)
logger.Log().WithField("path", path).Info("CertificateService: deleting ACME cert file")
if err := os.Remove(path); err != nil {
log.Printf("CertificateService: failed to delete cert file: %v", err)
logger.Log().WithError(err).Error("CertificateService: failed to delete cert file")
}
// Try to delete key as well
keyPath := strings.TrimSuffix(path, ".crt") + ".key"
@@ -3,7 +3,8 @@ package services
import (
"bytes"
"fmt"
"log"
"context"
"github.com/Wikid82/charon/backend/internal/logger"
"net"
neturl "net/url"
"strings"
@@ -73,10 +74,10 @@ func (s *NotificationService) MarkAllAsRead() error {
// External Notifications (Shoutrrr & Custom Webhooks)
func (s *NotificationService) SendExternal(eventType, title, message string, data map[string]interface{}) {
func (s *NotificationService) SendExternal(ctx context.Context, eventType, title, message string, data map[string]interface{}) {
var providers []models.NotificationProvider
if err := s.DB.Where("enabled = ?", true).Find(&providers).Error; err != nil {
log.Printf("Failed to fetch notification providers: %v", err)
logger.Log().WithError(err).Error("Failed to fetch notification providers")
return
}
@@ -118,29 +119,29 @@ func (s *NotificationService) SendExternal(eventType, title, message string, dat
go func(p models.NotificationProvider) {
if p.Type == "webhook" {
if err := s.sendCustomWebhook(p, data); err != nil {
log.Printf("Failed to send webhook to %s: %v", p.Name, err)
if err := s.sendCustomWebhook(ctx, p, data); err != nil {
logger.Log().WithError(err).WithField("provider", p.Name).Error("Failed to send webhook")
}
} else {
url := normalizeURL(p.Type, p.URL)
// Validate HTTP/HTTPS destinations used by shoutrrr to reduce SSRF risk
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
if _, err := validateWebhookURL(url); err != nil {
log.Printf("Skipping notification for provider %s due to invalid destination", p.Name)
logger.Log().WithField("provider", p.Name).Warn("Skipping notification for provider due to invalid destination")
return
}
}
// Use newline for better formatting in chat apps
msg := fmt.Sprintf("%s\n\n%s", title, message)
if err := shoutrrr.Send(url, msg); err != nil {
log.Printf("Failed to send notification to %s: %v", p.Name, err)
logger.Log().WithError(err).WithField("provider", p.Name).Error("Failed to send notification")
}
}
}(provider)
}
}
func (s *NotificationService) sendCustomWebhook(p models.NotificationProvider, data map[string]interface{}) error {
func (s *NotificationService) sendCustomWebhook(ctx context.Context, p models.NotificationProvider, data map[string]interface{}) error {
// Built-in templates
const minimalTemplate = `{"message": {{toJSON .Message}}, "title": {{toJSON .Title}}, "time": {{toJSON .Time}}, "event": {{toJSON .EventType}}}`
const detailedTemplate = `{"title": {{toJSON .Title}}, "message": {{toJSON .Message}}, "time": {{toJSON .Time}}, "event": {{toJSON .EventType}}, "host": {{toJSON .HostName}}, "host_ip": {{toJSON .HostIP}}, "service_count": {{toJSON .ServiceCount}}, "services": {{toJSON .Services}}, "data": {{toJSON .}}}`
@@ -249,11 +250,17 @@ func (s *NotificationService) sendCustomWebhook(p models.NotificationProvider, d
Path: u.Path,
RawQuery: u.RawQuery,
}
req, err := http.NewRequest("POST", safeURL.String(), &body)
req, err := http.NewRequestWithContext(ctx, "POST", safeURL.String(), &body)
if err != nil {
return fmt.Errorf("failed to create webhook request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Propagate request id header if present in context
if rid := ctx.Value("requestID"); rid != nil {
if ridStr, ok := rid.(string); ok {
req.Header.Set("X-Request-ID", ridStr)
}
}
// Preserve original hostname for virtual host (Host header)
req.Host = u.Host
@@ -344,7 +351,7 @@ func (s *NotificationService) TestProvider(provider models.NotificationProvider)
"Latency": 123,
"Time": time.Now().Format(time.RFC3339),
}
return s.sendCustomWebhook(provider, data)
return s.sendCustomWebhook(context.Background(), provider, data)
}
url := normalizeURL(provider.Type, provider.URL)
return shoutrrr.Send(url, "Test notification from Charon")
@@ -3,6 +3,7 @@ package services
import (
"encoding/json"
"fmt"
"context"
"net/http"
"sync/atomic"
"testing"
@@ -165,7 +166,7 @@ func TestNotificationService_SendExternal(t *testing.T) {
}
svc.CreateProvider(&provider)
svc.SendExternal("proxy_host", "Title", "Message", nil)
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
select {
case <-received:
@@ -200,7 +201,7 @@ func TestNotificationService_SendExternal_MinimalVsDetailedTemplates(t *testing.
svc.CreateProvider(&providerMin)
data := map[string]interface{}{"Title": "Min Title", "Message": "Min Message", "Time": time.Now().Format(time.RFC3339), "EventType": "uptime"}
svc.SendExternal("uptime", "Min Title", "Min Message", data)
svc.SendExternal(context.Background(), "uptime", "Min Title", "Min Message", data)
select {
case body := <-rcvMinimal:
@@ -235,7 +236,7 @@ func TestNotificationService_SendExternal_MinimalVsDetailedTemplates(t *testing.
svc.CreateProvider(&providerDet)
dataDet := map[string]interface{}{"Title": "Det Title", "Message": "Det Message", "Time": time.Now().Format(time.RFC3339), "EventType": "uptime", "HostName": "example-host", "HostIP": "1.2.3.4", "ServiceCount": 1, "Services": []map[string]interface{}{{"Name": "svc1"}}}
svc.SendExternal("uptime", "Det Title", "Det Message", dataDet)
svc.SendExternal(context.Background(), "uptime", "Det Title", "Det Message", dataDet)
select {
case body := <-rcvDetailed:
@@ -275,7 +276,7 @@ func TestNotificationService_SendExternal_Filtered(t *testing.T) {
// Force update to false because GORM default tag might override zero value (false) on Create
db.Model(&provider).Update("notify_proxy_hosts", false)
svc.SendExternal("proxy_host", "Title", "Message", nil)
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
select {
case <-received:
@@ -299,7 +300,7 @@ func TestNotificationService_SendExternal_Shoutrrr(t *testing.T) {
svc.CreateProvider(&provider)
// This will log an error but should cover the code path
svc.SendExternal("proxy_host", "Title", "Message", nil)
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
// Give it a moment to run goroutine
time.Sleep(100 * time.Millisecond)
@@ -356,7 +357,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
URL: "://invalid-url",
}
data := map[string]interface{}{"Title": "Test", "Message": "Test Message"}
err := svc.sendCustomWebhook(provider, data)
err := svc.sendCustomWebhook(context.Background(), provider, data)
assert.Error(t, err)
})
@@ -373,7 +374,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
// But for unit test speed, we should probably mock or use a closed port on localhost
// Using a closed port on localhost is faster
provider.URL = "http://127.0.0.1:54321" // Assuming this port is closed
err := svc.sendCustomWebhook(provider, data)
err := svc.sendCustomWebhook(context.Background(), provider, data)
assert.Error(t, err)
})
@@ -388,7 +389,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
URL: ts.URL,
}
data := map[string]interface{}{"Title": "Test", "Message": "Test Message"}
err := svc.sendCustomWebhook(provider, data)
err := svc.sendCustomWebhook(context.Background(), provider, data)
assert.Error(t, err)
assert.Contains(t, err.Error(), "500")
})
@@ -413,7 +414,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
Config: `{"custom": "Test: {{.Title}}"}`,
}
data := map[string]interface{}{"Title": "My Title", "Message": "Test Message"}
svc.sendCustomWebhook(provider, data)
svc.sendCustomWebhook(context.Background(), provider, data)
select {
case <-received:
@@ -443,7 +444,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
// Config is empty, so default template is used: minimal
}
data := map[string]interface{}{"Title": "Default Title", "Message": "Test Message"}
svc.sendCustomWebhook(provider, data)
svc.sendCustomWebhook(context.Background(), provider, data)
select {
case <-received:
@@ -454,6 +455,32 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
})
}
func TestNotificationService_SendCustomWebhook_PropagatesRequestID(t *testing.T) {
db := setupNotificationTestDB(t)
svc := NewNotificationService(db)
received := make(chan string, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
received <- r.Header.Get("X-Request-ID")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
provider := models.NotificationProvider{Type: "webhook", URL: ts.URL}
data := map[string]interface{}{"Title": "Test", "Message": "Test"}
// Build context with requestID value
ctx := context.WithValue(context.Background(), "requestID", "my-rid")
err := svc.sendCustomWebhook(ctx, provider, data)
require.NoError(t, err)
select {
case rid := <-received:
assert.Equal(t, "my-rid", rid)
case <-time.After(500 * time.Millisecond):
t.Fatal("Timed out waiting for webhook request")
}
}
func TestNotificationService_TestProvider_Errors(t *testing.T) {
db := setupNotificationTestDB(t)
svc := NewNotificationService(db)
@@ -537,7 +564,7 @@ func TestNotificationService_SendExternal_EdgeCases(t *testing.T) {
svc.CreateProvider(&provider)
// Should complete without error
svc.SendExternal("proxy_host", "Title", "Message", nil)
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
time.Sleep(50 * time.Millisecond)
})
@@ -580,9 +607,9 @@ func TestNotificationService_SendExternal_EdgeCases(t *testing.T) {
require.False(t, saved.NotifyUptime, "NotifyUptime should be false")
require.False(t, saved.NotifyCerts, "NotifyCerts should be false")
svc.SendExternal("proxy_host", "Title", "Message", nil)
svc.SendExternal("uptime", "Title", "Message", nil)
svc.SendExternal("cert", "Title", "Message", nil)
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
svc.SendExternal(context.Background(), "uptime", "Title", "Message", nil)
svc.SendExternal(context.Background(), "cert", "Title", "Message", nil)
time.Sleep(50 * time.Millisecond)
})
@@ -615,7 +642,7 @@ func TestNotificationService_SendExternal_EdgeCases(t *testing.T) {
customData := map[string]interface{}{
"CustomField": "test-value",
}
svc.SendExternal("proxy_host", "Title", "Message", customData)
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", customData)
time.Sleep(100 * time.Millisecond)
assert.Equal(t, "test-value", receivedCustom.Load().(string))
+4 -3
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"encoding/json"
"fmt"
"github.com/Wikid82/charon/backend/internal/logger"
@@ -521,7 +522,7 @@ func (s *UptimeService) sendHostDownNotification(host *models.UptimeHost, downMo
"Services": downMonitors,
"Time": time.Now().Format(time.RFC1123),
}
s.NotificationService.SendExternal("uptime", title, sb.String(), data)
s.NotificationService.SendExternal(context.Background(), "uptime", title, sb.String(), data)
logger.Log().WithField("host_name", host.Name).WithField("service_count", len(downMonitors)).Info("Sent consolidated DOWN notification")
}
@@ -751,7 +752,7 @@ func (s *UptimeService) flushPendingNotification(hostID string) {
"Services": pending.downMonitors,
"Time": time.Now().Format(time.RFC1123),
}
s.NotificationService.SendExternal("uptime", title, sb.String(), data)
s.NotificationService.SendExternal(context.Background(), "uptime", title, sb.String(), data)
logger.Log().WithField("count", len(pending.downMonitors)).WithField("host", pending.hostName).Info("Sent batched DOWN notification")
}
@@ -781,7 +782,7 @@ func (s *UptimeService) sendRecoveryNotification(monitor models.UptimeMonitor, d
"Time": time.Now().Format(time.RFC1123),
"URL": monitor.URL,
}
s.NotificationService.SendExternal("uptime", title, sb.String(), data)
s.NotificationService.SendExternal(context.Background(), "uptime", title, sb.String(), data)
}
// FlushPendingNotifications flushes all pending batched notifications immediately.