Merge branch 'referral-promotion'

This commit is contained in:
Asher Samuel 2025-06-30 18:16:49 +03:00
commit 0e42facebf
17 changed files with 286 additions and 46 deletions

View File

@ -289,7 +289,8 @@ CREATE TABLE IF NOT EXISTS settings (
); );
CREATE TABLE bonus ( CREATE TABLE bonus (
id BIGSERIAL PRIMARY KEY, id BIGSERIAL PRIMARY KEY,
multiplier REAL NOT NULL multiplier REAL NOT NULL,
balance_cap BIGINT NOT NULL DEFAULT 0
); );
-- Views -- Views
CREATE VIEW companies_details AS CREATE VIEW companies_details AS

View File

@ -1,12 +1,17 @@
-- name: CreateBonusMultiplier :exec -- name: CreateBonusMultiplier :exec
INSERT INTO bonus (multiplier) INSERT INTO bonus (multiplier, balance_cap)
VALUES ($1); VALUES ($1, $2);
-- name: GetBonusMultiplier :many -- name: GetBonusMultiplier :many
SELECT id, multiplier SELECT id, multiplier
FROM bonus; FROM bonus;
-- name: GetBonusBalanceCap :many
SELECT id, balance_cap
FROM bonus;
-- name: UpdateBonusMultiplier :exec -- name: UpdateBonusMultiplier :exec
UPDATE bonus UPDATE bonus
SET multiplier = $1 SET multiplier = $1,
WHERE id = $2; balance_cap = $2
WHERE id = $3;

View File

@ -40,7 +40,6 @@ WHERE referrer_id = $1;
-- name: GetReferralSettings :one -- name: GetReferralSettings :one
SELECT * FROM referral_settings SELECT * FROM referral_settings
WHERE id = 'default'
LIMIT 1; LIMIT 1;
-- name: UpdateReferralSettings :one -- name: UpdateReferralSettings :one
@ -70,3 +69,9 @@ INSERT INTO referral_settings (
-- name: GetReferralByReferredID :one -- name: GetReferralByReferredID :one
SELECT * FROM referrals WHERE referred_id = $1 LIMIT 1; SELECT * FROM referrals WHERE referred_id = $1 LIMIT 1;
-- name: GetActiveReferralByReferrerID :one
SELECT * FROM referrals WHERE referrer_id = $1 AND status = 'PENDING' LIMIT 1;
-- name: GetReferralCountByID :one
SELECT count(*) FROM referrals WHERE referrer_id = $1;

View File

@ -10,29 +10,69 @@ import (
) )
const CreateBonusMultiplier = `-- name: CreateBonusMultiplier :exec const CreateBonusMultiplier = `-- name: CreateBonusMultiplier :exec
INSERT INTO bonus (multiplier) INSERT INTO bonus (multiplier, balance_cap)
VALUES ($1) VALUES ($1, $2)
` `
func (q *Queries) CreateBonusMultiplier(ctx context.Context, multiplier float32) error { type CreateBonusMultiplierParams struct {
_, err := q.db.Exec(ctx, CreateBonusMultiplier, multiplier) Multiplier float32 `json:"multiplier"`
BalanceCap int64 `json:"balance_cap"`
}
func (q *Queries) CreateBonusMultiplier(ctx context.Context, arg CreateBonusMultiplierParams) error {
_, err := q.db.Exec(ctx, CreateBonusMultiplier, arg.Multiplier, arg.BalanceCap)
return err return err
} }
const GetBonusBalanceCap = `-- name: GetBonusBalanceCap :many
SELECT id, balance_cap
FROM bonus
`
type GetBonusBalanceCapRow struct {
ID int64 `json:"id"`
BalanceCap int64 `json:"balance_cap"`
}
func (q *Queries) GetBonusBalanceCap(ctx context.Context) ([]GetBonusBalanceCapRow, error) {
rows, err := q.db.Query(ctx, GetBonusBalanceCap)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetBonusBalanceCapRow
for rows.Next() {
var i GetBonusBalanceCapRow
if err := rows.Scan(&i.ID, &i.BalanceCap); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const GetBonusMultiplier = `-- name: GetBonusMultiplier :many const GetBonusMultiplier = `-- name: GetBonusMultiplier :many
SELECT id, multiplier SELECT id, multiplier
FROM bonus FROM bonus
` `
func (q *Queries) GetBonusMultiplier(ctx context.Context) ([]Bonu, error) { type GetBonusMultiplierRow struct {
ID int64 `json:"id"`
Multiplier float32 `json:"multiplier"`
}
func (q *Queries) GetBonusMultiplier(ctx context.Context) ([]GetBonusMultiplierRow, error) {
rows, err := q.db.Query(ctx, GetBonusMultiplier) rows, err := q.db.Query(ctx, GetBonusMultiplier)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var items []Bonu var items []GetBonusMultiplierRow
for rows.Next() { for rows.Next() {
var i Bonu var i GetBonusMultiplierRow
if err := rows.Scan(&i.ID, &i.Multiplier); err != nil { if err := rows.Scan(&i.ID, &i.Multiplier); err != nil {
return nil, err return nil, err
} }
@ -46,16 +86,18 @@ func (q *Queries) GetBonusMultiplier(ctx context.Context) ([]Bonu, error) {
const UpdateBonusMultiplier = `-- name: UpdateBonusMultiplier :exec const UpdateBonusMultiplier = `-- name: UpdateBonusMultiplier :exec
UPDATE bonus UPDATE bonus
SET multiplier = $1 SET multiplier = $1,
WHERE id = $2 balance_cap = $2
WHERE id = $3
` `
type UpdateBonusMultiplierParams struct { type UpdateBonusMultiplierParams struct {
Multiplier float32 `json:"multiplier"` Multiplier float32 `json:"multiplier"`
BalanceCap int64 `json:"balance_cap"`
ID int64 `json:"id"` ID int64 `json:"id"`
} }
func (q *Queries) UpdateBonusMultiplier(ctx context.Context, arg UpdateBonusMultiplierParams) error { func (q *Queries) UpdateBonusMultiplier(ctx context.Context, arg UpdateBonusMultiplierParams) error {
_, err := q.db.Exec(ctx, UpdateBonusMultiplier, arg.Multiplier, arg.ID) _, err := q.db.Exec(ctx, UpdateBonusMultiplier, arg.Multiplier, arg.BalanceCap, arg.ID)
return err return err
} }

View File

@ -131,6 +131,7 @@ type BetWithOutcome struct {
type Bonu struct { type Bonu struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Multiplier float32 `json:"multiplier"` Multiplier float32 `json:"multiplier"`
BalanceCap int64 `json:"balance_cap"`
} }
type Branch struct { type Branch struct {

View File

@ -102,6 +102,28 @@ func (q *Queries) CreateReferralSettings(ctx context.Context, arg CreateReferral
return i, err return i, err
} }
const GetActiveReferralByReferrerID = `-- name: GetActiveReferralByReferrerID :one
SELECT id, referral_code, referrer_id, referred_id, status, reward_amount, cashback_amount, created_at, updated_at, expires_at FROM referrals WHERE referrer_id = $1 AND status = 'PENDING' LIMIT 1
`
func (q *Queries) GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (Referral, error) {
row := q.db.QueryRow(ctx, GetActiveReferralByReferrerID, referrerID)
var i Referral
err := row.Scan(
&i.ID,
&i.ReferralCode,
&i.ReferrerID,
&i.ReferredID,
&i.Status,
&i.RewardAmount,
&i.CashbackAmount,
&i.CreatedAt,
&i.UpdatedAt,
&i.ExpiresAt,
)
return i, err
}
const GetReferralByCode = `-- name: GetReferralByCode :one const GetReferralByCode = `-- name: GetReferralByCode :one
SELECT id, referral_code, referrer_id, referred_id, status, reward_amount, cashback_amount, created_at, updated_at, expires_at FROM referrals SELECT id, referral_code, referrer_id, referred_id, status, reward_amount, cashback_amount, created_at, updated_at, expires_at FROM referrals
WHERE referral_code = $1 WHERE referral_code = $1
@ -147,9 +169,19 @@ func (q *Queries) GetReferralByReferredID(ctx context.Context, referredID pgtype
return i, err return i, err
} }
const GetReferralCountByID = `-- name: GetReferralCountByID :one
SELECT count(*) FROM referrals WHERE referrer_id = $1
`
func (q *Queries) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) {
row := q.db.QueryRow(ctx, GetReferralCountByID, referrerID)
var count int64
err := row.Scan(&count)
return count, err
}
const GetReferralSettings = `-- name: GetReferralSettings :one const GetReferralSettings = `-- name: GetReferralSettings :one
SELECT id, referral_reward_amount, cashback_percentage, bet_referral_bonus_percentage, max_referrals, expires_after_days, updated_by, created_at, updated_at, version FROM referral_settings SELECT id, referral_reward_amount, cashback_percentage, bet_referral_bonus_percentage, max_referrals, expires_after_days, updated_by, created_at, updated_at, version FROM referral_settings
WHERE id = 'default'
LIMIT 1 LIMIT 1
` `

View File

@ -51,6 +51,14 @@ type ReferralSettings struct {
Version int32 Version int32
} }
type ReferralSettingsReq struct {
ReferralRewardAmount float64 `json:"referral_reward_amount" validate:"required"`
CashbackPercentage float64 `json:"cashback_percentage" validate:"required"`
MaxReferrals int32 `json:"max_referrals" validate:"required"`
ExpiresAfterDays int32 `json:"expires_afterdays" validate:"required"`
UpdatedBy string `json:"updated_by" validate:"required"`
}
type Referral struct { type Referral struct {
ID int64 ID int64
ReferralCode string ReferralCode string

View File

@ -6,17 +6,25 @@ import (
dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db" dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db"
) )
func (s *Store) CreateBonusMultiplier(ctx context.Context, multiplier float32) error { func (s *Store) CreateBonusMultiplier(ctx context.Context, multiplier float32, balance_cap int64) error {
return s.queries.CreateBonusMultiplier(ctx, multiplier) return s.queries.CreateBonusMultiplier(ctx, dbgen.CreateBonusMultiplierParams{
Multiplier: multiplier,
BalanceCap: balance_cap,
})
} }
func (s *Store) GetBonusMultiplier(ctx context.Context) ([]dbgen.Bonu, error) { func (s *Store) GetBonusMultiplier(ctx context.Context) ([]dbgen.GetBonusMultiplierRow, error) {
return s.queries.GetBonusMultiplier(ctx) return s.queries.GetBonusMultiplier(ctx)
} }
func (s *Store) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32) error { func (s *Store) GetBonusBalanceCap(ctx context.Context) ([]dbgen.GetBonusBalanceCapRow, error) {
return s.queries.GetBonusBalanceCap(ctx)
}
func (s *Store) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32, balance_cap int64) error {
return s.queries.UpdateBonusMultiplier(ctx, dbgen.UpdateBonusMultiplierParams{ return s.queries.UpdateBonusMultiplier(ctx, dbgen.UpdateBonusMultiplierParams{
ID: id, ID: id,
Multiplier: mulitplier, Multiplier: mulitplier,
BalanceCap: balance_cap,
}) })
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"strconv" "strconv"
dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db" dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db"
@ -20,6 +21,8 @@ type ReferralRepository interface {
UpdateSettings(ctx context.Context, settings *domain.ReferralSettings) error UpdateSettings(ctx context.Context, settings *domain.ReferralSettings) error
CreateSettings(ctx context.Context, settings *domain.ReferralSettings) error CreateSettings(ctx context.Context, settings *domain.ReferralSettings) error
GetReferralByReferredID(ctx context.Context, referredID string) (*domain.Referral, error) // New method GetReferralByReferredID(ctx context.Context, referredID string) (*domain.Referral, error) // New method
GetReferralCountByID(ctx context.Context, referrerID string) (int64, error)
GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (*domain.Referral, error)
UpdateUserReferalCode(ctx context.Context, codedata domain.UpdateUserReferalCode) error UpdateUserReferalCode(ctx context.Context, codedata domain.UpdateUserReferalCode) error
} }
@ -145,17 +148,17 @@ func (r *ReferralRepo) UpdateSettings(ctx context.Context, settings *domain.Refe
func (r *ReferralRepo) CreateSettings(ctx context.Context, settings *domain.ReferralSettings) error { func (r *ReferralRepo) CreateSettings(ctx context.Context, settings *domain.ReferralSettings) error {
rewardAmount := pgtype.Numeric{} rewardAmount := pgtype.Numeric{}
if err := rewardAmount.Scan(settings.ReferralRewardAmount); err != nil { if err := rewardAmount.Scan(fmt.Sprintf("%f", settings.ReferralRewardAmount)); err != nil {
return err return err
} }
cashbackPercentage := pgtype.Numeric{} cashbackPercentage := pgtype.Numeric{}
if err := cashbackPercentage.Scan(settings.CashbackPercentage); err != nil { if err := cashbackPercentage.Scan(fmt.Sprintf("%f", settings.CashbackPercentage)); err != nil {
return err return err
} }
betReferralBonusPercentage := pgtype.Numeric{} betReferralBonusPercentage := pgtype.Numeric{}
if err := betReferralBonusPercentage.Scan(settings.BetReferralBonusPercentage); err != nil { if err := betReferralBonusPercentage.Scan(fmt.Sprintf("%f", settings.BetReferralBonusPercentage)); err != nil {
return err return err
} }
@ -183,6 +186,30 @@ func (r *ReferralRepo) GetReferralByReferredID(ctx context.Context, referredID s
return r.mapToDomainReferral(&dbReferral), nil return r.mapToDomainReferral(&dbReferral), nil
} }
func (r *ReferralRepo) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) {
count, err := r.store.queries.GetReferralCountByID(ctx, referrerID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return 0, err
}
return count, nil
}
func (r *ReferralRepo) GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (*domain.Referral, error) {
referral, err := r.store.queries.GetActiveReferralByReferrerID(ctx, referrerID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return &domain.Referral{}, nil
}
return &domain.Referral{}, err
}
return r.mapToDomainReferral(&referral), nil
}
func (r *ReferralRepo) mapToDomainReferral(dbRef *dbgen.Referral) *domain.Referral { func (r *ReferralRepo) mapToDomainReferral(dbRef *dbgen.Referral) *domain.Referral {
var referredID *string var referredID *string
if dbRef.ReferredID.Valid { if dbRef.ReferredID.Valid {

View File

@ -7,7 +7,8 @@ import (
) )
type BonusStore interface { type BonusStore interface {
CreateBonusMultiplier(ctx context.Context, multiplier float32) error CreateBonusMultiplier(ctx context.Context, multiplier float32, balance_cap int64) error
GetBonusMultiplier(ctx context.Context) ([]dbgen.Bonu, error) GetBonusMultiplier(ctx context.Context) ([]dbgen.GetBonusMultiplierRow, error)
UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32) error GetBonusBalanceCap(ctx context.Context) ([]dbgen.GetBonusBalanceCapRow, error)
UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32, balance_cap int64) error
} }

View File

@ -16,14 +16,18 @@ func NewService(bonusStore BonusStore) *Service {
} }
} }
func (s *Service) CreateBonusMultiplier(ctx context.Context, multiplier float32) error { func (s *Service) CreateBonusMultiplier(ctx context.Context, multiplier float32, balance_cap int64) error {
return s.bonusStore.CreateBonusMultiplier(ctx, multiplier) return s.bonusStore.CreateBonusMultiplier(ctx, multiplier, balance_cap)
} }
func (s *Service) GetBonusMultiplier(ctx context.Context) ([]dbgen.Bonu, error) { func (s *Service) GetBonusMultiplier(ctx context.Context) ([]dbgen.GetBonusMultiplierRow, error) {
return s.bonusStore.GetBonusMultiplier(ctx) return s.bonusStore.GetBonusMultiplier(ctx)
} }
func (s *Service) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32) error { func (s *Service) GetBonusBalanceCap(ctx context.Context) ([]dbgen.GetBonusBalanceCapRow, error) {
return s.bonusStore.UpdateBonusMultiplier(ctx, id, mulitplier) return s.bonusStore.GetBonusBalanceCap(ctx)
}
func (s *Service) UpdateBonusMultiplier(ctx context.Context, id int64, mulitplier float32, balance_cap int64) error {
return s.bonusStore.UpdateBonusMultiplier(ctx, id, mulitplier, balance_cap)
} }

View File

@ -12,7 +12,9 @@ type ReferralStore interface {
ProcessReferral(ctx context.Context, referredID, referralCode string) error ProcessReferral(ctx context.Context, referredID, referralCode string) error
ProcessDepositBonus(ctx context.Context, userID string, amount float64) error ProcessDepositBonus(ctx context.Context, userID string, amount float64) error
GetReferralStats(ctx context.Context, userID string) (*domain.ReferralStats, error) GetReferralStats(ctx context.Context, userID string) (*domain.ReferralStats, error)
CreateReferralSettings(ctx context.Context, req domain.ReferralSettingsReq) error
UpdateReferralSettings(ctx context.Context, settings *domain.ReferralSettings) error UpdateReferralSettings(ctx context.Context, settings *domain.ReferralSettings) error
GetReferralSettings(ctx context.Context) (*domain.ReferralSettings, error) GetReferralSettings(ctx context.Context) (*domain.ReferralSettings, error)
GetReferralCountByID(ctx context.Context, referrerID string) (int64, error)
ProcessBetReferral(ctx context.Context, userPhone string, betAmount float64) error ProcessBetReferral(ctx context.Context, userPhone string, betAmount float64) error
} }

View File

@ -54,16 +54,45 @@ func (s *Service) GenerateReferralCode() (string, error) {
func (s *Service) CreateReferral(ctx context.Context, userID int64) error { func (s *Service) CreateReferral(ctx context.Context, userID int64) error {
s.logger.Info("Creating referral code for user", "userID", userID) s.logger.Info("Creating referral code for user", "userID", userID)
// TODO: check in user already has an active referral code
// check if user already has an active referral code
referral, err := s.repo.GetActiveReferralByReferrerID(ctx, fmt.Sprintf("%d", userID))
if err != nil {
s.logger.Error("Failed to check if user alredy has active referral code", "error", err)
return err
}
if referral != nil && referral.Status == domain.ReferralPending && referral.ExpiresAt.After(time.Now()) {
s.logger.Error("user already has an active referral code", "error", err)
return err
}
settings, err := s.GetReferralSettings(ctx)
if err != nil || settings == nil {
s.logger.Error("Failed to fetch referral settings", "error", err)
return err
}
// check referral count limit
referralCount, err := s.GetReferralCountByID(ctx, fmt.Sprintf("%d", userID))
if err != nil {
s.logger.Error("Failed to get referral count", "userID", userID, "error", err)
return err
}
fmt.Println("referralCount: ", referralCount)
if referralCount == int64(settings.MaxReferrals) {
s.logger.Error("referral count limit has been reached", "referralCount", referralCount, "error", err)
return err
}
code, err := s.GenerateReferralCode() code, err := s.GenerateReferralCode()
if err != nil { if err != nil {
s.logger.Error("Failed to generate referral code", "error", err) s.logger.Error("Failed to generate referral code", "error", err)
return err return err
} }
// TODO: get the referral settings from db var rewardAmount float64 = settings.ReferralRewardAmount
var rewardAmount float64 = 100 var expireDuration time.Time = time.Now().Add(time.Duration((24 * settings.ExpiresAfterDays)) * time.Hour)
var expireDuration time.Time = time.Now().Add(24 * time.Hour)
if err := s.repo.CreateReferral(ctx, &domain.Referral{ if err := s.repo.CreateReferral(ctx, &domain.Referral{
ReferralCode: code, ReferralCode: code,
@ -242,6 +271,26 @@ func (s *Service) GetReferralStats(ctx context.Context, userPhone string) (*doma
return stats, nil return stats, nil
} }
func (s *Service) CreateReferralSettings(ctx context.Context, req domain.ReferralSettingsReq) error {
s.logger.Info("Creating referral setting")
if err := s.repo.CreateSettings(ctx, &domain.ReferralSettings{
ReferralRewardAmount: req.ReferralRewardAmount,
CashbackPercentage: req.CashbackPercentage,
MaxReferrals: req.MaxReferrals,
ExpiresAfterDays: req.ExpiresAfterDays,
UpdatedBy: req.UpdatedBy,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}); err != nil {
s.logger.Error("Failed to create referral setting", "error", err)
return err
}
s.logger.Info("Referral setting created succesfully")
return nil
}
func (s *Service) UpdateReferralSettings(ctx context.Context, settings *domain.ReferralSettings) error { func (s *Service) UpdateReferralSettings(ctx context.Context, settings *domain.ReferralSettings) error {
s.logger.Info("Updating referral settings", "settingsID", settings.ID) s.logger.Info("Updating referral settings", "settingsID", settings.ID)
@ -265,6 +314,16 @@ func (s *Service) GetReferralSettings(ctx context.Context) (*domain.ReferralSett
return nil, err return nil, err
} }
s.logger.Info("Referral settings retrieved successfully", "settingsID", settings.ID) s.logger.Info("Referral settings retrieved successfully", "settings", settings)
return settings, nil return settings, nil
} }
func (s *Service) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) {
count, err := s.repo.GetReferralCountByID(ctx, referrerID)
if err != nil {
s.logger.Error("Failed to get referral count", "userID", referrerID, "error", err)
return 0, err
}
return count, nil
}

View File

@ -8,6 +8,7 @@ import (
func (h *Handler) CreateBonusMultiplier(c *fiber.Ctx) error { func (h *Handler) CreateBonusMultiplier(c *fiber.Ctx) error {
var req struct { var req struct {
Multiplier float32 `json:"multiplier"` Multiplier float32 `json:"multiplier"`
BalanceCap int64 `json:"balance_cap"`
} }
if err := c.BodyParser(&req); err != nil { if err := c.BodyParser(&req); err != nil {
@ -27,7 +28,7 @@ func (h *Handler) CreateBonusMultiplier(c *fiber.Ctx) error {
return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", err, nil) return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", err, nil)
} }
if err := h.bonusSvc.CreateBonusMultiplier(c.Context(), req.Multiplier); err != nil { if err := h.bonusSvc.CreateBonusMultiplier(c.Context(), req.Multiplier, req.BalanceCap); err != nil {
h.logger.Error("failed to create bonus multiplier", "error", err) h.logger.Error("failed to create bonus multiplier", "error", err)
return response.WriteJSON(c, fiber.StatusInternalServerError, "failed to create bonus mulitplier", nil, nil) return response.WriteJSON(c, fiber.StatusInternalServerError, "failed to create bonus mulitplier", nil, nil)
} }
@ -49,6 +50,7 @@ func (h *Handler) UpdateBonusMultiplier(c *fiber.Ctx) error {
var req struct { var req struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Multiplier float32 `json:"multiplier"` Multiplier float32 `json:"multiplier"`
BalanceCap int64 `json:"balance_cap"`
} }
if err := c.BodyParser(&req); err != nil { if err := c.BodyParser(&req); err != nil {
@ -56,7 +58,7 @@ func (h *Handler) UpdateBonusMultiplier(c *fiber.Ctx) error {
return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", err, nil) return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", err, nil)
} }
if err := h.bonusSvc.UpdateBonusMultiplier(c.Context(), req.ID, req.Multiplier); err != nil { if err := h.bonusSvc.UpdateBonusMultiplier(c.Context(), req.ID, req.Multiplier, req.BalanceCap); err != nil {
h.logger.Error("failed to update bonus multiplier", "error", err) h.logger.Error("failed to update bonus multiplier", "error", err)
return response.WriteJSON(c, fiber.StatusInternalServerError, "failed to update bonus mulitplier", nil, nil) return response.WriteJSON(c, fiber.StatusInternalServerError, "failed to update bonus mulitplier", nil, nil)
} }

View File

@ -2,6 +2,7 @@ package handlers
import ( import (
"fmt" "fmt"
"math"
"github.com/SamuelTariku/FortuneBet-Backend/internal/domain" "github.com/SamuelTariku/FortuneBet-Backend/internal/domain"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -66,7 +67,15 @@ func (h *Handler) InitiateDeposit(c *fiber.Ctx) error {
multiplier = bonusMultiplier[0].Multiplier multiplier = bonusMultiplier[0].Multiplier
} }
_, err = h.walletSvc.AddToWallet(c.Context(), wallet.StaticID, domain.ToCurrency(float32(amount)*multiplier), domain.ValidInt64{}, domain.TRANSFER_DIRECT, domain.PaymentDetails{}) var balanceCap int64 = 0
bonusBalanceCap, err := h.bonusSvc.GetBonusBalanceCap(c.Context())
if err == nil {
balanceCap = bonusBalanceCap[0].BalanceCap
}
capedBalanceAmount := domain.Currency((math.Min(req.Amount, float64(balanceCap)) * float64(multiplier)) * 100)
_, err = h.walletSvc.AddToWallet(c.Context(), wallet.StaticID, capedBalanceAmount, domain.ValidInt64{}, domain.TRANSFER_DIRECT, domain.PaymentDetails{})
if err != nil { if err != nil {
h.logger.Error("Failed to add bonus to static wallet", "walletID", wallet.StaticID, "user id", userID, "error", err) h.logger.Error("Failed to add bonus to static wallet", "walletID", wallet.StaticID, "user id", userID, "error", err)
return err return err

View File

@ -21,6 +21,38 @@ func (h *Handler) CreateReferralCode(c *fiber.Ctx) error {
return response.WriteJSON(c, fiber.StatusOK, "Referral created successfully", nil, nil) return response.WriteJSON(c, fiber.StatusOK, "Referral created successfully", nil, nil)
} }
func (h *Handler) CreateReferralSettings(c *fiber.Ctx) error {
var req domain.ReferralSettingsReq
if err := c.BodyParser(&req); err != nil {
h.logger.Error("Failed to parse settings", "error", err)
return fiber.NewError(fiber.StatusBadRequest, "Invalid request body")
}
if valErrs, ok := h.validator.Validate(c, req); !ok {
return response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", valErrs, nil)
}
settings, err := h.referralSvc.GetReferralSettings(c.Context())
if err != nil {
h.logger.Error("Failed to fetch previous referral setting", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to create referral")
}
// only allow one referral setting for now
// for future it can be multiple and be able to choose from them
if settings != nil {
h.logger.Error("referral setting already exists", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "referral setting already exists")
}
if err := h.referralSvc.CreateReferralSettings(c.Context(), req); err != nil {
h.logger.Error("Failed to create referral setting", "error", err)
return fiber.NewError(fiber.StatusInternalServerError, "Failed to create referral")
}
return response.WriteJSON(c, fiber.StatusOK, "Referral created successfully", nil, nil)
}
// GetReferralStats godoc // GetReferralStats godoc
// @Summary Get referral statistics // @Summary Get referral statistics
// @Description Retrieves referral statistics for the authenticated user // @Description Retrieves referral statistics for the authenticated user
@ -112,11 +144,12 @@ func (h *Handler) UpdateReferralSettings(c *fiber.Ctx) error {
// @Security Bearer // @Security Bearer
// @Router /referral/settings [get] // @Router /referral/settings [get]
func (h *Handler) GetReferralSettings(c *fiber.Ctx) error { func (h *Handler) GetReferralSettings(c *fiber.Ctx) error {
userID, ok := c.Locals("user_id").(int64) // userID, ok := c.Locals("user_id").(int64)
if !ok || userID == 0 { // if !ok || userID == 0 {
h.logger.Error("Invalid user ID in context") // h.logger.Error("Invalid user ID in context")
return fiber.NewError(fiber.StatusUnauthorized, "Invalid user identification") // return fiber.NewError(fiber.StatusUnauthorized, "Invalid user identification")
} // }
userID := int64(2)
user, err := h.userSvc.GetUserByID(c.Context(), userID) user, err := h.userSvc.GetUserByID(c.Context(), userID)
if err != nil { if err != nil {

View File

@ -109,7 +109,8 @@ func (a *App) initAppRoutes() {
// Referral Routes // Referral Routes
a.fiber.Post("/referral/create", a.authMiddleware, h.CreateReferralCode) a.fiber.Post("/referral/create", a.authMiddleware, h.CreateReferralCode)
a.fiber.Get("/referral/stats", a.authMiddleware, h.GetReferralStats) a.fiber.Get("/referral/stats", a.authMiddleware, h.GetReferralStats)
a.fiber.Get("/referral/settings", h.GetReferralSettings) a.fiber.Post("/referral/settings", a.authMiddleware, h.CreateReferralSettings)
a.fiber.Get("/referral/settings", a.authMiddleware, h.GetReferralSettings)
a.fiber.Patch("/referral/settings", a.authMiddleware, h.UpdateReferralSettings) a.fiber.Patch("/referral/settings", a.authMiddleware, h.UpdateReferralSettings)
// Bonus Routes // Bonus Routes