package ws import ( "net/http" "strconv" "sync" "github.com/gofiber/fiber/v2" "github.com/gorilla/websocket" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpadaptor" ) type Client struct { Conn *websocket.Conn RecipientID int64 } type NotificationHub struct { Clients map[*Client]bool Broadcast chan interface{} Register chan *Client Unregister chan *Client mu sync.Mutex } func NewNotificationHub() *NotificationHub { return &NotificationHub{ Clients: make(map[*Client]bool), Broadcast: make(chan interface{}, 1000), Register: make(chan *Client), Unregister: make(chan *Client), } } func (h *NotificationHub) Run() { for { select { case client := <-h.Register: h.mu.Lock() h.Clients[client] = true h.mu.Unlock() case client := <-h.Unregister: h.mu.Lock() if _, ok := h.Clients[client]; ok { delete(h.Clients, client) client.Conn.Close() } h.mu.Unlock() case message := <-h.Broadcast: h.mu.Lock() for client := range h.Clients { if payload, ok := message.(map[string]interface{}); ok { if recipient, ok := payload["recipient_id"].(int64); ok && recipient == client.RecipientID { err := client.Conn.WriteJSON(payload) if err != nil { delete(h.Clients, client) client.Conn.Close() } } } } h.mu.Unlock() } } } var Upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true }, } type responseWriterWrapper struct { ctx *fasthttp.RequestCtx header http.Header wrote bool } func newResponseWriterWrapper(ctx *fasthttp.RequestCtx) *responseWriterWrapper { return &responseWriterWrapper{ ctx: ctx, header: make(http.Header), } } func (rw *responseWriterWrapper) Header() http.Header { return rw.header } func (rw *responseWriterWrapper) Write(b []byte) (int, error) { rw.writeHeaderIfNeeded(http.StatusOK) return rw.ctx.Write(b) } func (rw *responseWriterWrapper) WriteHeader(statusCode int) { if rw.wrote { return } rw.writeHeaderIfNeeded(statusCode) } func (rw *responseWriterWrapper) writeHeaderIfNeeded(statusCode int) { if rw.wrote { return } rw.wrote = true // Copy headers from rw.header to fasthttp.Response.Header for k, vv := range rw.header { for _, v := range vv { rw.ctx.Response.Header.Add(k, v) } } rw.ctx.SetStatusCode(statusCode) } // ✅ WebSocketHandler integrates Gorilla WebSocket with Fiber func (h *NotificationHub) WebSocketHandler() fiber.Handler { return func(c *fiber.Ctx) error { userIDStr := c.Params("user_id") userID, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { return c.Status(fiber.StatusBadRequest).SendString("Invalid user ID") } // Use your custom responseWriterWrapper here rw := newResponseWriterWrapper(c.Context()) stdReq := new(http.Request) if err := fasthttpadaptor.ConvertRequest(c.Context(), stdReq, true); err != nil { return err } conn, err := Upgrader.Upgrade(rw, stdReq, nil) if err != nil { return err } client := &Client{ Conn: conn, RecipientID: userID, } h.Register <- client defer func() { h.Unregister <- client }() for { if _, _, err := conn.ReadMessage(); err != nil { break } } return nil } } // func (h *NotificationHub) BroadcastWalletUpdate(userID int64, event event.WalletEvent) { // payload := map[string]interface{}{ // "type": event.EventType, // "wallet_id": event.WalletID, // "wallet_type": event.WalletType, // "user_id": event.UserID, // "balance": event.Balance, // "trigger": event.Trigger, // "recipient_id": userID, // } // h.Broadcast <- payload // }