Yimaru-BackEnd/internal/web_server/middleware.go
Yared Yemane 79fb95ce36 Add category-based subscription controls for LMS and exam prep.
Introduce plan and content categories across programs and exam-prep catalog roots, wire category-aware checkout and access checks, and keep learner gating temporarily bypassed until data migration is ready.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 06:20:49 -07:00

615 lines
20 KiB
Go

package httpserver
import (
"Yimaru-Backend/internal/domain"
examprepsvc "Yimaru-Backend/internal/services/examprep"
jwtutil "Yimaru-Backend/internal/web_server/jwt"
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"go.uber.org/zap"
)
var categorySubscriptionGateDisabled = true
func (a *App) authMiddleware(c *fiber.Ctx) error {
ip := c.IP()
userAgent := c.Get("User-Agent")
c.Locals("ip_address", ip)
c.Locals("user_agent", userAgent)
// Get Authorization header (case-insensitive)
authHeader := strings.TrimSpace(c.Get("Authorization"))
if authHeader == "" {
// Try lowercase as fallback
authHeader = strings.TrimSpace(c.Get("authorization"))
}
fmt.Println("--------------------------------")
fmt.Println("All Headers:")
allHeaders := c.GetReqHeaders()
for key, value := range allHeaders {
fmt.Printf(" %s: %s\n", key, value)
}
fmt.Println("userAgent", userAgent)
fmt.Println("ip", ip)
fmt.Println("authHeader", authHeader)
fmt.Println("--------------------------------")
if authHeader == "" {
a.mongoLoggerSvc.Info("Authorization header missing",
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusUnauthorized, "Authorization header missing")
}
if !strings.HasPrefix(authHeader, "Bearer ") {
a.mongoLoggerSvc.Info("Invalid authorization header format",
zap.String("authHeader", authHeader),
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusUnauthorized, "Invalid authorization header format")
}
accessToken := strings.TrimPrefix(authHeader, "Bearer ")
c.Locals("access_token", accessToken)
claim, err := jwtutil.ParseJwt(accessToken, a.JwtConfig.JwtAccessKey)
if err != nil {
if errors.Is(err, jwtutil.ErrExpiredToken) {
a.mongoLoggerSvc.Info("Access Token Expired",
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusUnauthorized, "Access token expired")
}
a.mongoLoggerSvc.Info("Invalid Access Token",
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
zap.Error(err),
)
return fiber.NewError(fiber.StatusUnauthorized, "Invalid access token")
}
refreshToken := c.Get("Refresh-Token")
if refreshToken == "" {
// refreshToken = c.Cookies("refresh_token", "")
// return fiber.NewError(fiber.StatusUnauthorized, "Refresh token missing")
}
// Asserting to make sure that only the super admin can have a nil company ID
// if claim.Role != domain.RoleSuperAdmin && !claim.CompanyID.Valid {
// a.mongoLoggerSvc.Error("Company Role without Company ID",
// zap.Int64("userID", claim.UserId),
// zap.Int("status_code", fiber.StatusInternalServerError),
// zap.Error(err),
// zap.Time("timestamp", time.Now()),
// )
// return fiber.NewError(fiber.StatusInternalServerError, "Company Role without Company ID")
// }
c.Locals("user_id", claim.UserId)
c.Locals("role", claim.Role)
// c.Locals("company_id", domain.ValidInt64{
// Value: claim.CompanyID.Value,
// Valid: claim.CompanyID.Valid,
// })
c.Locals("refresh_token", refreshToken)
// var branchID domain.ValidInt64
if claim.Role == domain.RoleAdmin {
// branch, err := a.branchSvc.GetBranchByCashier(c.Context(), claim.UserId)
// if err != nil {
// a.mongoLoggerSvc.Error("Failed to get branch id for cashier",
// zap.Int64("userID", claim.UserId),
// zap.Int("status_code", fiber.StatusInternalServerError),
// zap.Error(err),
// zap.Time("timestamp", time.Now()),
// )
// return fiber.NewError(fiber.StatusInternalServerError, "Failed to branch id for cashier")
// }
// branchID = domain.ValidInt64{
// Value: branch.ID,
// Valid: true,
// }
}
// c.Locals("branch_id", branchID)
return c.Next()
}
func (a *App) SuperAdminOnly(c *fiber.Ctx) error {
userID := c.Locals("user_id").(int64)
userRole := c.Locals("role").(domain.Role)
if userRole != domain.RoleSuperAdmin {
a.mongoLoggerSvc.Warn("Attempt to access restricted SuperAdminOnly route",
zap.Int64("userID", userID),
zap.String("role", string(userRole)),
zap.Int("status_code", fiber.StatusForbidden),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusForbidden, "This route is restricted")
}
return c.Next()
}
// func (a *App) CompanyOnly(c *fiber.Ctx) error {
// userID := c.Locals("user_id").(int64)
// userRole := c.Locals("role").(domain.Role)
// if userRole == domain.RoleStudent {
// a.mongoLoggerSvc.Warn("Attempt to access restricted CompanyOnly route",
// zap.Int64("userID", userID),
// zap.String("role", string(userRole)),
// zap.Int("status_code", fiber.StatusForbidden),
// zap.Time("timestamp", time.Now()),
// )
// return fiber.NewError(fiber.StatusForbidden, "This route is restricted")
// }
// return c.Next()
// }
func (a *App) OnlyAdminAndAbove(c *fiber.Ctx) error {
userID := c.Locals("user_id").(int64)
userRole := c.Locals("role").(domain.Role)
if userRole != domain.RoleSuperAdmin && userRole != domain.RoleAdmin {
a.mongoLoggerSvc.Warn("Attempt to access restricted OnlyAdminAndAbove route",
zap.Int64("userID", userID),
zap.String("role", string(userRole)),
zap.Int("status_code", fiber.StatusForbidden),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusForbidden, "This route is restricted")
}
return c.Next()
}
// RequireActiveSubscription enforces an active subscription for learner accounts.
// Staff roles (SUPER_ADMIN, ADMIN, INSTRUCTOR, SUPPORT) bypass this check.
// Use after authMiddleware on routes that deliver paid learning content.
func (a *App) RequireActiveSubscription() fiber.Handler {
return func(c *fiber.Ctx) error {
role, ok := c.Locals("role").(domain.Role)
if !ok {
return fiber.NewError(fiber.StatusForbidden, "Role not found in context")
}
switch role {
case domain.RoleSuperAdmin, domain.RoleAdmin, domain.RoleInstructor, domain.RoleSupport:
return c.Next()
case domain.RoleStudent, domain.RoleOpenLearner:
userID, ok := c.Locals("user_id").(int64)
if !ok || userID == 0 {
return fiber.NewError(fiber.StatusUnauthorized, "Unauthorized")
}
active, err := a.subscriptionsSvc.HasActiveSubscription(c.Context(), userID)
if err != nil {
a.mongoLoggerSvc.Error("subscription check failed",
zap.Int64("userID", userID),
zap.String("path", c.Path()),
zap.Error(err),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription")
}
if !active {
// Temporary bypass: allow unsubscribed learners to access content.
// Re-enable the previous 403 response when subscription gating is turned back on.
return c.Next()
}
return c.Next()
default:
return c.Next()
}
}
}
func (a *App) RequireSubscriptionCategory(category domain.SubscriptionCategory) fiber.Handler {
return func(c *fiber.Ctx) error {
role, userID, err := subscriptionScopedUser(c)
if err != nil {
return err
}
if bypassSubscriptionForRole(role) {
return c.Next()
}
if role != domain.RoleStudent && role != domain.RoleOpenLearner {
return c.Next()
}
if categorySubscriptionGateDisabled {
// Temporary bypass to disable category-aware learner access checks without changing route wiring.
return c.Next()
}
active, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, category)
if err != nil {
a.mongoLoggerSvc.Error("category subscription check failed",
zap.Int64("userID", userID),
zap.String("category", string(category)),
zap.String("path", c.Path()),
zap.Error(err),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription")
}
if !active {
return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("An active %s subscription is required", humanizeSubscriptionCategory(category)))
}
return c.Next()
}
}
func (a *App) RequireExamPrepSubscription() fiber.Handler {
return func(c *fiber.Ctx) error {
role, userID, err := subscriptionScopedUser(c)
if err != nil {
return err
}
if bypassSubscriptionForRole(role) {
return c.Next()
}
if role != domain.RoleStudent && role != domain.RoleOpenLearner {
return c.Next()
}
if categorySubscriptionGateDisabled {
// Temporary bypass to disable category-aware learner access checks without changing route wiring.
return c.Next()
}
category, scoped, err := a.resolveExamPrepSubscriptionCategory(c)
if err != nil {
switch {
case errors.Is(err, examprepsvc.ErrCatalogCourseNotFound),
errors.Is(err, examprepsvc.ErrUnitNotFound),
errors.Is(err, examprepsvc.ErrModuleNotFound),
errors.Is(err, examprepsvc.ErrLessonNotFound),
errors.Is(err, examprepsvc.ErrPracticeNotFound):
return fiber.NewError(fiber.StatusNotFound, err.Error())
default:
a.mongoLoggerSvc.Error("exam prep category resolution failed",
zap.Int64("userID", userID),
zap.String("path", c.Path()),
zap.Error(err),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription")
}
}
if !scoped {
hasIELTS, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, domain.SubscriptionCategoryIELTS)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription")
}
hasDuolingo, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, domain.SubscriptionCategoryDuolingo)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription")
}
if !hasIELTS && !hasDuolingo {
return fiber.NewError(fiber.StatusForbidden, "An active IELTS or Duolingo subscription is required")
}
return c.Next()
}
active, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, category)
if err != nil {
a.mongoLoggerSvc.Error("exam prep subscription check failed",
zap.Int64("userID", userID),
zap.String("category", string(category)),
zap.String("path", c.Path()),
zap.Error(err),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription")
}
if !active {
return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("An active %s subscription is required", humanizeSubscriptionCategory(category)))
}
return c.Next()
}
}
func subscriptionScopedUser(c *fiber.Ctx) (domain.Role, int64, error) {
role, ok := c.Locals("role").(domain.Role)
if !ok {
return "", 0, fiber.NewError(fiber.StatusForbidden, "Role not found in context")
}
userID, ok := c.Locals("user_id").(int64)
if !ok || userID == 0 {
return role, 0, fiber.NewError(fiber.StatusUnauthorized, "Unauthorized")
}
return role, userID, nil
}
func bypassSubscriptionForRole(role domain.Role) bool {
switch role {
case domain.RoleSuperAdmin, domain.RoleAdmin, domain.RoleInstructor, domain.RoleSupport:
return true
default:
return false
}
}
func humanizeSubscriptionCategory(category domain.SubscriptionCategory) string {
return strings.ToLower(strings.ReplaceAll(string(category), "_", " "))
}
func parseRouteInt64(c *fiber.Ctx, name string) (int64, bool, error) {
raw := strings.TrimSpace(c.Params(name))
if raw == "" {
return 0, false, nil
}
id, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return 0, false, fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Invalid %s", name))
}
return id, true, nil
}
func (a *App) resolveExamPrepSubscriptionCategory(c *fiber.Ctx) (domain.SubscriptionCategory, bool, error) {
if catalogCourseID, ok, err := parseRouteInt64(c, "catalogCourseId"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByCatalogCourseID(c.Context(), catalogCourseID)
}
if unitID, ok, err := parseRouteInt64(c, "unitId"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByUnitID(c.Context(), unitID)
}
if moduleID, ok, err := parseRouteInt64(c, "moduleId"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByModuleID(c.Context(), moduleID)
}
if lessonID, ok, err := parseRouteInt64(c, "lessonId"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByLessonID(c.Context(), lessonID)
}
switch routePath := c.Route().Path; {
case strings.Contains(routePath, "/catalog-courses/:id"):
if id, ok, err := parseRouteInt64(c, "id"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByCatalogCourseID(c.Context(), id)
}
case strings.Contains(routePath, "/units/:id"):
if id, ok, err := parseRouteInt64(c, "id"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByUnitID(c.Context(), id)
}
case strings.Contains(routePath, "/modules/:id"):
if id, ok, err := parseRouteInt64(c, "id"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByModuleID(c.Context(), id)
}
case strings.Contains(routePath, "/lessons/:id"):
if id, ok, err := parseRouteInt64(c, "id"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByLessonID(c.Context(), id)
}
case strings.Contains(routePath, "/practices/:id"):
if id, ok, err := parseRouteInt64(c, "id"); err != nil {
return "", false, err
} else if ok {
return a.examPrepCategoryByPracticeID(c.Context(), id)
}
}
return "", false, nil
}
func (a *App) examPrepCategoryByCatalogCourseID(ctx context.Context, catalogCourseID int64) (domain.SubscriptionCategory, bool, error) {
catalogCourse, err := a.examPrepSvc.GetCatalogCourseByID(ctx, catalogCourseID)
if err != nil {
return "", false, err
}
return domain.SubscriptionCategory(catalogCourse.Category), true, nil
}
func (a *App) examPrepCategoryByUnitID(ctx context.Context, unitID int64) (domain.SubscriptionCategory, bool, error) {
unit, err := a.examPrepSvc.GetUnitByID(ctx, unitID)
if err != nil {
return "", false, err
}
return a.examPrepCategoryByCatalogCourseID(ctx, unit.CatalogCourseID)
}
func (a *App) examPrepCategoryByModuleID(ctx context.Context, moduleID int64) (domain.SubscriptionCategory, bool, error) {
module, err := a.examPrepSvc.GetModuleByID(ctx, moduleID)
if err != nil {
return "", false, err
}
return a.examPrepCategoryByUnitID(ctx, module.UnitID)
}
func (a *App) examPrepCategoryByLessonID(ctx context.Context, lessonID int64) (domain.SubscriptionCategory, bool, error) {
lesson, err := a.examPrepSvc.GetLessonByID(ctx, lessonID)
if err != nil {
return "", false, err
}
return a.examPrepCategoryByModuleID(ctx, lesson.UnitModuleID)
}
func (a *App) examPrepCategoryByPracticeID(ctx context.Context, practiceID int64) (domain.SubscriptionCategory, bool, error) {
practice, err := a.examPrepSvc.GetExamPrepPracticeByID(ctx, practiceID)
if err != nil {
return "", false, err
}
return a.examPrepCategoryByLessonID(ctx, practice.LessonID)
}
func (a *App) RequirePermission(permKey string) fiber.Handler {
return func(c *fiber.Ctx) error {
userRole, ok := c.Locals("role").(domain.Role)
if !ok {
return fiber.NewError(fiber.StatusForbidden, "Role not found in context")
}
if !a.rbacSvc.HasPermission(string(userRole), permKey) {
userID, _ := c.Locals("user_id").(int64)
a.mongoLoggerSvc.Warn("Permission denied",
zap.Int64("userID", userID),
zap.String("role", string(userRole)),
zap.String("permission", permKey),
zap.Int("status_code", fiber.StatusForbidden),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusForbidden, "You don't have permission to access this resource")
}
return c.Next()
}
}
func (a *App) OnlyBranchManagerAndAbove(c *fiber.Ctx) error {
userID := c.Locals("user_id").(int64)
userRole := c.Locals("role").(domain.Role)
if userRole != domain.RoleSuperAdmin && userRole != domain.RoleAdmin {
a.mongoLoggerSvc.Warn("Attempt to access restricted OnlyBranchMangerAndAbove route",
zap.Int64("userID", userID),
zap.String("role", string(userRole)),
zap.Int("status_code", fiber.StatusForbidden),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusForbidden, "This route is restricted")
}
return c.Next()
}
func (a *App) WebsocketAuthMiddleware(c *fiber.Ctx) error {
tokenStr := c.Query("token")
ip := c.IP()
userAgent := c.Get("User-Agent")
if tokenStr == "" {
a.mongoLoggerSvc.Info("Missing token in query parameter",
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusUnauthorized, "Missing token")
}
claim, err := jwtutil.ParseJwt(tokenStr, a.JwtConfig.JwtAccessKey)
if err != nil {
if errors.Is(err, jwtutil.ErrExpiredToken) {
a.mongoLoggerSvc.Info("Token expired",
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusUnauthorized, "Token expired")
}
a.logger.Error("Invalid token", "error", err)
a.mongoLoggerSvc.Info("Invalid token",
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
zap.Error(err),
)
return fiber.NewError(fiber.StatusUnauthorized, "Invalid token")
}
userID := claim.UserId
if userID == 0 {
a.mongoLoggerSvc.Info("Invalid user ID in token claims",
zap.Int("status_code", fiber.StatusUnauthorized),
zap.String("ip_address", ip),
zap.String("user_agent", userAgent),
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID")
}
c.Locals("userID", userID)
a.mongoLoggerSvc.Info("Authenticated WebSocket connection",
zap.Int64("userID", userID),
zap.Time("timestamp", time.Now()),
)
return c.Next()
}
func (a *App) TenantMiddleware(c *fiber.Ctx) error {
tenantSlug := c.Params("tenant_slug")
if tenantSlug == "" {
a.mongoLoggerSvc.Info("blank tenant param",
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusBadRequest, "tenant is required for this route")
}
// company, err := a.companySvc.GetCompanyBySlug(c.Context(), tenantSlug)
// if err != nil {
// a.mongoLoggerSvc.Info("failed to resolve tenant",
// zap.String("tenant_slug", tenantSlug),
// zap.Time("timestamp", time.Now()),
// )
// return fiber.NewError(fiber.StatusBadRequest, "failed to resolve tenant")
// }
// if !company.IsActive {
// a.mongoLoggerSvc.Info("request using deactivated tenant",
// zap.String("tenant_slug", tenantSlug),
// zap.Time("timestamp", time.Now()),
// )
// return fiber.NewError(fiber.StatusForbidden, "this tenant has been deactivated")
// }
// c.Locals("company_id", domain.ValidInt64{
// Value: company.ID,
// Valid: true,
// })
return c.Next()
}
func (a *App) TenantAuthMiddleware(c *fiber.Ctx) error {
slugID, ok := c.Locals("tenant_id").(domain.ValidInt64)
if !ok || !slugID.Valid {
a.mongoLoggerSvc.Info("invalid tenant slug",
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusBadRequest, "invalid tenant slug")
}
tokenCID, ok := c.Locals("company_id").(domain.ValidInt64)
if !ok || !tokenCID.Valid {
a.mongoLoggerSvc.Error("invalid company id in token",
zap.Time("timestamp", time.Now()),
zap.Bool("tokenCID Valid", tokenCID.Valid),
zap.Bool("ValidInt64 Type Check", ok),
)
return fiber.NewError(fiber.StatusInternalServerError, "invalid company id in token")
}
if slugID.Value != tokenCID.Value {
a.mongoLoggerSvc.Error("token company-id doesn't match the slug company_id",
zap.Time("timestamp", time.Now()),
)
return fiber.NewError(fiber.StatusInternalServerError, "invalid company_id")
}
fmt.Printf("\nTenant successfully authenticated!\n")
return c.Next()
}