package httpserver import ( "Yimaru-Backend/internal/domain" examprepsvc "Yimaru-Backend/internal/services/examprep" jwtutil "Yimaru-Backend/internal/web_server/jwt" "context" "errors" "fmt" "strconv" "strings" "time" "github.com/gofiber/fiber/v2" "go.uber.org/zap" ) var categorySubscriptionGateDisabled = true func (a *App) authMiddleware(c *fiber.Ctx) error { ip := c.IP() userAgent := c.Get("User-Agent") c.Locals("ip_address", ip) c.Locals("user_agent", userAgent) // Get Authorization header (case-insensitive) authHeader := strings.TrimSpace(c.Get("Authorization")) if authHeader == "" { // Try lowercase as fallback authHeader = strings.TrimSpace(c.Get("authorization")) } fmt.Println("--------------------------------") fmt.Println("All Headers:") allHeaders := c.GetReqHeaders() for key, value := range allHeaders { fmt.Printf(" %s: %s\n", key, value) } fmt.Println("userAgent", userAgent) fmt.Println("ip", ip) fmt.Println("authHeader", authHeader) fmt.Println("--------------------------------") if authHeader == "" { a.mongoLoggerSvc.Info("Authorization header missing", zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusUnauthorized, "Authorization header missing") } if !strings.HasPrefix(authHeader, "Bearer ") { a.mongoLoggerSvc.Info("Invalid authorization header format", zap.String("authHeader", authHeader), zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusUnauthorized, "Invalid authorization header format") } accessToken := strings.TrimPrefix(authHeader, "Bearer ") c.Locals("access_token", accessToken) claim, err := jwtutil.ParseJwt(accessToken, a.JwtConfig.JwtAccessKey) if err != nil { if errors.Is(err, jwtutil.ErrExpiredToken) { a.mongoLoggerSvc.Info("Access Token Expired", zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusUnauthorized, "Access token expired") } a.mongoLoggerSvc.Info("Invalid Access Token", zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), zap.Error(err), ) return fiber.NewError(fiber.StatusUnauthorized, "Invalid access token") } refreshToken := c.Get("Refresh-Token") if refreshToken == "" { // refreshToken = c.Cookies("refresh_token", "") // return fiber.NewError(fiber.StatusUnauthorized, "Refresh token missing") } // Asserting to make sure that only the super admin can have a nil company ID // if claim.Role != domain.RoleSuperAdmin && !claim.CompanyID.Valid { // a.mongoLoggerSvc.Error("Company Role without Company ID", // zap.Int64("userID", claim.UserId), // zap.Int("status_code", fiber.StatusInternalServerError), // zap.Error(err), // zap.Time("timestamp", time.Now()), // ) // return fiber.NewError(fiber.StatusInternalServerError, "Company Role without Company ID") // } c.Locals("user_id", claim.UserId) c.Locals("role", claim.Role) // c.Locals("company_id", domain.ValidInt64{ // Value: claim.CompanyID.Value, // Valid: claim.CompanyID.Valid, // }) c.Locals("refresh_token", refreshToken) // var branchID domain.ValidInt64 if claim.Role == domain.RoleAdmin { // branch, err := a.branchSvc.GetBranchByCashier(c.Context(), claim.UserId) // if err != nil { // a.mongoLoggerSvc.Error("Failed to get branch id for cashier", // zap.Int64("userID", claim.UserId), // zap.Int("status_code", fiber.StatusInternalServerError), // zap.Error(err), // zap.Time("timestamp", time.Now()), // ) // return fiber.NewError(fiber.StatusInternalServerError, "Failed to branch id for cashier") // } // branchID = domain.ValidInt64{ // Value: branch.ID, // Valid: true, // } } // c.Locals("branch_id", branchID) return c.Next() } func (a *App) SuperAdminOnly(c *fiber.Ctx) error { userID := c.Locals("user_id").(int64) userRole := c.Locals("role").(domain.Role) if userRole != domain.RoleSuperAdmin { a.mongoLoggerSvc.Warn("Attempt to access restricted SuperAdminOnly route", zap.Int64("userID", userID), zap.String("role", string(userRole)), zap.Int("status_code", fiber.StatusForbidden), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusForbidden, "This route is restricted") } return c.Next() } // func (a *App) CompanyOnly(c *fiber.Ctx) error { // userID := c.Locals("user_id").(int64) // userRole := c.Locals("role").(domain.Role) // if userRole == domain.RoleStudent { // a.mongoLoggerSvc.Warn("Attempt to access restricted CompanyOnly route", // zap.Int64("userID", userID), // zap.String("role", string(userRole)), // zap.Int("status_code", fiber.StatusForbidden), // zap.Time("timestamp", time.Now()), // ) // return fiber.NewError(fiber.StatusForbidden, "This route is restricted") // } // return c.Next() // } func (a *App) OnlyAdminAndAbove(c *fiber.Ctx) error { userID := c.Locals("user_id").(int64) userRole := c.Locals("role").(domain.Role) if userRole != domain.RoleSuperAdmin && userRole != domain.RoleAdmin { a.mongoLoggerSvc.Warn("Attempt to access restricted OnlyAdminAndAbove route", zap.Int64("userID", userID), zap.String("role", string(userRole)), zap.Int("status_code", fiber.StatusForbidden), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusForbidden, "This route is restricted") } return c.Next() } // RequireActiveSubscription enforces an active subscription for learner accounts. // Staff roles (SUPER_ADMIN, ADMIN, INSTRUCTOR, SUPPORT) bypass this check. // Use after authMiddleware on routes that deliver paid learning content. func (a *App) RequireActiveSubscription() fiber.Handler { return func(c *fiber.Ctx) error { role, ok := c.Locals("role").(domain.Role) if !ok { return fiber.NewError(fiber.StatusForbidden, "Role not found in context") } switch role { case domain.RoleSuperAdmin, domain.RoleAdmin, domain.RoleInstructor, domain.RoleSupport: return c.Next() case domain.RoleStudent, domain.RoleOpenLearner: userID, ok := c.Locals("user_id").(int64) if !ok || userID == 0 { return fiber.NewError(fiber.StatusUnauthorized, "Unauthorized") } active, err := a.subscriptionsSvc.HasActiveSubscription(c.Context(), userID) if err != nil { a.mongoLoggerSvc.Error("subscription check failed", zap.Int64("userID", userID), zap.String("path", c.Path()), zap.Error(err), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription") } if !active { // Temporary bypass: allow unsubscribed learners to access content. // Re-enable the previous 403 response when subscription gating is turned back on. return c.Next() } return c.Next() default: return c.Next() } } } func (a *App) RequireSubscriptionCategory(category domain.SubscriptionCategory) fiber.Handler { return func(c *fiber.Ctx) error { role, userID, err := subscriptionScopedUser(c) if err != nil { return err } if bypassSubscriptionForRole(role) { return c.Next() } if role != domain.RoleStudent && role != domain.RoleOpenLearner { return c.Next() } if categorySubscriptionGateDisabled { // Temporary bypass to disable category-aware learner access checks without changing route wiring. return c.Next() } active, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, category) if err != nil { a.mongoLoggerSvc.Error("category subscription check failed", zap.Int64("userID", userID), zap.String("category", string(category)), zap.String("path", c.Path()), zap.Error(err), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription") } if !active { return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("An active %s subscription is required", humanizeSubscriptionCategory(category))) } return c.Next() } } func (a *App) RequireExamPrepSubscription() fiber.Handler { return func(c *fiber.Ctx) error { role, userID, err := subscriptionScopedUser(c) if err != nil { return err } if bypassSubscriptionForRole(role) { return c.Next() } if role != domain.RoleStudent && role != domain.RoleOpenLearner { return c.Next() } if categorySubscriptionGateDisabled { // Temporary bypass to disable category-aware learner access checks without changing route wiring. return c.Next() } category, scoped, err := a.resolveExamPrepSubscriptionCategory(c) if err != nil { switch { case errors.Is(err, examprepsvc.ErrCatalogCourseNotFound), errors.Is(err, examprepsvc.ErrUnitNotFound), errors.Is(err, examprepsvc.ErrModuleNotFound), errors.Is(err, examprepsvc.ErrLessonNotFound), errors.Is(err, examprepsvc.ErrPracticeNotFound): return fiber.NewError(fiber.StatusNotFound, err.Error()) default: a.mongoLoggerSvc.Error("exam prep category resolution failed", zap.Int64("userID", userID), zap.String("path", c.Path()), zap.Error(err), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription") } } if !scoped { hasIELTS, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, domain.SubscriptionCategoryIELTS) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription") } hasDuolingo, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, domain.SubscriptionCategoryDuolingo) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription") } if !hasIELTS && !hasDuolingo { return fiber.NewError(fiber.StatusForbidden, "An active IELTS or Duolingo subscription is required") } return c.Next() } active, err := a.subscriptionsSvc.HasActiveSubscriptionByCategory(c.Context(), userID, category) if err != nil { a.mongoLoggerSvc.Error("exam prep subscription check failed", zap.Int64("userID", userID), zap.String("category", string(category)), zap.String("path", c.Path()), zap.Error(err), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusInternalServerError, "Failed to verify subscription") } if !active { return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("An active %s subscription is required", humanizeSubscriptionCategory(category))) } return c.Next() } } func subscriptionScopedUser(c *fiber.Ctx) (domain.Role, int64, error) { role, ok := c.Locals("role").(domain.Role) if !ok { return "", 0, fiber.NewError(fiber.StatusForbidden, "Role not found in context") } userID, ok := c.Locals("user_id").(int64) if !ok || userID == 0 { return role, 0, fiber.NewError(fiber.StatusUnauthorized, "Unauthorized") } return role, userID, nil } func bypassSubscriptionForRole(role domain.Role) bool { switch role { case domain.RoleSuperAdmin, domain.RoleAdmin, domain.RoleInstructor, domain.RoleSupport: return true default: return false } } func humanizeSubscriptionCategory(category domain.SubscriptionCategory) string { return strings.ToLower(strings.ReplaceAll(string(category), "_", " ")) } func parseRouteInt64(c *fiber.Ctx, name string) (int64, bool, error) { raw := strings.TrimSpace(c.Params(name)) if raw == "" { return 0, false, nil } id, err := strconv.ParseInt(raw, 10, 64) if err != nil { return 0, false, fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Invalid %s", name)) } return id, true, nil } func (a *App) resolveExamPrepSubscriptionCategory(c *fiber.Ctx) (domain.SubscriptionCategory, bool, error) { if catalogCourseID, ok, err := parseRouteInt64(c, "catalogCourseId"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByCatalogCourseID(c.Context(), catalogCourseID) } if unitID, ok, err := parseRouteInt64(c, "unitId"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByUnitID(c.Context(), unitID) } if moduleID, ok, err := parseRouteInt64(c, "moduleId"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByModuleID(c.Context(), moduleID) } if lessonID, ok, err := parseRouteInt64(c, "lessonId"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByLessonID(c.Context(), lessonID) } switch routePath := c.Route().Path; { case strings.Contains(routePath, "/catalog-courses/:id"): if id, ok, err := parseRouteInt64(c, "id"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByCatalogCourseID(c.Context(), id) } case strings.Contains(routePath, "/units/:id"): if id, ok, err := parseRouteInt64(c, "id"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByUnitID(c.Context(), id) } case strings.Contains(routePath, "/modules/:id"): if id, ok, err := parseRouteInt64(c, "id"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByModuleID(c.Context(), id) } case strings.Contains(routePath, "/lessons/:id"): if id, ok, err := parseRouteInt64(c, "id"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByLessonID(c.Context(), id) } case strings.Contains(routePath, "/practices/:id"): if id, ok, err := parseRouteInt64(c, "id"); err != nil { return "", false, err } else if ok { return a.examPrepCategoryByPracticeID(c.Context(), id) } } return "", false, nil } func (a *App) examPrepCategoryByCatalogCourseID(ctx context.Context, catalogCourseID int64) (domain.SubscriptionCategory, bool, error) { catalogCourse, err := a.examPrepSvc.GetCatalogCourseByID(ctx, catalogCourseID) if err != nil { return "", false, err } return domain.SubscriptionCategory(catalogCourse.Category), true, nil } func (a *App) examPrepCategoryByUnitID(ctx context.Context, unitID int64) (domain.SubscriptionCategory, bool, error) { unit, err := a.examPrepSvc.GetUnitByID(ctx, unitID) if err != nil { return "", false, err } return a.examPrepCategoryByCatalogCourseID(ctx, unit.CatalogCourseID) } func (a *App) examPrepCategoryByModuleID(ctx context.Context, moduleID int64) (domain.SubscriptionCategory, bool, error) { module, err := a.examPrepSvc.GetModuleByID(ctx, moduleID) if err != nil { return "", false, err } return a.examPrepCategoryByUnitID(ctx, module.UnitID) } func (a *App) examPrepCategoryByLessonID(ctx context.Context, lessonID int64) (domain.SubscriptionCategory, bool, error) { lesson, err := a.examPrepSvc.GetLessonByID(ctx, lessonID) if err != nil { return "", false, err } return a.examPrepCategoryByModuleID(ctx, lesson.UnitModuleID) } func (a *App) examPrepCategoryByPracticeID(ctx context.Context, practiceID int64) (domain.SubscriptionCategory, bool, error) { practice, err := a.examPrepSvc.GetExamPrepPracticeByID(ctx, practiceID) if err != nil { return "", false, err } return a.examPrepCategoryByLessonID(ctx, practice.LessonID) } func (a *App) RequirePermission(permKey string) fiber.Handler { return func(c *fiber.Ctx) error { userRole, ok := c.Locals("role").(domain.Role) if !ok { return fiber.NewError(fiber.StatusForbidden, "Role not found in context") } if !a.rbacSvc.HasPermission(string(userRole), permKey) { userID, _ := c.Locals("user_id").(int64) a.mongoLoggerSvc.Warn("Permission denied", zap.Int64("userID", userID), zap.String("role", string(userRole)), zap.String("permission", permKey), zap.Int("status_code", fiber.StatusForbidden), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusForbidden, "You don't have permission to access this resource") } return c.Next() } } func (a *App) OnlyBranchManagerAndAbove(c *fiber.Ctx) error { userID := c.Locals("user_id").(int64) userRole := c.Locals("role").(domain.Role) if userRole != domain.RoleSuperAdmin && userRole != domain.RoleAdmin { a.mongoLoggerSvc.Warn("Attempt to access restricted OnlyBranchMangerAndAbove route", zap.Int64("userID", userID), zap.String("role", string(userRole)), zap.Int("status_code", fiber.StatusForbidden), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusForbidden, "This route is restricted") } return c.Next() } func (a *App) WebsocketAuthMiddleware(c *fiber.Ctx) error { tokenStr := c.Query("token") ip := c.IP() userAgent := c.Get("User-Agent") if tokenStr == "" { a.mongoLoggerSvc.Info("Missing token in query parameter", zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusUnauthorized, "Missing token") } claim, err := jwtutil.ParseJwt(tokenStr, a.JwtConfig.JwtAccessKey) if err != nil { if errors.Is(err, jwtutil.ErrExpiredToken) { a.mongoLoggerSvc.Info("Token expired", zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusUnauthorized, "Token expired") } a.logger.Error("Invalid token", "error", err) a.mongoLoggerSvc.Info("Invalid token", zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), zap.Error(err), ) return fiber.NewError(fiber.StatusUnauthorized, "Invalid token") } userID := claim.UserId if userID == 0 { a.mongoLoggerSvc.Info("Invalid user ID in token claims", zap.Int("status_code", fiber.StatusUnauthorized), zap.String("ip_address", ip), zap.String("user_agent", userAgent), zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID") } c.Locals("userID", userID) a.mongoLoggerSvc.Info("Authenticated WebSocket connection", zap.Int64("userID", userID), zap.Time("timestamp", time.Now()), ) return c.Next() } func (a *App) TenantMiddleware(c *fiber.Ctx) error { tenantSlug := c.Params("tenant_slug") if tenantSlug == "" { a.mongoLoggerSvc.Info("blank tenant param", zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusBadRequest, "tenant is required for this route") } // company, err := a.companySvc.GetCompanyBySlug(c.Context(), tenantSlug) // if err != nil { // a.mongoLoggerSvc.Info("failed to resolve tenant", // zap.String("tenant_slug", tenantSlug), // zap.Time("timestamp", time.Now()), // ) // return fiber.NewError(fiber.StatusBadRequest, "failed to resolve tenant") // } // if !company.IsActive { // a.mongoLoggerSvc.Info("request using deactivated tenant", // zap.String("tenant_slug", tenantSlug), // zap.Time("timestamp", time.Now()), // ) // return fiber.NewError(fiber.StatusForbidden, "this tenant has been deactivated") // } // c.Locals("company_id", domain.ValidInt64{ // Value: company.ID, // Valid: true, // }) return c.Next() } func (a *App) TenantAuthMiddleware(c *fiber.Ctx) error { slugID, ok := c.Locals("tenant_id").(domain.ValidInt64) if !ok || !slugID.Valid { a.mongoLoggerSvc.Info("invalid tenant slug", zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusBadRequest, "invalid tenant slug") } tokenCID, ok := c.Locals("company_id").(domain.ValidInt64) if !ok || !tokenCID.Valid { a.mongoLoggerSvc.Error("invalid company id in token", zap.Time("timestamp", time.Now()), zap.Bool("tokenCID Valid", tokenCID.Valid), zap.Bool("ValidInt64 Type Check", ok), ) return fiber.NewError(fiber.StatusInternalServerError, "invalid company id in token") } if slugID.Value != tokenCID.Value { a.mongoLoggerSvc.Error("token company-id doesn't match the slug company_id", zap.Time("timestamp", time.Now()), ) return fiber.NewError(fiber.StatusInternalServerError, "invalid company_id") } fmt.Printf("\nTenant successfully authenticated!\n") return c.Next() }