175 lines
3.7 KiB
Go
175 lines
3.7 KiB
Go
package ws
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/SamuelTariku/FortuneBet-Backend/internal/event"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/valyala/fasthttp"
|
|
"github.com/valyala/fasthttp/fasthttpadaptor"
|
|
)
|
|
|
|
type Client struct {
|
|
Conn *websocket.Conn
|
|
RecipientID int64
|
|
}
|
|
|
|
type NotificationHub struct {
|
|
Clients map[*Client]bool
|
|
Broadcast chan interface{}
|
|
Register chan *Client
|
|
Unregister chan *Client
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func NewNotificationHub() *NotificationHub {
|
|
return &NotificationHub{
|
|
Clients: make(map[*Client]bool),
|
|
Broadcast: make(chan interface{}, 1000),
|
|
Register: make(chan *Client),
|
|
Unregister: make(chan *Client),
|
|
}
|
|
}
|
|
|
|
func (h *NotificationHub) Run() {
|
|
for {
|
|
select {
|
|
case client := <-h.Register:
|
|
h.mu.Lock()
|
|
h.Clients[client] = true
|
|
h.mu.Unlock()
|
|
case client := <-h.Unregister:
|
|
h.mu.Lock()
|
|
if _, ok := h.Clients[client]; ok {
|
|
delete(h.Clients, client)
|
|
client.Conn.Close()
|
|
}
|
|
h.mu.Unlock()
|
|
case message := <-h.Broadcast:
|
|
h.mu.Lock()
|
|
for client := range h.Clients {
|
|
if payload, ok := message.(map[string]interface{}); ok {
|
|
if recipient, ok := payload["recipient_id"].(int64); ok && recipient == client.RecipientID {
|
|
err := client.Conn.WriteJSON(payload)
|
|
if err != nil {
|
|
delete(h.Clients, client)
|
|
client.Conn.Close()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
var Upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true
|
|
},
|
|
}
|
|
|
|
type responseWriterWrapper struct {
|
|
ctx *fasthttp.RequestCtx
|
|
header http.Header
|
|
wrote bool
|
|
}
|
|
|
|
func newResponseWriterWrapper(ctx *fasthttp.RequestCtx) *responseWriterWrapper {
|
|
return &responseWriterWrapper{
|
|
ctx: ctx,
|
|
header: make(http.Header),
|
|
}
|
|
}
|
|
|
|
func (rw *responseWriterWrapper) Header() http.Header {
|
|
return rw.header
|
|
}
|
|
|
|
func (rw *responseWriterWrapper) Write(b []byte) (int, error) {
|
|
rw.writeHeaderIfNeeded(http.StatusOK)
|
|
return rw.ctx.Write(b)
|
|
}
|
|
|
|
func (rw *responseWriterWrapper) WriteHeader(statusCode int) {
|
|
if rw.wrote {
|
|
return
|
|
}
|
|
rw.writeHeaderIfNeeded(statusCode)
|
|
}
|
|
|
|
func (rw *responseWriterWrapper) writeHeaderIfNeeded(statusCode int) {
|
|
if rw.wrote {
|
|
return
|
|
}
|
|
rw.wrote = true
|
|
|
|
// Copy headers from rw.header to fasthttp.Response.Header
|
|
for k, vv := range rw.header {
|
|
for _, v := range vv {
|
|
rw.ctx.Response.Header.Add(k, v)
|
|
}
|
|
}
|
|
|
|
rw.ctx.SetStatusCode(statusCode)
|
|
}
|
|
|
|
// ✅ WebSocketHandler integrates Gorilla WebSocket with Fiber
|
|
func (h *NotificationHub) WebSocketHandler() fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
userIDStr := c.Params("user_id")
|
|
userID, err := strconv.ParseInt(userIDStr, 10, 64)
|
|
if err != nil {
|
|
return c.Status(fiber.StatusBadRequest).SendString("Invalid user ID")
|
|
}
|
|
|
|
// Use your custom responseWriterWrapper here
|
|
rw := newResponseWriterWrapper(c.Context())
|
|
|
|
stdReq := new(http.Request)
|
|
if err := fasthttpadaptor.ConvertRequest(c.Context(), stdReq, true); err != nil {
|
|
return err
|
|
}
|
|
|
|
conn, err := Upgrader.Upgrade(rw, stdReq, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
client := &Client{
|
|
Conn: conn,
|
|
RecipientID: userID,
|
|
}
|
|
|
|
h.Register <- client
|
|
defer func() {
|
|
h.Unregister <- client
|
|
}()
|
|
|
|
for {
|
|
if _, _, err := conn.ReadMessage(); err != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (h *NotificationHub) BroadcastWalletUpdate(userID int64, event event.WalletEvent) {
|
|
payload := map[string]interface{}{
|
|
"type": event.EventType,
|
|
"wallet_id": event.WalletID,
|
|
"user_id": event.UserID,
|
|
"balance": event.Balance,
|
|
"trigger": event.Trigger,
|
|
"recipient_id": userID,
|
|
}
|
|
h.Broadcast <- payload
|
|
}
|