diff --git a/db/query/referal.sql b/db/query/referal.sql index dd868ca..206606e 100644 --- a/db/query/referal.sql +++ b/db/query/referal.sql @@ -69,3 +69,9 @@ INSERT INTO referral_settings ( -- name: GetReferralByReferredID :one SELECT * FROM referrals WHERE referred_id = $1 LIMIT 1; + +-- name: GetActiveReferralByReferrerID :one +SELECT * FROM referrals WHERE referrer_id = $1 AND status = 'PENDING' LIMIT 1; + +-- name: GetReferralCountByID :one +SELECT count(*) FROM referrals WHERE referrer_id = $1; \ No newline at end of file diff --git a/gen/db/referal.sql.go b/gen/db/referal.sql.go index 9784440..b5ceeed 100644 --- a/gen/db/referal.sql.go +++ b/gen/db/referal.sql.go @@ -102,6 +102,28 @@ func (q *Queries) CreateReferralSettings(ctx context.Context, arg CreateReferral return i, err } +const GetActiveReferralByReferrerID = `-- name: GetActiveReferralByReferrerID :one +SELECT id, referral_code, referrer_id, referred_id, status, reward_amount, cashback_amount, created_at, updated_at, expires_at FROM referrals WHERE referrer_id = $1 AND status = 'PENDING' LIMIT 1 +` + +func (q *Queries) GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (Referral, error) { + row := q.db.QueryRow(ctx, GetActiveReferralByReferrerID, referrerID) + var i Referral + err := row.Scan( + &i.ID, + &i.ReferralCode, + &i.ReferrerID, + &i.ReferredID, + &i.Status, + &i.RewardAmount, + &i.CashbackAmount, + &i.CreatedAt, + &i.UpdatedAt, + &i.ExpiresAt, + ) + return i, err +} + const GetReferralByCode = `-- name: GetReferralByCode :one SELECT id, referral_code, referrer_id, referred_id, status, reward_amount, cashback_amount, created_at, updated_at, expires_at FROM referrals WHERE referral_code = $1 @@ -147,6 +169,17 @@ func (q *Queries) GetReferralByReferredID(ctx context.Context, referredID pgtype return i, err } +const GetReferralCountByID = `-- name: GetReferralCountByID :one +SELECT count(*) FROM referrals WHERE referrer_id = $1 +` + +func (q *Queries) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) { + row := q.db.QueryRow(ctx, GetReferralCountByID, referrerID) + var count int64 + err := row.Scan(&count) + return count, err +} + const GetReferralSettings = `-- name: GetReferralSettings :one SELECT id, referral_reward_amount, cashback_percentage, bet_referral_bonus_percentage, max_referrals, expires_after_days, updated_by, created_at, updated_at, version FROM referral_settings LIMIT 1 diff --git a/internal/repository/referal.go b/internal/repository/referal.go index 105ce5d..d214c54 100644 --- a/internal/repository/referal.go +++ b/internal/repository/referal.go @@ -21,6 +21,8 @@ type ReferralRepository interface { UpdateSettings(ctx context.Context, settings *domain.ReferralSettings) error CreateSettings(ctx context.Context, settings *domain.ReferralSettings) error GetReferralByReferredID(ctx context.Context, referredID string) (*domain.Referral, error) // New method + GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) + GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (*domain.Referral, error) UpdateUserReferalCode(ctx context.Context, codedata domain.UpdateUserReferalCode) error } @@ -184,6 +186,30 @@ func (r *ReferralRepo) GetReferralByReferredID(ctx context.Context, referredID s return r.mapToDomainReferral(&dbReferral), nil } +func (r *ReferralRepo) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) { + count, err := r.store.queries.GetReferralCountByID(ctx, referrerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + + return count, nil +} + +func (r *ReferralRepo) GetActiveReferralByReferrerID(ctx context.Context, referrerID string) (*domain.Referral, error) { + referral, err := r.store.queries.GetActiveReferralByReferrerID(ctx, referrerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return &domain.Referral{}, nil + } + return &domain.Referral{}, err + } + + return r.mapToDomainReferral(&referral), nil +} + func (r *ReferralRepo) mapToDomainReferral(dbRef *dbgen.Referral) *domain.Referral { var referredID *string if dbRef.ReferredID.Valid { diff --git a/internal/services/referal/port.go b/internal/services/referal/port.go index 703f986..1946e99 100644 --- a/internal/services/referal/port.go +++ b/internal/services/referal/port.go @@ -15,5 +15,6 @@ type ReferralStore interface { CreateReferralSettings(ctx context.Context, req domain.ReferralSettingsReq) error UpdateReferralSettings(ctx context.Context, settings *domain.ReferralSettings) error GetReferralSettings(ctx context.Context) (*domain.ReferralSettings, error) + GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) ProcessBetReferral(ctx context.Context, userPhone string, betAmount float64) error } diff --git a/internal/services/referal/service.go b/internal/services/referal/service.go index 0a812f0..aaa7af3 100644 --- a/internal/services/referal/service.go +++ b/internal/services/referal/service.go @@ -56,7 +56,7 @@ func (s *Service) CreateReferral(ctx context.Context, userID int64) error { s.logger.Info("Creating referral code for user", "userID", userID) // check if user already has an active referral code - referral, err := s.repo.GetReferralByReferredID(ctx, fmt.Sprintf("%d", userID)) + referral, err := s.repo.GetActiveReferralByReferrerID(ctx, fmt.Sprintf("%d", userID)) if err != nil { s.logger.Error("Failed to check if user alredy has active referral code", "error", err) return err @@ -66,18 +66,31 @@ func (s *Service) CreateReferral(ctx context.Context, userID int64) error { return err } - code, err := s.GenerateReferralCode() - if err != nil { - s.logger.Error("Failed to generate referral code", "error", err) - return err - } - settings, err := s.GetReferralSettings(ctx) if err != nil || settings == nil { s.logger.Error("Failed to fetch referral settings", "error", err) return err } + // check referral count limit + referralCount, err := s.GetReferralCountByID(ctx, fmt.Sprintf("%d", userID)) + if err != nil { + s.logger.Error("Failed to get referral count", "userID", userID, "error", err) + return err + } + + fmt.Println("referralCount: ", referralCount) + if referralCount == int64(settings.MaxReferrals) { + s.logger.Error("referral count limit has been reached", "referralCount", referralCount, "error", err) + return err + } + + code, err := s.GenerateReferralCode() + if err != nil { + s.logger.Error("Failed to generate referral code", "error", err) + return err + } + var rewardAmount float64 = settings.ReferralRewardAmount var expireDuration time.Time = time.Now().Add(time.Duration((24 * settings.ExpiresAfterDays)) * time.Hour) @@ -304,3 +317,13 @@ func (s *Service) GetReferralSettings(ctx context.Context) (*domain.ReferralSett s.logger.Info("Referral settings retrieved successfully", "settings", settings) return settings, nil } + +func (s *Service) GetReferralCountByID(ctx context.Context, referrerID string) (int64, error) { + count, err := s.repo.GetReferralCountByID(ctx, referrerID) + if err != nil { + s.logger.Error("Failed to get referral count", "userID", referrerID, "error", err) + return 0, err + } + + return count, nil +}