feat: propagate request context in notification service and related handlers
This commit is contained in:
@@ -92,7 +92,7 @@ func (h *CertificateHandler) Upload(c *gin.Context) {
|
||||
|
||||
// Send Notification
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"cert",
|
||||
"Certificate Uploaded",
|
||||
fmt.Sprintf("Certificate %s uploaded", cert.Name),
|
||||
@@ -122,7 +122,7 @@ func (h *CertificateHandler) Delete(c *gin.Context) {
|
||||
|
||||
// Send Notification
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"cert",
|
||||
"Certificate Deleted",
|
||||
fmt.Sprintf("Certificate ID %d deleted", id),
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *DomainHandler) Create(c *gin.Context) {
|
||||
|
||||
// Send Notification
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"domain",
|
||||
"Domain Added",
|
||||
fmt.Sprintf("Domain %s added", domain.Name),
|
||||
@@ -72,7 +72,7 @@ func (h *DomainHandler) Delete(c *gin.Context) {
|
||||
if err := h.DB.Where("uuid = ?", id).First(&domain).Error; err == nil {
|
||||
// Send Notification before delete (or after if we keep the name)
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"domain",
|
||||
"Domain Deleted",
|
||||
fmt.Sprintf("Domain %s deleted", domain.Name),
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *ProxyHostHandler) Create(c *gin.Context) {
|
||||
|
||||
// Send Notification
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"proxy_host",
|
||||
"Proxy Host Created",
|
||||
fmt.Sprintf("Proxy Host %s (%s) created", host.Name, host.DomainNames),
|
||||
@@ -243,7 +243,7 @@ func (h *ProxyHostHandler) Delete(c *gin.Context) {
|
||||
|
||||
// Send Notification
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"proxy_host",
|
||||
"Proxy Host Deleted",
|
||||
fmt.Sprintf("Proxy Host %s deleted", host.Name),
|
||||
|
||||
@@ -68,7 +68,7 @@ func (h *RemoteServerHandler) Create(c *gin.Context) {
|
||||
|
||||
// Send Notification
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"remote_server",
|
||||
"Remote Server Added",
|
||||
fmt.Sprintf("Remote Server %s (%s:%d) added", server.Name, server.Host, server.Port),
|
||||
@@ -136,17 +136,17 @@ func (h *RemoteServerHandler) Delete(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Send Notification
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(
|
||||
"remote_server",
|
||||
"Remote Server Deleted",
|
||||
fmt.Sprintf("Remote Server %s deleted", server.Name),
|
||||
map[string]interface{}{
|
||||
"Name": server.Name,
|
||||
"Action": "deleted",
|
||||
},
|
||||
)
|
||||
}
|
||||
if h.notificationService != nil {
|
||||
h.notificationService.SendExternal(c.Request.Context(),
|
||||
"remote_server",
|
||||
"Remote Server Deleted",
|
||||
fmt.Sprintf("Remote Server %s deleted", server.Name),
|
||||
map[string]interface{}{
|
||||
"Name": server.Name,
|
||||
"Action": "deleted",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wikid82/charon/backend/internal/config"
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
@@ -31,7 +32,7 @@ func NewBackupService(cfg *config.Config) *BackupService {
|
||||
// Ensure backup directory exists
|
||||
backupDir := filepath.Join(filepath.Dir(cfg.DatabasePath), "backups")
|
||||
if err := os.MkdirAll(backupDir, 0755); err != nil {
|
||||
fmt.Printf("Failed to create backup directory: %v\n", err)
|
||||
logger.Log().WithError(err).Error("Failed to create backup directory")
|
||||
}
|
||||
|
||||
s := &BackupService{
|
||||
@@ -44,7 +45,7 @@ func NewBackupService(cfg *config.Config) *BackupService {
|
||||
// Schedule daily backup at 3 AM
|
||||
_, err := s.Cron.AddFunc("0 3 * * *", s.RunScheduledBackup)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to schedule backup: %v\n", err)
|
||||
logger.Log().WithError(err).Error("Failed to schedule backup")
|
||||
}
|
||||
s.Cron.Start()
|
||||
|
||||
@@ -52,11 +53,11 @@ func NewBackupService(cfg *config.Config) *BackupService {
|
||||
}
|
||||
|
||||
func (s *BackupService) RunScheduledBackup() {
|
||||
fmt.Println("Starting scheduled backup...")
|
||||
logger.Log().Info("Starting scheduled backup")
|
||||
if name, err := s.CreateBackup(); err != nil {
|
||||
fmt.Printf("Scheduled backup failed: %v\n", err)
|
||||
logger.Log().WithError(err).Error("Scheduled backup failed")
|
||||
} else {
|
||||
fmt.Printf("Scheduled backup created: %s\n", name)
|
||||
logger.Log().WithField("backup", name).Info("Scheduled backup created")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,7 +124,7 @@ func (s *BackupService) CreateBackup() (string, error) {
|
||||
caddyDir := filepath.Join(s.DataDir, "caddy")
|
||||
if err := s.addDirToZip(w, caddyDir, "caddy"); err != nil {
|
||||
// It's possible caddy dir doesn't exist yet, which is fine
|
||||
fmt.Printf("Warning: could not backup caddy dir: %v\n", err)
|
||||
logger.Log().WithError(err).Warn("Warning: could not backup caddy dir")
|
||||
}
|
||||
|
||||
// Close zip writer and check for errors (important for zip integrity)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -50,7 +50,7 @@ func NewCertificateService(dataDir string, db *gorm.DB) *CertificateService {
|
||||
// Perform initial scan in background
|
||||
go func() {
|
||||
if err := svc.SyncFromDisk(); err != nil {
|
||||
log.Printf("CertificateService: initial sync failed: %v", err)
|
||||
logger.Log().WithError(err).Error("CertificateService: initial sync failed")
|
||||
}
|
||||
}()
|
||||
return svc
|
||||
@@ -63,7 +63,7 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
defer s.cacheMu.Unlock()
|
||||
|
||||
certRoot := filepath.Join(s.dataDir, "certificates")
|
||||
log.Printf("CertificateService: scanning cert directory: %s", certRoot)
|
||||
logger.Log().WithField("certRoot", certRoot).Info("CertificateService: scanning cert directory")
|
||||
|
||||
foundDomains := map[string]struct{}{}
|
||||
|
||||
@@ -71,14 +71,14 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
if _, err := os.Stat(certRoot); err == nil {
|
||||
_ = filepath.Walk(certRoot, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Printf("CertificateService: walk error for %s: %v\n", path, err)
|
||||
logger.Log().WithField("path", path).WithError(err).Error("CertificateService: walk error")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".crt") {
|
||||
certData, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
log.Printf("CertificateService: failed to read cert file %s: %v", path, err)
|
||||
logger.Log().WithField("path", path).WithError(err).Error("CertificateService: failed to read cert file")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("CertificateService: failed to parse cert %s: %v", path, err)
|
||||
logger.Log().WithField("path", path).WithError(err).Error("CertificateService: failed to parse cert")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -134,10 +134,10 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := s.db.Create(&newCert).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to create DB cert for %s: %v\n", domain, err)
|
||||
logger.Log().WithField("domain", domain).WithError(err).Error("CertificateService: failed to create DB cert")
|
||||
}
|
||||
} else {
|
||||
log.Printf("CertificateService: db error querying cert %s: %v\n", domain, res.Error)
|
||||
logger.Log().WithField("domain", domain).WithError(res.Error).Error("CertificateService: db error querying cert")
|
||||
}
|
||||
} else {
|
||||
// Update expiry/certificate content and provider if changed
|
||||
@@ -169,12 +169,12 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
if updated {
|
||||
existing.UpdatedAt = time.Now()
|
||||
if err := s.db.Save(&existing).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to update DB cert for %s: %v\n", domain, err)
|
||||
logger.Log().WithField("domain", domain).WithError(err).Error("CertificateService: failed to update DB cert")
|
||||
}
|
||||
} else {
|
||||
// still update ExpiresAt if needed
|
||||
if err := s.db.Model(&existing).Update("expires_at", &expiresAt).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to update expiry for %s: %v\n", domain, err)
|
||||
logger.Log().WithField("domain", domain).WithError(err).Error("CertificateService: failed to update expiry")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -183,9 +183,9 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
})
|
||||
} else {
|
||||
if os.IsNotExist(err) {
|
||||
log.Printf("CertificateService: cert directory does not exist: %s\n", certRoot)
|
||||
logger.Log().WithField("certRoot", certRoot).Info("CertificateService: cert directory does not exist")
|
||||
} else {
|
||||
log.Printf("CertificateService: failed to stat cert directory: %v\n", err)
|
||||
logger.Log().WithError(err).Error("CertificateService: failed to stat cert directory")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,10 +195,10 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
for _, c := range acmeCerts {
|
||||
if _, ok := foundDomains[c.Domains]; !ok {
|
||||
// remove stale record
|
||||
if err := s.db.Delete(&models.SSLCertificate{}, "id = ?", c.ID).Error; err != nil {
|
||||
log.Printf("CertificateService: failed to delete stale cert %s: %v\n", c.Domains, err)
|
||||
if err := s.db.Delete(&models.SSLCertificate{}, "id = ?", c.ID).Error; err != nil {
|
||||
logger.Log().WithField("domain", c.Domains).WithError(err).Error("CertificateService: failed to delete stale cert")
|
||||
} else {
|
||||
log.Printf("CertificateService: removed stale DB cert for %s\n", c.Domains)
|
||||
logger.Log().WithField("domain", c.Domains).Info("CertificateService: removed stale DB cert")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -211,7 +211,7 @@ func (s *CertificateService) SyncFromDisk() error {
|
||||
|
||||
s.lastScan = time.Now()
|
||||
s.initialized = true
|
||||
log.Printf("CertificateService: disk sync complete, %d certificates cached", len(s.cache))
|
||||
logger.Log().WithField("count", len(s.cache)).Info("CertificateService: disk sync complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -319,7 +319,7 @@ func (s *CertificateService) ListCertificates() ([]CertificateInfo, error) {
|
||||
// Trigger background rescan for stale cache
|
||||
go func() {
|
||||
if err := s.SyncFromDisk(); err != nil {
|
||||
log.Printf("CertificateService: background sync failed: %v", err)
|
||||
logger.Log().WithError(err).Error("CertificateService: background sync failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -396,9 +396,9 @@ func (s *CertificateService) DeleteCertificate(id uint) error {
|
||||
if err == nil && !info.IsDir() && strings.HasSuffix(info.Name(), ".crt") {
|
||||
if info.Name() == cert.Domains+".crt" {
|
||||
// Found it
|
||||
log.Printf("CertificateService: deleting ACME cert file %s", path)
|
||||
logger.Log().WithField("path", path).Info("CertificateService: deleting ACME cert file")
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Printf("CertificateService: failed to delete cert file: %v", err)
|
||||
logger.Log().WithError(err).Error("CertificateService: failed to delete cert file")
|
||||
}
|
||||
// Try to delete key as well
|
||||
keyPath := strings.TrimSuffix(path, ".crt") + ".key"
|
||||
|
||||
@@ -3,7 +3,8 @@ package services
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"context"
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
"net"
|
||||
neturl "net/url"
|
||||
"strings"
|
||||
@@ -73,10 +74,10 @@ func (s *NotificationService) MarkAllAsRead() error {
|
||||
|
||||
// External Notifications (Shoutrrr & Custom Webhooks)
|
||||
|
||||
func (s *NotificationService) SendExternal(eventType, title, message string, data map[string]interface{}) {
|
||||
func (s *NotificationService) SendExternal(ctx context.Context, eventType, title, message string, data map[string]interface{}) {
|
||||
var providers []models.NotificationProvider
|
||||
if err := s.DB.Where("enabled = ?", true).Find(&providers).Error; err != nil {
|
||||
log.Printf("Failed to fetch notification providers: %v", err)
|
||||
logger.Log().WithError(err).Error("Failed to fetch notification providers")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -118,29 +119,29 @@ func (s *NotificationService) SendExternal(eventType, title, message string, dat
|
||||
|
||||
go func(p models.NotificationProvider) {
|
||||
if p.Type == "webhook" {
|
||||
if err := s.sendCustomWebhook(p, data); err != nil {
|
||||
log.Printf("Failed to send webhook to %s: %v", p.Name, err)
|
||||
if err := s.sendCustomWebhook(ctx, p, data); err != nil {
|
||||
logger.Log().WithError(err).WithField("provider", p.Name).Error("Failed to send webhook")
|
||||
}
|
||||
} else {
|
||||
url := normalizeURL(p.Type, p.URL)
|
||||
// Validate HTTP/HTTPS destinations used by shoutrrr to reduce SSRF risk
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
if _, err := validateWebhookURL(url); err != nil {
|
||||
log.Printf("Skipping notification for provider %s due to invalid destination", p.Name)
|
||||
logger.Log().WithField("provider", p.Name).Warn("Skipping notification for provider due to invalid destination")
|
||||
return
|
||||
}
|
||||
}
|
||||
// Use newline for better formatting in chat apps
|
||||
msg := fmt.Sprintf("%s\n\n%s", title, message)
|
||||
if err := shoutrrr.Send(url, msg); err != nil {
|
||||
log.Printf("Failed to send notification to %s: %v", p.Name, err)
|
||||
logger.Log().WithError(err).WithField("provider", p.Name).Error("Failed to send notification")
|
||||
}
|
||||
}
|
||||
}(provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NotificationService) sendCustomWebhook(p models.NotificationProvider, data map[string]interface{}) error {
|
||||
func (s *NotificationService) sendCustomWebhook(ctx context.Context, p models.NotificationProvider, data map[string]interface{}) error {
|
||||
// Built-in templates
|
||||
const minimalTemplate = `{"message": {{toJSON .Message}}, "title": {{toJSON .Title}}, "time": {{toJSON .Time}}, "event": {{toJSON .EventType}}}`
|
||||
const detailedTemplate = `{"title": {{toJSON .Title}}, "message": {{toJSON .Message}}, "time": {{toJSON .Time}}, "event": {{toJSON .EventType}}, "host": {{toJSON .HostName}}, "host_ip": {{toJSON .HostIP}}, "service_count": {{toJSON .ServiceCount}}, "services": {{toJSON .Services}}, "data": {{toJSON .}}}`
|
||||
@@ -249,11 +250,17 @@ func (s *NotificationService) sendCustomWebhook(p models.NotificationProvider, d
|
||||
Path: u.Path,
|
||||
RawQuery: u.RawQuery,
|
||||
}
|
||||
req, err := http.NewRequest("POST", safeURL.String(), &body)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", safeURL.String(), &body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create webhook request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// Propagate request id header if present in context
|
||||
if rid := ctx.Value("requestID"); rid != nil {
|
||||
if ridStr, ok := rid.(string); ok {
|
||||
req.Header.Set("X-Request-ID", ridStr)
|
||||
}
|
||||
}
|
||||
// Preserve original hostname for virtual host (Host header)
|
||||
req.Host = u.Host
|
||||
|
||||
@@ -344,7 +351,7 @@ func (s *NotificationService) TestProvider(provider models.NotificationProvider)
|
||||
"Latency": 123,
|
||||
"Time": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
return s.sendCustomWebhook(provider, data)
|
||||
return s.sendCustomWebhook(context.Background(), provider, data)
|
||||
}
|
||||
url := normalizeURL(provider.Type, provider.URL)
|
||||
return shoutrrr.Send(url, "Test notification from Charon")
|
||||
|
||||
@@ -3,6 +3,7 @@ package services
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -165,7 +166,7 @@ func TestNotificationService_SendExternal(t *testing.T) {
|
||||
}
|
||||
svc.CreateProvider(&provider)
|
||||
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
|
||||
|
||||
select {
|
||||
case <-received:
|
||||
@@ -200,7 +201,7 @@ func TestNotificationService_SendExternal_MinimalVsDetailedTemplates(t *testing.
|
||||
svc.CreateProvider(&providerMin)
|
||||
|
||||
data := map[string]interface{}{"Title": "Min Title", "Message": "Min Message", "Time": time.Now().Format(time.RFC3339), "EventType": "uptime"}
|
||||
svc.SendExternal("uptime", "Min Title", "Min Message", data)
|
||||
svc.SendExternal(context.Background(), "uptime", "Min Title", "Min Message", data)
|
||||
|
||||
select {
|
||||
case body := <-rcvMinimal:
|
||||
@@ -235,7 +236,7 @@ func TestNotificationService_SendExternal_MinimalVsDetailedTemplates(t *testing.
|
||||
svc.CreateProvider(&providerDet)
|
||||
|
||||
dataDet := map[string]interface{}{"Title": "Det Title", "Message": "Det Message", "Time": time.Now().Format(time.RFC3339), "EventType": "uptime", "HostName": "example-host", "HostIP": "1.2.3.4", "ServiceCount": 1, "Services": []map[string]interface{}{{"Name": "svc1"}}}
|
||||
svc.SendExternal("uptime", "Det Title", "Det Message", dataDet)
|
||||
svc.SendExternal(context.Background(), "uptime", "Det Title", "Det Message", dataDet)
|
||||
|
||||
select {
|
||||
case body := <-rcvDetailed:
|
||||
@@ -275,7 +276,7 @@ func TestNotificationService_SendExternal_Filtered(t *testing.T) {
|
||||
// Force update to false because GORM default tag might override zero value (false) on Create
|
||||
db.Model(&provider).Update("notify_proxy_hosts", false)
|
||||
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
|
||||
|
||||
select {
|
||||
case <-received:
|
||||
@@ -299,7 +300,7 @@ func TestNotificationService_SendExternal_Shoutrrr(t *testing.T) {
|
||||
svc.CreateProvider(&provider)
|
||||
|
||||
// This will log an error but should cover the code path
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
|
||||
|
||||
// Give it a moment to run goroutine
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -356,7 +357,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
|
||||
URL: "://invalid-url",
|
||||
}
|
||||
data := map[string]interface{}{"Title": "Test", "Message": "Test Message"}
|
||||
err := svc.sendCustomWebhook(provider, data)
|
||||
err := svc.sendCustomWebhook(context.Background(), provider, data)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -373,7 +374,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
|
||||
// But for unit test speed, we should probably mock or use a closed port on localhost
|
||||
// Using a closed port on localhost is faster
|
||||
provider.URL = "http://127.0.0.1:54321" // Assuming this port is closed
|
||||
err := svc.sendCustomWebhook(provider, data)
|
||||
err := svc.sendCustomWebhook(context.Background(), provider, data)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -388,7 +389,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
|
||||
URL: ts.URL,
|
||||
}
|
||||
data := map[string]interface{}{"Title": "Test", "Message": "Test Message"}
|
||||
err := svc.sendCustomWebhook(provider, data)
|
||||
err := svc.sendCustomWebhook(context.Background(), provider, data)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "500")
|
||||
})
|
||||
@@ -413,7 +414,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
|
||||
Config: `{"custom": "Test: {{.Title}}"}`,
|
||||
}
|
||||
data := map[string]interface{}{"Title": "My Title", "Message": "Test Message"}
|
||||
svc.sendCustomWebhook(provider, data)
|
||||
svc.sendCustomWebhook(context.Background(), provider, data)
|
||||
|
||||
select {
|
||||
case <-received:
|
||||
@@ -443,7 +444,7 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
|
||||
// Config is empty, so default template is used: minimal
|
||||
}
|
||||
data := map[string]interface{}{"Title": "Default Title", "Message": "Test Message"}
|
||||
svc.sendCustomWebhook(provider, data)
|
||||
svc.sendCustomWebhook(context.Background(), provider, data)
|
||||
|
||||
select {
|
||||
case <-received:
|
||||
@@ -454,6 +455,32 @@ func TestNotificationService_SendCustomWebhook_Errors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationService_SendCustomWebhook_PropagatesRequestID(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
|
||||
received := make(chan string, 1)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
received <- r.Header.Get("X-Request-ID")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := models.NotificationProvider{Type: "webhook", URL: ts.URL}
|
||||
data := map[string]interface{}{"Title": "Test", "Message": "Test"}
|
||||
// Build context with requestID value
|
||||
ctx := context.WithValue(context.Background(), "requestID", "my-rid")
|
||||
err := svc.sendCustomWebhook(ctx, provider, data)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case rid := <-received:
|
||||
assert.Equal(t, "my-rid", rid)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for webhook request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationService_TestProvider_Errors(t *testing.T) {
|
||||
db := setupNotificationTestDB(t)
|
||||
svc := NewNotificationService(db)
|
||||
@@ -537,7 +564,7 @@ func TestNotificationService_SendExternal_EdgeCases(t *testing.T) {
|
||||
svc.CreateProvider(&provider)
|
||||
|
||||
// Should complete without error
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
@@ -580,9 +607,9 @@ func TestNotificationService_SendExternal_EdgeCases(t *testing.T) {
|
||||
require.False(t, saved.NotifyUptime, "NotifyUptime should be false")
|
||||
require.False(t, saved.NotifyCerts, "NotifyCerts should be false")
|
||||
|
||||
svc.SendExternal("proxy_host", "Title", "Message", nil)
|
||||
svc.SendExternal("uptime", "Title", "Message", nil)
|
||||
svc.SendExternal("cert", "Title", "Message", nil)
|
||||
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", nil)
|
||||
svc.SendExternal(context.Background(), "uptime", "Title", "Message", nil)
|
||||
svc.SendExternal(context.Background(), "cert", "Title", "Message", nil)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
@@ -615,7 +642,7 @@ func TestNotificationService_SendExternal_EdgeCases(t *testing.T) {
|
||||
customData := map[string]interface{}{
|
||||
"CustomField": "test-value",
|
||||
}
|
||||
svc.SendExternal("proxy_host", "Title", "Message", customData)
|
||||
svc.SendExternal(context.Background(), "proxy_host", "Title", "Message", customData)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, "test-value", receivedCustom.Load().(string))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/Wikid82/charon/backend/internal/logger"
|
||||
@@ -521,7 +522,7 @@ func (s *UptimeService) sendHostDownNotification(host *models.UptimeHost, downMo
|
||||
"Services": downMonitors,
|
||||
"Time": time.Now().Format(time.RFC1123),
|
||||
}
|
||||
s.NotificationService.SendExternal("uptime", title, sb.String(), data)
|
||||
s.NotificationService.SendExternal(context.Background(), "uptime", title, sb.String(), data)
|
||||
|
||||
logger.Log().WithField("host_name", host.Name).WithField("service_count", len(downMonitors)).Info("Sent consolidated DOWN notification")
|
||||
}
|
||||
@@ -751,7 +752,7 @@ func (s *UptimeService) flushPendingNotification(hostID string) {
|
||||
"Services": pending.downMonitors,
|
||||
"Time": time.Now().Format(time.RFC1123),
|
||||
}
|
||||
s.NotificationService.SendExternal("uptime", title, sb.String(), data)
|
||||
s.NotificationService.SendExternal(context.Background(), "uptime", title, sb.String(), data)
|
||||
|
||||
logger.Log().WithField("count", len(pending.downMonitors)).WithField("host", pending.hostName).Info("Sent batched DOWN notification")
|
||||
}
|
||||
@@ -781,7 +782,7 @@ func (s *UptimeService) sendRecoveryNotification(monitor models.UptimeMonitor, d
|
||||
"Time": time.Now().Format(time.RFC1123),
|
||||
"URL": monitor.URL,
|
||||
}
|
||||
s.NotificationService.SendExternal("uptime", title, sb.String(), data)
|
||||
s.NotificationService.SendExternal(context.Background(), "uptime", title, sb.String(), data)
|
||||
}
|
||||
|
||||
// FlushPendingNotifications flushes all pending batched notifications immediately.
|
||||
|
||||
Reference in New Issue
Block a user