Yimaru-BackEnd/internal/web_server/ws/ws.go

175 lines
3.7 KiB
Go

package ws
import (
"net/http"
"strconv"
"sync"
"github.com/SamuelTariku/FortuneBet-Backend/internal/event"
"github.com/gofiber/fiber/v2"
"github.com/gorilla/websocket"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpadaptor"
)
type Client struct {
Conn *websocket.Conn
RecipientID int64
}
type NotificationHub struct {
Clients map[*Client]bool
Broadcast chan interface{}
Register chan *Client
Unregister chan *Client
mu sync.Mutex
}
func NewNotificationHub() *NotificationHub {
return &NotificationHub{
Clients: make(map[*Client]bool),
Broadcast: make(chan interface{}, 1000),
Register: make(chan *Client),
Unregister: make(chan *Client),
}
}
func (h *NotificationHub) Run() {
for {
select {
case client := <-h.Register:
h.mu.Lock()
h.Clients[client] = true
h.mu.Unlock()
case client := <-h.Unregister:
h.mu.Lock()
if _, ok := h.Clients[client]; ok {
delete(h.Clients, client)
client.Conn.Close()
}
h.mu.Unlock()
case message := <-h.Broadcast:
h.mu.Lock()
for client := range h.Clients {
if payload, ok := message.(map[string]interface{}); ok {
if recipient, ok := payload["recipient_id"].(int64); ok && recipient == client.RecipientID {
err := client.Conn.WriteJSON(payload)
if err != nil {
delete(h.Clients, client)
client.Conn.Close()
}
}
}
}
h.mu.Unlock()
}
}
}
var Upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type responseWriterWrapper struct {
ctx *fasthttp.RequestCtx
header http.Header
wrote bool
}
func newResponseWriterWrapper(ctx *fasthttp.RequestCtx) *responseWriterWrapper {
return &responseWriterWrapper{
ctx: ctx,
header: make(http.Header),
}
}
func (rw *responseWriterWrapper) Header() http.Header {
return rw.header
}
func (rw *responseWriterWrapper) Write(b []byte) (int, error) {
rw.writeHeaderIfNeeded(http.StatusOK)
return rw.ctx.Write(b)
}
func (rw *responseWriterWrapper) WriteHeader(statusCode int) {
if rw.wrote {
return
}
rw.writeHeaderIfNeeded(statusCode)
}
func (rw *responseWriterWrapper) writeHeaderIfNeeded(statusCode int) {
if rw.wrote {
return
}
rw.wrote = true
// Copy headers from rw.header to fasthttp.Response.Header
for k, vv := range rw.header {
for _, v := range vv {
rw.ctx.Response.Header.Add(k, v)
}
}
rw.ctx.SetStatusCode(statusCode)
}
// ✅ WebSocketHandler integrates Gorilla WebSocket with Fiber
func (h *NotificationHub) WebSocketHandler() fiber.Handler {
return func(c *fiber.Ctx) error {
userIDStr := c.Params("user_id")
userID, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
return c.Status(fiber.StatusBadRequest).SendString("Invalid user ID")
}
// Use your custom responseWriterWrapper here
rw := newResponseWriterWrapper(c.Context())
stdReq := new(http.Request)
if err := fasthttpadaptor.ConvertRequest(c.Context(), stdReq, true); err != nil {
return err
}
conn, err := Upgrader.Upgrade(rw, stdReq, nil)
if err != nil {
return err
}
client := &Client{
Conn: conn,
RecipientID: userID,
}
h.Register <- client
defer func() {
h.Unregister <- client
}()
for {
if _, _, err := conn.ReadMessage(); err != nil {
break
}
}
return nil
}
}
func (h *NotificationHub) BroadcastWalletUpdate(userID int64, event event.WalletEvent) {
payload := map[string]interface{}{
"type": event.EventType,
"wallet_id": event.WalletID,
"user_id": event.UserID,
"balance": event.Balance,
"trigger": event.Trigger,
"recipient_id": userID,
}
h.Broadcast <- payload
}