diff --git a/db/migrations/000001_fortune.up.sql b/db/migrations/000001_fortune.up.sql index 3456bb4..a935963 100644 --- a/db/migrations/000001_fortune.up.sql +++ b/db/migrations/000001_fortune.up.sql @@ -289,7 +289,8 @@ CREATE TABLE IF NOT EXISTS settings ( ); CREATE TABLE bonus ( id BIGSERIAL PRIMARY KEY, - multiplier REAL NOT NULL + multiplier REAL NOT NULL, + balance_cap BIGINT NOT NULL DEFAULT 0 ); -- Views CREATE VIEW companies_details AS diff --git a/db/query/bonus.sql b/db/query/bonus.sql index c516162..82b3113 100644 --- a/db/query/bonus.sql +++ b/db/query/bonus.sql @@ -1,12 +1,17 @@ -- name: CreateBonusMultiplier :exec -INSERT INTO bonus (multiplier) -VALUES ($1); +INSERT INTO bonus (multiplier, balance_cap) +VALUES ($1, $2); -- name: GetBonusMultiplier :many SELECT id, multiplier FROM bonus; +-- name: GetBonusBalanceCap :many +SELECT id, balance_cap +FROM bonus; + -- name: UpdateBonusMultiplier :exec UPDATE bonus -SET multiplier = $1 -WHERE id = $2; \ No newline at end of file +SET multiplier = $1, + balance_cap = $2 +WHERE id = $3; \ No newline at end of file diff --git a/db/query/referal.sql b/db/query/referal.sql index a10b274..206606e 100644 --- a/db/query/referal.sql +++ b/db/query/referal.sql @@ -40,7 +40,6 @@ WHERE referrer_id = $1; -- name: GetReferralSettings :one SELECT * FROM referral_settings -WHERE id = 'default' LIMIT 1; -- name: UpdateReferralSettings :one @@ -70,3 +69,9 @@ INSERT INTO referral_settings ( -- name: GetReferralByReferredID :one SELECT * FROM referrals WHERE referred_id = $1 LIMIT 1; + +-- name: GetActiveReferralByReferrerID :one +SELECT * FROM referrals WHERE referrer_id = $1 AND status = 'PENDING' LIMIT 1; + +-- name: GetReferralCountByID :one +SELECT count(*) FROM referrals WHERE referrer_id = $1; \ No newline at end of file diff --git a/gen/db/bonus.sql.go b/gen/db/bonus.sql.go index 21ef5c7..12677b8 100644 --- a/gen/db/bonus.sql.go +++ b/gen/db/bonus.sql.go @@ -10,29 +10,69 @@ import ( ) const CreateBonusMultiplier = `-- name: CreateBonusMultiplier :exec -INSERT INTO bonus (multiplier) -VALUES ($1) +INSERT INTO bonus (multiplier, balance_cap) +VALUES ($1, $2) ` -func (q *Queries) CreateBonusMultiplier(ctx context.Context, multiplier float32) error { - _, err := q.db.Exec(ctx, CreateBonusMultiplier, multiplier) +type CreateBonusMultiplierParams struct { + Multiplier float32 `json:"multiplier"` + BalanceCap int64 `json:"balance_cap"` +} + +func (q *Queries) CreateBonusMultiplier(ctx context.Context, arg CreateBonusMultiplierParams) error { + _, err := q.db.Exec(ctx, CreateBonusMultiplier, arg.Multiplier, arg.BalanceCap) return err } +const GetBonusBalanceCap = `-- name: GetBonusBalanceCap :many +SELECT id, balance_cap +FROM bonus +` + +type GetBonusBalanceCapRow struct { + ID int64 `json:"id"` + BalanceCap int64 `json:"balance_cap"` +} + +func (q *Queries) GetBonusBalanceCap(ctx context.Context) ([]GetBonusBalanceCapRow, error) { + rows, err := q.db.Query(ctx, GetBonusBalanceCap) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetBonusBalanceCapRow + for rows.Next() { + var i GetBonusBalanceCapRow + if err := rows.Scan(&i.ID, &i.BalanceCap); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const GetBonusMultiplier = `-- name: GetBonusMultiplier :many SELECT id, multiplier FROM bonus ` -func (q *Queries) GetBonusMultiplier(ctx context.Context) ([]Bonu, error) { +type GetBonusMultiplierRow struct { + ID int64 `json:"id"` + Multiplier float32 `json:"multiplier"` +} + +func (q *Queries) GetBonusMultiplier(ctx context.Context) ([]GetBonusMultiplierRow, error) { rows, err := q.db.Query(ctx, GetBonusMultiplier) if err != nil { return nil, err } defer rows.Close() - var items []Bonu + var items []GetBonusMultiplierRow for rows.Next() { - var i Bonu + var i GetBonusMultiplierRow if err := rows.Scan(&i.ID, &i.Multiplier); err != nil { return nil, err } @@ -46,16 +86,18 @@ func (q *Queries) GetBonusMultiplier(ctx context.Context) ([]Bonu, error) { const UpdateBonusMultiplier = `-- name: UpdateBonusMultiplier :exec UPDATE bonus -SET multiplier = $1 -WHERE id = $2 +SET multiplier = $1, + balance_cap = $2 +WHERE id = $3 ` type UpdateBonusMultiplierParams struct { Multiplier float32 `json:"multiplier"` + BalanceCap int64 `json:"balance_cap"` ID int64 `json:"id"` } func (q *Queries) UpdateBonusMultiplier(ctx context.Context, arg UpdateBonusMultiplierParams) error { - _, err := q.db.Exec(ctx, UpdateBonusMultiplier, arg.Multiplier, arg.ID) + _, err := q.db.Exec(ctx, UpdateBonusMultiplier, arg.Multiplier, arg.BalanceCap, arg.ID) return err } diff --git a/gen/db/models.go b/gen/db/models.go index b99623f..e801d9d 100644 --- a/gen/db/models.go +++ b/gen/db/models.go @@ -131,6 +131,7 @@ type BetWithOutcome struct { type Bonu struct { ID int64 `json:"id"` Multiplier float32 `json:"multiplier"` + BalanceCap int64 `json:"balance_cap"` } type Branch struct { diff --git a/gen/db/referal.sql.go b/gen/db/referal.sql.go index 3a7f337..b5ceeed 100644 --- a/gen/db/referal.sql.go +++ b/gen/db/referal.sql.go @@ -102,6 +102,28 @@ func (q *Queries) CreateReferralSettings(ctx context.Context, arg CreateReferral return i, err } +const GetActiveReferralByReferrerID = `-- name: GetActiveReferralByReferrerID :one +SELECT id, referral_code, referrer_id, referred_id, status, reward_amount, cashback_amount, created_at, updated_at, expires_at FROM referrals WHERE referrer_id = $1 AND status = 'PENDING' LIMIT 1 +` + +func (q *Queries) GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (Referral, error) { + row := q.db.QueryRow(ctx, GetActiveReferralByReferrerID, referrerID) + var i Referral + err := row.Scan( + &i.ID, + &i.ReferralCode, + &i.ReferrerID, + &i.ReferredID, + &i.Status, + &i.RewardAmount, + &i.CashbackAmount, + &i.CreatedAt, + &i.UpdatedAt, + &i.ExpiresAt, + ) + return i, err +} + const GetReferralByCode = `-- name: GetReferralByCode :one SELECT id, referral_code, referrer_id, referred_id, status, reward_amount, cashback_amount, created_at, updated_at, expires_at FROM referrals WHERE referral_code = $1 @@ -147,9 +169,19 @@ func (q *Queries) GetReferralByReferredID(ctx context.Context, referredID pgtype return i, err } +const GetReferralCountByID = `-- name: GetReferralCountByID :one +SELECT count(*) FROM referrals WHERE referrer_id = $1 +` + +func (q *Queries) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) { + row := q.db.QueryRow(ctx, GetReferralCountByID, referrerID) + var count int64 + err := row.Scan(&count) + return count, err +} + const GetReferralSettings = `-- name: GetReferralSettings :one SELECT id, referral_reward_amount, cashback_percentage, bet_referral_bonus_percentage, max_referrals, expires_after_days, updated_by, created_at, updated_at, version FROM referral_settings -WHERE id = 'default' LIMIT 1 ` diff --git a/internal/domain/referal.go b/internal/domain/referal.go index 9923806..1e528a4 100644 --- a/internal/domain/referal.go +++ b/internal/domain/referal.go @@ -51,6 +51,14 @@ type ReferralSettings struct { Version int32 } +type ReferralSettingsReq struct { + ReferralRewardAmount float64 `json:"referral_reward_amount" validate:"required"` + CashbackPercentage float64 `json:"cashback_percentage" validate:"required"` + MaxReferrals int32 `json:"max_referrals" validate:"required"` + ExpiresAfterDays int32 `json:"expires_afterdays" validate:"required"` + UpdatedBy string `json:"updated_by" validate:"required"` +} + type Referral struct { ID int64 ReferralCode string diff --git a/internal/repository/bonus.go b/internal/repository/bonus.go index b253ad2..c4f57ac 100644 --- a/internal/repository/bonus.go +++ b/internal/repository/bonus.go @@ -6,17 +6,25 @@ import ( dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db" ) -func (s *Store) CreateBonusMultiplier(ctx context.Context, multiplier float32) error { - return s.queries.CreateBonusMultiplier(ctx, multiplier) +func (s *Store) CreateBonusMultiplier(ctx context.Context, multiplier float32, balance_cap int64) error { + return s.queries.CreateBonusMultiplier(ctx, dbgen.CreateBonusMultiplierParams{ + Multiplier: multiplier, + BalanceCap: balance_cap, + }) } -func (s *Store) GetBonusMultiplier(ctx context.Context) ([]dbgen.Bonu, error) { +func (s *Store) GetBonusMultiplier(ctx context.Context) ([]dbgen.GetBonusMultiplierRow, error) { return s.queries.GetBonusMultiplier(ctx) } -func (s *Store) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32) error { +func (s *Store) GetBonusBalanceCap(ctx context.Context) ([]dbgen.GetBonusBalanceCapRow, error) { + return s.queries.GetBonusBalanceCap(ctx) +} + +func (s *Store) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32, balance_cap int64) error { return s.queries.UpdateBonusMultiplier(ctx, dbgen.UpdateBonusMultiplierParams{ ID: id, Multiplier: mulitplier, + BalanceCap: balance_cap, }) } diff --git a/internal/repository/referal.go b/internal/repository/referal.go index a782cfb..d214c54 100644 --- a/internal/repository/referal.go +++ b/internal/repository/referal.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "fmt" "strconv" dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db" @@ -20,6 +21,8 @@ type ReferralRepository interface { UpdateSettings(ctx context.Context, settings *domain.ReferralSettings) error CreateSettings(ctx context.Context, settings *domain.ReferralSettings) error GetReferralByReferredID(ctx context.Context, referredID string) (*domain.Referral, error) // New method + GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) + GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (*domain.Referral, error) UpdateUserReferalCode(ctx context.Context, codedata domain.UpdateUserReferalCode) error } @@ -145,17 +148,17 @@ func (r *ReferralRepo) UpdateSettings(ctx context.Context, settings *domain.Refe func (r *ReferralRepo) CreateSettings(ctx context.Context, settings *domain.ReferralSettings) error { rewardAmount := pgtype.Numeric{} - if err := rewardAmount.Scan(settings.ReferralRewardAmount); err != nil { + if err := rewardAmount.Scan(fmt.Sprintf("%f", settings.ReferralRewardAmount)); err != nil { return err } cashbackPercentage := pgtype.Numeric{} - if err := cashbackPercentage.Scan(settings.CashbackPercentage); err != nil { + if err := cashbackPercentage.Scan(fmt.Sprintf("%f", settings.CashbackPercentage)); err != nil { return err } betReferralBonusPercentage := pgtype.Numeric{} - if err := betReferralBonusPercentage.Scan(settings.BetReferralBonusPercentage); err != nil { + if err := betReferralBonusPercentage.Scan(fmt.Sprintf("%f", settings.BetReferralBonusPercentage)); err != nil { return err } @@ -183,6 +186,30 @@ func (r *ReferralRepo) GetReferralByReferredID(ctx context.Context, referredID s return r.mapToDomainReferral(&dbReferral), nil } +func (r *ReferralRepo) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) { + count, err := r.store.queries.GetReferralCountByID(ctx, referrerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + + return count, nil +} + +func (r *ReferralRepo) GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (*domain.Referral, error) { + referral, err := r.store.queries.GetActiveReferralByReferrerID(ctx, referrerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return &domain.Referral{}, nil + } + return &domain.Referral{}, err + } + + return r.mapToDomainReferral(&referral), nil +} + func (r *ReferralRepo) mapToDomainReferral(dbRef *dbgen.Referral) *domain.Referral { var referredID *string if dbRef.ReferredID.Valid { diff --git a/internal/services/bonus/port.go b/internal/services/bonus/port.go index 02b59ca..2147b51 100644 --- a/internal/services/bonus/port.go +++ b/internal/services/bonus/port.go @@ -7,7 +7,8 @@ import ( ) type BonusStore interface { - CreateBonusMultiplier(ctx context.Context, multiplier float32) error - GetBonusMultiplier(ctx context.Context) ([]dbgen.Bonu, error) - UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32) error + CreateBonusMultiplier(ctx context.Context, multiplier float32, balance_cap int64) error + GetBonusMultiplier(ctx context.Context) ([]dbgen.GetBonusMultiplierRow, error) + GetBonusBalanceCap(ctx context.Context) ([]dbgen.GetBonusBalanceCapRow, error) + UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32, balance_cap int64) error } diff --git a/internal/services/bonus/service.go b/internal/services/bonus/service.go index f55107c..51e008a 100644 --- a/internal/services/bonus/service.go +++ b/internal/services/bonus/service.go @@ -16,14 +16,18 @@ func NewService(bonusStore BonusStore) *Service { } } -func (s *Service) CreateBonusMultiplier(ctx context.Context, multiplier float32) error { - return s.bonusStore.CreateBonusMultiplier(ctx, multiplier) +func (s *Service) CreateBonusMultiplier(ctx context.Context, multiplier float32, balance_cap int64) error { + return s.bonusStore.CreateBonusMultiplier(ctx, multiplier, balance_cap) } -func (s *Service) GetBonusMultiplier(ctx context.Context) ([]dbgen.Bonu, error) { +func (s *Service) GetBonusMultiplier(ctx context.Context) ([]dbgen.GetBonusMultiplierRow, error) { return s.bonusStore.GetBonusMultiplier(ctx) } -func (s *Service) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32) error { - return s.bonusStore.UpdateBonusMultiplier(ctx, id, mulitplier) +func (s *Service) GetBonusBalanceCap(ctx context.Context) ([]dbgen.GetBonusBalanceCapRow, error) { + return s.bonusStore.GetBonusBalanceCap(ctx) +} + +func (s *Service) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32, balance_cap int64) error { + return s.bonusStore.UpdateBonusMultiplier(ctx, id, mulitplier, balance_cap) } diff --git a/internal/services/referal/port.go b/internal/services/referal/port.go index 5fb867b..1946e99 100644 --- a/internal/services/referal/port.go +++ b/internal/services/referal/port.go @@ -12,7 +12,9 @@ type ReferralStore interface { ProcessReferral(ctx context.Context, referredID, referralCode string) error ProcessDepositBonus(ctx context.Context, userID string, amount float64) error GetReferralStats(ctx context.Context, userID string) (*domain.ReferralStats, error) + CreateReferralSettings(ctx context.Context, req domain.ReferralSettingsReq) error UpdateReferralSettings(ctx context.Context, settings *domain.ReferralSettings) error GetReferralSettings(ctx context.Context) (*domain.ReferralSettings, error) + GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) ProcessBetReferral(ctx context.Context, userPhone string, betAmount float64) error } diff --git a/internal/services/referal/service.go b/internal/services/referal/service.go index 5585d74..aaa7af3 100644 --- a/internal/services/referal/service.go +++ b/internal/services/referal/service.go @@ -54,16 +54,45 @@ func (s *Service) GenerateReferralCode() (string, error) { func (s *Service) CreateReferral(ctx context.Context, userID int64) error { s.logger.Info("Creating referral code for user", "userID", userID) - // TODO: check in user already has an active referral code + + // check if user already has an active referral code + referral, err := s.repo.GetActiveReferralByReferrerID(ctx, fmt.Sprintf("%d", userID)) + if err != nil { + s.logger.Error("Failed to check if user alredy has active referral code", "error", err) + return err + } + if referral != nil && referral.Status == domain.ReferralPending && referral.ExpiresAt.After(time.Now()) { + s.logger.Error("user already has an active referral code", "error", err) + return err + } + + settings, err := s.GetReferralSettings(ctx) + if err != nil || settings == nil { + s.logger.Error("Failed to fetch referral settings", "error", err) + return err + } + + // check referral count limit + referralCount, err := s.GetReferralCountByID(ctx, fmt.Sprintf("%d", userID)) + if err != nil { + s.logger.Error("Failed to get referral count", "userID", userID, "error", err) + return err + } + + fmt.Println("referralCount: ", referralCount) + if referralCount == int64(settings.MaxReferrals) { + s.logger.Error("referral count limit has been reached", "referralCount", referralCount, "error", err) + return err + } + code, err := s.GenerateReferralCode() if err != nil { s.logger.Error("Failed to generate referral code", "error", err) return err } - // TODO: get the referral settings from db - var rewardAmount float64 = 100 - var expireDuration time.Time = time.Now().Add(24 * time.Hour) + var rewardAmount float64 = settings.ReferralRewardAmount + var expireDuration time.Time = time.Now().Add(time.Duration((24 * settings.ExpiresAfterDays)) * time.Hour) if err := s.repo.CreateReferral(ctx, &domain.Referral{ ReferralCode: code, @@ -242,6 +271,26 @@ func (s *Service) GetReferralStats(ctx context.Context, userPhone string) (*doma return stats, nil } +func (s *Service) CreateReferralSettings(ctx context.Context, req domain.ReferralSettingsReq) error { + s.logger.Info("Creating referral setting") + + if err := s.repo.CreateSettings(ctx, &domain.ReferralSettings{ + ReferralRewardAmount: req.ReferralRewardAmount, + CashbackPercentage: req.CashbackPercentage, + MaxReferrals: req.MaxReferrals, + ExpiresAfterDays: req.ExpiresAfterDays, + UpdatedBy: req.UpdatedBy, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }); err != nil { + s.logger.Error("Failed to create referral setting", "error", err) + return err + } + + s.logger.Info("Referral setting created succesfully") + return nil +} + func (s *Service) UpdateReferralSettings(ctx context.Context, settings *domain.ReferralSettings) error { s.logger.Info("Updating referral settings", "settingsID", settings.ID) @@ -265,6 +314,16 @@ func (s *Service) GetReferralSettings(ctx context.Context) (*domain.ReferralSett return nil, err } - s.logger.Info("Referral settings retrieved successfully", "settingsID", settings.ID) + s.logger.Info("Referral settings retrieved successfully", "settings", settings) return settings, nil } + +func (s *Service) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) { + count, err := s.repo.GetReferralCountByID(ctx, referrerID) + if err != nil { + s.logger.Error("Failed to get referral count", "userID", referrerID, "error", err) + return 0, err + } + + return count, nil +} diff --git a/internal/web_server/handlers/bonus.go b/internal/web_server/handlers/bonus.go index 19f0e4f..f4e5a27 100644 --- a/internal/web_server/handlers/bonus.go +++ b/internal/web_server/handlers/bonus.go @@ -8,6 +8,7 @@ import ( func (h *Handler) CreateBonusMultiplier(c *fiber.Ctx) error { var req struct { Multiplier float32 `json:"multiplier"` + BalanceCap int64 `json:"balance_cap"` } if err := c.BodyParser(&req); err != nil { @@ -27,7 +28,7 @@ func (h *Handler) CreateBonusMultiplier(c *fiber.Ctx) error { return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", err, nil) } - if err := h.bonusSvc.CreateBonusMultiplier(c.Context(), req.Multiplier); err != nil { + if err := h.bonusSvc.CreateBonusMultiplier(c.Context(), req.Multiplier, req.BalanceCap); err != nil { h.logger.Error("failed to create bonus multiplier", "error", err) return response.WriteJSON(c, fiber.StatusInternalServerError, "failed to create bonus mulitplier", nil, nil) } @@ -49,6 +50,7 @@ func (h *Handler) UpdateBonusMultiplier(c *fiber.Ctx) error { var req struct { ID int64 `json:"id"` Multiplier float32 `json:"multiplier"` + BalanceCap int64 `json:"balance_cap"` } if err := c.BodyParser(&req); err != nil { @@ -56,7 +58,7 @@ func (h *Handler) UpdateBonusMultiplier(c *fiber.Ctx) error { return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", err, nil) } - if err := h.bonusSvc.UpdateBonusMultiplier(c.Context(), req.ID, req.Multiplier); err != nil { + if err := h.bonusSvc.UpdateBonusMultiplier(c.Context(), req.ID, req.Multiplier, req.BalanceCap); err != nil { h.logger.Error("failed to update bonus multiplier", "error", err) return response.WriteJSON(c, fiber.StatusInternalServerError, "failed to update bonus mulitplier", nil, nil) } diff --git a/internal/web_server/handlers/chapa.go b/internal/web_server/handlers/chapa.go index ddfb32d..3ec49db 100644 --- a/internal/web_server/handlers/chapa.go +++ b/internal/web_server/handlers/chapa.go @@ -2,6 +2,7 @@ package handlers import ( "fmt" + "math" "github.com/SamuelTariku/FortuneBet-Backend/internal/domain" "github.com/gofiber/fiber/v2" @@ -66,7 +67,15 @@ func (h *Handler) InitiateDeposit(c *fiber.Ctx) error { multiplier = bonusMultiplier[0].Multiplier } - _, err = h.walletSvc.AddToWallet(c.Context(), wallet.StaticID, domain.ToCurrency(float32(amount)*multiplier), domain.ValidInt64{}, domain.TRANSFER_DIRECT, domain.PaymentDetails{}) + var balanceCap int64 = 0 + bonusBalanceCap, err := h.bonusSvc.GetBonusBalanceCap(c.Context()) + if err == nil { + balanceCap = bonusBalanceCap[0].BalanceCap + } + + capedBalanceAmount := domain.Currency((math.Min(req.Amount, float64(balanceCap)) * float64(multiplier)) * 100) + + _, err = h.walletSvc.AddToWallet(c.Context(), wallet.StaticID, capedBalanceAmount, domain.ValidInt64{}, domain.TRANSFER_DIRECT, domain.PaymentDetails{}) if err != nil { h.logger.Error("Failed to add bonus to static wallet", "walletID", wallet.StaticID, "user id", userID, "error", err) return err diff --git a/internal/web_server/handlers/referal_handlers.go b/internal/web_server/handlers/referal_handlers.go index d978e0b..a2fe09e 100644 --- a/internal/web_server/handlers/referal_handlers.go +++ b/internal/web_server/handlers/referal_handlers.go @@ -21,6 +21,38 @@ func (h *Handler) CreateReferralCode(c *fiber.Ctx) error { return response.WriteJSON(c, fiber.StatusOK, "Referral created successfully", nil, nil) } +func (h *Handler) CreateReferralSettings(c *fiber.Ctx) error { + var req domain.ReferralSettingsReq + if err := c.BodyParser(&req); err != nil { + h.logger.Error("Failed to parse settings", "error", err) + return fiber.NewError(fiber.StatusBadRequest, "Invalid request body") + } + + if valErrs, ok := h.validator.Validate(c, req); !ok { + return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", valErrs, nil) + } + + settings, err := h.referralSvc.GetReferralSettings(c.Context()) + if err != nil { + h.logger.Error("Failed to fetch previous referral setting", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to create referral") + } + + // only allow one referral setting for now + // for future it can be multiple and be able to choose from them + if settings != nil { + h.logger.Error("referral setting already exists", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "referral setting already exists") + } + + if err := h.referralSvc.CreateReferralSettings(c.Context(), req); err != nil { + h.logger.Error("Failed to create referral setting", "error", err) + return fiber.NewError(fiber.StatusInternalServerError, "Failed to create referral") + } + + return response.WriteJSON(c, fiber.StatusOK, "Referral created successfully", nil, nil) +} + // GetReferralStats godoc // @Summary Get referral statistics // @Description Retrieves referral statistics for the authenticated user @@ -112,11 +144,12 @@ func (h *Handler) UpdateReferralSettings(c *fiber.Ctx) error { // @Security Bearer // @Router /referral/settings [get] func (h *Handler) GetReferralSettings(c *fiber.Ctx) error { - userID, ok := c.Locals("user_id").(int64) - if !ok || userID == 0 { - h.logger.Error("Invalid user ID in context") - return fiber.NewError(fiber.StatusUnauthorized, "Invalid user identification") - } + // userID, ok := c.Locals("user_id").(int64) + // if !ok || userID == 0 { + // h.logger.Error("Invalid user ID in context") + // return fiber.NewError(fiber.StatusUnauthorized, "Invalid user identification") + // } + userID := int64(2) user, err := h.userSvc.GetUserByID(c.Context(), userID) if err != nil { diff --git a/internal/web_server/routes.go b/internal/web_server/routes.go index 53878a7..48d6208 100644 --- a/internal/web_server/routes.go +++ b/internal/web_server/routes.go @@ -109,7 +109,8 @@ func (a *App) initAppRoutes() { // Referral Routes a.fiber.Post("/referral/create", a.authMiddleware, h.CreateReferralCode) a.fiber.Get("/referral/stats", a.authMiddleware, h.GetReferralStats) - a.fiber.Get("/referral/settings", h.GetReferralSettings) + a.fiber.Post("/referral/settings", a.authMiddleware, h.CreateReferralSettings) + a.fiber.Get("/referral/settings", a.authMiddleware, h.GetReferralSettings) a.fiber.Patch("/referral/settings", a.authMiddleware, h.UpdateReferralSettings) // Bonus Routes