Yimaru-BackEnd/internal/web_server/middleware.go

117 lines
3.5 KiB
Go

package httpserver
import (
"errors"
"fmt"
"strings"
"github.com/SamuelTariku/FortuneBet-Backend/internal/domain"
jwtutil "github.com/SamuelTariku/FortuneBet-Backend/internal/web_server/jwt"
"github.com/gofiber/fiber/v2"
)
func (a *App) authMiddleware(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
fmt.Println("Auth Header Missing")
return fiber.NewError(fiber.StatusUnauthorized, "Authorization header missing")
}
if !strings.HasPrefix(authHeader, "Bearer ") {
fmt.Println("Invalid authorization header format")
return fiber.NewError(fiber.StatusUnauthorized, "Invalid authorization header format")
}
accessToken := strings.TrimPrefix(authHeader, "Bearer ")
c.Locals("access_token", accessToken)
claim, err := jwtutil.ParseJwt(accessToken, a.JwtConfig.JwtAccessKey)
if err != nil {
if errors.Is(err, jwtutil.ErrExpiredToken) {
fmt.Println("Token Expired")
return fiber.NewError(fiber.StatusUnauthorized, "Access token expired")
}
fmt.Println("Invalid Token")
return fiber.NewError(fiber.StatusUnauthorized, "Invalid access token")
}
refreshToken := c.Get("Refresh-Token")
if refreshToken == "" {
// refreshToken = c.Cookies("refresh_token", "")
// return fiber.NewError(fiber.StatusUnauthorized, "Refresh token missing")
}
// Asserting to make sure that there is no company role without a valid company id
if claim.Role != domain.RoleSuperAdmin && claim.Role != domain.RoleCustomer && !claim.CompanyID.Valid {
fmt.Println("Company Role without Company ID")
return fiber.NewError(fiber.StatusInternalServerError, "Company Role without Company ID")
}
c.Locals("user_id", claim.UserId)
c.Locals("role", claim.Role)
c.Locals("company_id", claim.CompanyID)
c.Locals("refresh_token", refreshToken)
var branchID domain.ValidInt64
if claim.Role == domain.RoleCashier {
branch, err := a.branchSvc.GetBranchByCashier(c.Context(), claim.UserId)
if err != nil {
a.logger.Error("Failed to get branch id for bet", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to branch id for bet")
}
branchID = domain.ValidInt64{
Value: branch.ID,
Valid: true,
}
}
c.Locals("branch_id", branchID)
return c.Next()
}
func (a *App) SuperAdminOnly(c *fiber.Ctx) error {
userRole := c.Locals("role").(domain.Role)
if userRole != domain.RoleSuperAdmin {
return fiber.NewError(fiber.StatusUnauthorized, "Invalid access token")
}
return c.Next()
}
func (a *App) CompanyOnly(c *fiber.Ctx) error {
userRole := c.Locals("role").(domain.Role)
if userRole == domain.RoleCustomer {
return fiber.NewError(fiber.StatusUnauthorized, "Invalid access token")
}
return c.Next()
}
func (a *App) WebsocketAuthMiddleware(c *fiber.Ctx) error {
tokenStr := c.Query("token")
if tokenStr == "" {
a.logger.Error("Missing token in query parameter")
return fiber.NewError(fiber.StatusUnauthorized, "Missing token")
}
claim, err := jwtutil.ParseJwt(tokenStr, a.JwtConfig.JwtAccessKey)
if err != nil {
if errors.Is(err, jwtutil.ErrExpiredToken) {
a.logger.Error("Token expired")
return fiber.NewError(fiber.StatusUnauthorized, "Token expired")
}
a.logger.Error("Invalid token", "error", err)
return fiber.NewError(fiber.StatusUnauthorized, "Invalid token")
}
userID := claim.UserId
if userID == 0 {
a.logger.Error("Invalid user ID in token claims")
return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID")
}
c.Locals("userID", userID)
a.logger.Info("Authenticated WebSocket connection", "userID", userID)
return c.Next()
}