From 6177b8cf92ff25a9337fa3af6a949ceaee5c15c7 Mon Sep 17 00:00:00 2001 From: KidusAlemayehu Date: Tue, 13 May 2025 16:44:31 +0300 Subject: [PATCH] refactor notification service to use Gorilla WebSocket; implement WebSocket authentication middleware and notification hub --- internal/services/notfication/port.go | 2 +- internal/services/notfication/service.go | 76 +++++++++---- internal/web_server/app.go | 10 +- internal/web_server/handlers/handlers.go | 4 +- .../handlers/notification_handler.go | 104 +++++++++++++----- internal/web_server/middleware.go | 28 +++++ internal/web_server/routes.go | 2 +- internal/web_server/ws/ws.go | 73 ++++++++++++ 8 files changed, 241 insertions(+), 58 deletions(-) create mode 100644 internal/web_server/ws/ws.go diff --git a/internal/services/notfication/port.go b/internal/services/notfication/port.go index 9fa2f72..23120ee 100644 --- a/internal/services/notfication/port.go +++ b/internal/services/notfication/port.go @@ -4,7 +4,7 @@ import ( "context" "github.com/SamuelTariku/FortuneBet-Backend/internal/domain" - "github.com/gofiber/websocket/v2" + "github.com/gorilla/websocket" ) type NotificationStore interface { diff --git a/internal/services/notfication/service.go b/internal/services/notfication/service.go index e21f7da..9c5597e 100644 --- a/internal/services/notfication/service.go +++ b/internal/services/notfication/service.go @@ -11,12 +11,14 @@ import ( "github.com/SamuelTariku/FortuneBet-Backend/internal/domain" "github.com/SamuelTariku/FortuneBet-Backend/internal/pkgs/helpers" "github.com/SamuelTariku/FortuneBet-Backend/internal/repository" + "github.com/SamuelTariku/FortuneBet-Backend/internal/web_server/ws" afro "github.com/amanuelabay/afrosms-go" - "github.com/gofiber/websocket/v2" + "github.com/gorilla/websocket" ) type Service struct { repo repository.NotificationRepository + Hub *ws.NotificationHub connections sync.Map notificationCh chan *domain.Notification stopCh chan struct{} @@ -24,9 +26,11 @@ type Service struct { logger *slog.Logger } -func New(repo repository.NotificationRepository, logger *slog.Logger, cfg *config.Config) NotificationStore { +func New(repo repository.NotificationRepository, logger *slog.Logger, cfg *config.Config) *Service { + hub := ws.NewNotificationHub() svc := &Service{ repo: repo, + Hub: hub, logger: logger, connections: sync.Map{}, notificationCh: make(chan *domain.Notification, 1000), @@ -34,6 +38,7 @@ func New(repo repository.NotificationRepository, logger *slog.Logger, cfg *confi config: cfg, } + go hub.Run() go svc.startWorker() go svc.startRetryWorker() @@ -63,10 +68,18 @@ func (s *Service) SendNotification(ctx context.Context, notification *domain.Not notification = created + if notification.DeliveryChannel == domain.DeliveryChannelInApp { + s.Hub.Broadcast <- map[string]interface{}{ + "type": "CREATED_NOTIFICATION", + "recipient_id": notification.RecipientID, + "payload": notification, + } + } + select { case s.notificationCh <- notification: default: - s.logger.Error("[NotificationSvc.SendNotification] Notification channel full, dropping notification", "id", notification.ID) + s.logger.Warn("[NotificationSvc.SendNotification] Notification channel full, dropping notification", "id", notification.ID) } return nil @@ -78,6 +91,21 @@ func (s *Service) MarkAsRead(ctx context.Context, notificationID string, recipie s.logger.Error("[NotificationSvc.MarkAsRead] Failed to mark notification as read", "notificationID", notificationID, "recipientID", recipientID, "error", err) return err } + + // count, err := s.repo.CountUnreadNotifications(ctx, recipientID) + // if err != nil { + // s.logger.Error("[NotificationSvc.MarkAsRead] Failed to count unread notifications", "recipientID", recipientID, "error", err) + // return err + // } + + // s.Hub.Broadcast <- map[string]interface{}{ + // "type": "COUNT_NOT_OPENED_NOTIFICATION", + // "recipient_id": recipientID, + // "payload": map[string]int{ + // "not_opened_notifications_count": int(count), + // }, + // } + s.logger.Info("[NotificationSvc.MarkAsRead] Notification marked as read", "notificationID", notificationID, "recipientID", recipientID) return nil } @@ -99,7 +127,6 @@ func (s *Service) ConnectWebSocket(ctx context.Context, recipientID int64, c *we } func (s *Service) DisconnectWebSocket(recipientID int64) { - s.connections.Delete(recipientID) if conn, loaded := s.connections.LoadAndDelete(recipientID); loaded { conn.(*websocket.Conn).Close() s.logger.Info("[NotificationSvc.DisconnectWebSocket] Disconnected WebSocket", "recipientID", recipientID) @@ -160,21 +187,26 @@ func (s *Service) ListRecipientIDs(ctx context.Context, receiver domain.Notifica func (s *Service) handleNotification(notification *domain.Notification) { ctx := context.Background() - if conn, ok := s.connections.Load(notification.RecipientID); ok { - data, err := notification.ToJSON() + switch notification.DeliveryChannel { + case domain.DeliveryChannelSMS: + err := s.SendSMS(ctx, notification.RecipientID, notification.Payload.Message) if err != nil { - s.logger.Error("[NotificationSvc.HandleNotification] Failed to serialize notification", "id", notification.ID, "error", err) - return - } - if err := conn.(*websocket.Conn).WriteMessage(websocket.TextMessage, data); err != nil { - s.logger.Error("[NotificationSvc.HandleNotification] Failed to send WebSocket message", "id", notification.ID, "error", err) notification.DeliveryStatus = domain.DeliveryStatusFailed } else { notification.DeliveryStatus = domain.DeliveryStatusSent } - } else { - s.logger.Warn("[NotificationSvc.HandleNotification] No WebSocket connection for recipient", "recipientID", notification.RecipientID) - notification.DeliveryStatus = domain.DeliveryStatusFailed + case domain.DeliveryChannelEmail: + err := s.SendEmail(ctx, notification.RecipientID, notification.Payload.Headline, notification.Payload.Message) + if err != nil { + notification.DeliveryStatus = domain.DeliveryStatusFailed + } else { + notification.DeliveryStatus = domain.DeliveryStatusSent + } + default: + if notification.DeliveryChannel != domain.DeliveryChannelInApp { + s.logger.Warn("[NotificationSvc.HandleNotification] Unsupported delivery channel", "channel", notification.DeliveryChannel) + notification.DeliveryStatus = domain.DeliveryStatusFailed + } } if _, err := s.repo.UpdateNotificationStatus(ctx, notification.ID, string(notification.DeliveryStatus), notification.IsRead, notification.Metadata); err != nil { @@ -210,13 +242,17 @@ func (s *Service) retryFailedNotifications() { go func(notification *domain.Notification) { for attempt := 0; attempt < 3; attempt++ { time.Sleep(time.Duration(attempt) * time.Second) - if conn, ok := s.connections.Load(notification.RecipientID); ok { - data, err := notification.ToJSON() - if err != nil { - s.logger.Error("[NotificationSvc.RetryFailedNotifications] Failed to serialize notification for retry", "id", notification.ID, "error", err) - continue + if notification.DeliveryChannel == domain.DeliveryChannelSMS { + if err := s.SendSMS(ctx, notification.RecipientID, notification.Payload.Message); err == nil { + notification.DeliveryStatus = domain.DeliveryStatusSent + if _, err := s.repo.UpdateNotificationStatus(ctx, notification.ID, string(notification.DeliveryStatus), notification.IsRead, notification.Metadata); err != nil { + s.logger.Error("[NotificationSvc.RetryFailedNotifications] Failed to update after retry", "id", notification.ID, "error", err) + } + s.logger.Info("[NotificationSvc.RetryFailedNotifications] Successfully retried notification", "id", notification.ID) + return } - if err := conn.(*websocket.Conn).WriteMessage(websocket.TextMessage, data); err == nil { + } else if notification.DeliveryChannel == domain.DeliveryChannelEmail { + if err := s.SendEmail(ctx, notification.RecipientID, notification.Payload.Headline, notification.Payload.Message); err == nil { notification.DeliveryStatus = domain.DeliveryStatusSent if _, err := s.repo.UpdateNotificationStatus(ctx, notification.ID, string(notification.DeliveryStatus), notification.IsRead, notification.Metadata); err != nil { s.logger.Error("[NotificationSvc.RetryFailedNotifications] Failed to update after retry", "id", notification.ID, "error", err) diff --git a/internal/web_server/app.go b/internal/web_server/app.go index f3e50bd..e370f4e 100644 --- a/internal/web_server/app.go +++ b/internal/web_server/app.go @@ -29,7 +29,7 @@ import ( type App struct { fiber *fiber.App logger *slog.Logger - NotidicationStore notificationservice.NotificationStore + NotidicationStore *notificationservice.Service referralSvc referralservice.ReferralStore port int authSvc *authentication.Service @@ -61,7 +61,7 @@ func NewApp( transactionSvc *transaction.Service, branchSvc *branch.Service, companySvc *company.Service, - notidicationStore notificationservice.NotificationStore, + notidicationStore *notificationservice.Service, prematchSvc *odds.ServiceImpl, eventSvc event.Service, referralSvc referralservice.ReferralStore, @@ -76,9 +76,9 @@ func NewApp( }) app.Use(cors.New(cors.Config{ - AllowOrigins: "*", // Specify your frontend's origin - AllowMethods: "GET,POST,PUT,DELETE,OPTIONS", // Specify the allowed HTTP methods - AllowHeaders: "Content-Type,Authorization,platform", // Specify the allowed headers + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE,OPTIONS", + AllowHeaders: "Content-Type,Authorization,platform", // AllowCredentials: true, })) diff --git a/internal/web_server/handlers/handlers.go b/internal/web_server/handlers/handlers.go index b5f811d..a72e514 100644 --- a/internal/web_server/handlers/handlers.go +++ b/internal/web_server/handlers/handlers.go @@ -22,7 +22,7 @@ import ( type Handler struct { logger *slog.Logger - notificationSvc notificationservice.NotificationStore + notificationSvc *notificationservice.Service userSvc *user.Service referralSvc referralservice.ReferralStore walletSvc *wallet.Service @@ -41,7 +41,7 @@ type Handler struct { func New( logger *slog.Logger, - notificationSvc notificationservice.NotificationStore, + notificationSvc *notificationservice.Service, validator *customvalidator.CustomValidator, walletSvc *wallet.Service, referralSvc referralservice.ReferralStore, diff --git a/internal/web_server/handlers/notification_handler.go b/internal/web_server/handlers/notification_handler.go index 9d8ca1a..8c6337b 100644 --- a/internal/web_server/handlers/notification_handler.go +++ b/internal/web_server/handlers/notification_handler.go @@ -3,53 +3,99 @@ package handlers import ( "context" "encoding/json" + "net" + "net/http" "github.com/SamuelTariku/FortuneBet-Backend/internal/domain" + "github.com/SamuelTariku/FortuneBet-Backend/internal/web_server/ws" "github.com/gofiber/fiber/v2" - "github.com/gofiber/websocket/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" + "github.com/gorilla/websocket" + "github.com/valyala/fasthttp/fasthttpadaptor" ) -func (h *Handler) ConnectSocket(c *fiber.Ctx) error { - if !websocket.IsWebSocketUpgrade(c) { - h.logger.Warn("WebSocket upgrade required") - return fiber.ErrUpgradeRequired - } +func hijackHTTP(c *fiber.Ctx) (net.Conn, http.ResponseWriter, error) { + var rw http.ResponseWriter + var conn net.Conn + // This is a trick: fasthttpadaptor gives us the HTTP interfaces + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + return + } + var err error + conn, _, err = hj.Hijack() + if err != nil { + return + } + rw = w + }) + + fasthttpadaptor.NewFastHTTPHandler(handler)(c.Context()) + + if conn == nil || rw == nil { + return nil, nil, fiber.NewError(fiber.StatusInternalServerError, "Failed to hijack connection") + } + return conn, rw, nil +} + +func (h *Handler) ConnectSocket(c *fiber.Ctx) error { userID, ok := c.Locals("userID").(int64) if !ok || userID == 0 { h.logger.Error("Invalid user ID in context") - return fiber.NewError(fiber.StatusUnauthorized, "invalid user identification") + return fiber.NewError(fiber.StatusUnauthorized, "Invalid user identification") } - c.Locals("allowed", true) + // Convert *fiber.Ctx to *http.Request + req, err := adaptor.ConvertRequest(c, false) + if err != nil { + h.logger.Error("Failed to convert request", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to convert request") + } - return websocket.New(func(conn *websocket.Conn) { - ctx := context.Background() - logger := h.logger.With("userID", userID, "remoteAddr", conn.RemoteAddr()) + // Create a net.Conn hijacked from the fasthttp context + netConn, rw, err := hijackHTTP(c) + if err != nil { + h.logger.Error("Failed to hijack connection", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to hijack connection") + } - if err := h.notificationSvc.ConnectWebSocket(ctx, userID, conn); err != nil { - logger.Error("Failed to connect WebSocket", "error", err) - _ = conn.Close() - return - } + // Upgrade the connection using Gorilla's Upgrader + conn, err := ws.Upgrader.Upgrade(rw, req, nil) + if err != nil { + h.logger.Error("WebSocket upgrade failed", "error", err) + netConn.Close() + return fiber.NewError(fiber.StatusInternalServerError, "WebSocket upgrade failed") + } - logger.Info("WebSocket connection established") + client := &ws.Client{ + Conn: conn, + RecipientID: userID, + } - defer func() { - h.notificationSvc.DisconnectWebSocket(userID) - logger.Info("WebSocket connection closed") - _ = conn.Close() - }() + h.notificationSvc.Hub.Register <- client + h.logger.Info("WebSocket connection established", "userID", userID) - for { - if _, _, err := conn.ReadMessage(); err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - logger.Warn("WebSocket unexpected close", "error", err) - } - break + defer func() { + h.notificationSvc.Hub.Unregister <- client + h.logger.Info("WebSocket connection closed", "userID", userID) + conn.Close() + }() + + for { + _, _, err := conn.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + h.logger.Info("WebSocket closed normally", "userID", userID) + } else { + h.logger.Warn("Unexpected WebSocket closure", "userID", userID, "error", err) } + break } - })(c) + } + + return nil } func (h *Handler) MarkNotificationAsRead(c *fiber.Ctx) error { diff --git a/internal/web_server/middleware.go b/internal/web_server/middleware.go index c7c90d5..63cfe6e 100644 --- a/internal/web_server/middleware.go +++ b/internal/web_server/middleware.go @@ -71,3 +71,31 @@ func (a *App) CompanyOnly(c *fiber.Ctx) error { } return c.Next() } + +func (a *App) WebsocketAuthMiddleware(c *fiber.Ctx) error { + tokenStr := c.Query("token") + if tokenStr == "" { + a.logger.Error("Missing token in query parameter") + return fiber.NewError(fiber.StatusUnauthorized, "Missing token") + } + + claim, err := jwtutil.ParseJwt(tokenStr, a.JwtConfig.JwtAccessKey) + if err != nil { + if errors.Is(err, jwtutil.ErrExpiredToken) { + a.logger.Error("Token expired") + return fiber.NewError(fiber.StatusUnauthorized, "Token expired") + } + a.logger.Error("Invalid token", "error", err) + return fiber.NewError(fiber.StatusUnauthorized, "Invalid token") + } + + userID := claim.UserId + if userID == 0 { + a.logger.Error("Invalid user ID in token claims") + return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID") + } + + c.Locals("userID", userID) + a.logger.Info("Authenticated WebSocket connection", "userID", userID) + return c.Next() +} diff --git a/internal/web_server/routes.go b/internal/web_server/routes.go index 7b7e22a..6095931 100644 --- a/internal/web_server/routes.go +++ b/internal/web_server/routes.go @@ -169,7 +169,7 @@ func (a *App) initAppRoutes() { a.fiber.Put("/transaction/:id", a.authMiddleware, h.UpdateTransactionVerified) // Notification Routes - a.fiber.Get("/notifications/ws/connect/:recipientID", h.ConnectSocket) + a.fiber.Get("/ws/connect", a.WebsocketAuthMiddleware, h.ConnectSocket) a.fiber.Post("/notifications/mark-as-read", h.MarkNotificationAsRead) a.fiber.Post("/notifications/create", h.CreateAndSendNotification) diff --git a/internal/web_server/ws/ws.go b/internal/web_server/ws/ws.go new file mode 100644 index 0000000..28fb860 --- /dev/null +++ b/internal/web_server/ws/ws.go @@ -0,0 +1,73 @@ +package ws + +import ( + "log" + "net/http" + "sync" + + "github.com/gorilla/websocket" +) + +type Client struct { + Conn *websocket.Conn + RecipientID int64 +} + +type NotificationHub struct { + Clients map[*Client]bool + Broadcast chan interface{} + Register chan *Client + Unregister chan *Client + mu sync.Mutex +} + +func NewNotificationHub() *NotificationHub { + return &NotificationHub{ + Clients: make(map[*Client]bool), + Broadcast: make(chan interface{}, 1000), + Register: make(chan *Client), + Unregister: make(chan *Client), + } +} + +func (h *NotificationHub) Run() { + for { + select { + case client := <-h.Register: + h.mu.Lock() + h.Clients[client] = true + h.mu.Unlock() + log.Printf("Client registered: %d", client.RecipientID) + case client := <-h.Unregister: + h.mu.Lock() + if _, ok := h.Clients[client]; ok { + delete(h.Clients, client) + client.Conn.Close() + } + h.mu.Unlock() + log.Printf("Client unregistered: %d", client.RecipientID) + case message := <-h.Broadcast: + h.mu.Lock() + for client := range h.Clients { + if payload, ok := message.(map[string]interface{}); ok { + if recipient, ok := payload["recipient_id"].(int64); ok && recipient == client.RecipientID { + err := client.Conn.WriteJSON(payload) + if err != nil { + delete(h.Clients, client) + client.Conn.Close() + } + } + } + } + h.mu.Unlock() + } + } +} + +var Upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, +}