From 41c9c552ae0b88777a7565a1ca190c012989370e Mon Sep 17 00:00:00 2001 From: Asher Samuel Date: Fri, 27 Jun 2025 14:35:04 +0300 Subject: [PATCH] referral bonus --- internal/repository/referal.go | 3 +- internal/services/referal/service.go | 43 +++++++++++++++------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/internal/repository/referal.go b/internal/repository/referal.go index 274acd9..a782cfb 100644 --- a/internal/repository/referal.go +++ b/internal/repository/referal.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "strconv" dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db" "github.com/SamuelTariku/FortuneBet-Backend/internal/domain" @@ -44,7 +45,7 @@ func (r *ReferralRepo) UpdateUserReferalCode(ctx context.Context, codedata domai func (r *ReferralRepo) CreateReferral(ctx context.Context, referral *domain.Referral) error { rewardAmount := pgtype.Numeric{} - if err := rewardAmount.Scan(referral.RewardAmount); err != nil { + if err := rewardAmount.Scan(strconv.Itoa(int(referral.RewardAmount))); err != nil { return err } diff --git a/internal/services/referal/service.go b/internal/services/referal/service.go index bbb0d43..5585d74 100644 --- a/internal/services/referal/service.go +++ b/internal/services/referal/service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base32" "errors" + "fmt" "log/slog" "strconv" "time" @@ -53,15 +54,23 @@ func (s *Service) GenerateReferralCode() (string, error) { func (s *Service) CreateReferral(ctx context.Context, userID int64) error { s.logger.Info("Creating referral code for user", "userID", userID) + // TODO: check in user already has an active referral code code, err := s.GenerateReferralCode() if err != nil { s.logger.Error("Failed to generate referral code", "error", err) return err } - if err := s.repo.UpdateUserReferalCode(ctx, domain.UpdateUserReferalCode{ - UserID: userID, - Code: code, + // TODO: get the referral settings from db + var rewardAmount float64 = 100 + var expireDuration time.Time = time.Now().Add(24 * time.Hour) + + if err := s.repo.CreateReferral(ctx, &domain.Referral{ + ReferralCode: code, + ReferrerID: fmt.Sprintf("%d", userID), + Status: domain.ReferralPending, + RewardAmount: rewardAmount, + ExpiresAt: expireDuration, }); err != nil { return err } @@ -73,12 +82,12 @@ func (s *Service) ProcessReferral(ctx context.Context, referredPhone, referralCo s.logger.Info("Processing referral", "referredPhone", referredPhone, "referralCode", referralCode) referral, err := s.repo.GetReferralByCode(ctx, referralCode) - if err != nil { + if err != nil || referral == nil { s.logger.Error("Failed to get referral by code", "referralCode", referralCode, "error", err) return err } - if referral == nil || referral.Status != domain.ReferralPending || referral.ExpiresAt.Before(time.Now()) { + if referral.Status != domain.ReferralPending || referral.ExpiresAt.Before(time.Now()) { s.logger.Warn("Invalid or expired referral", "referralCode", referralCode, "status", referral.Status) return ErrInvalidReferral } @@ -106,27 +115,21 @@ func (s *Service) ProcessReferral(ctx context.Context, referredPhone, referralCo return err } - referrerID, err := strconv.ParseInt(referral.ReferrerID, 10, 64) + referrerId, err := strconv.Atoi(referral.ReferrerID) if err != nil { - s.logger.Error("Invalid referrer phone number format", "referrerID", referral.ReferrerID, "error", err) - return errors.New("invalid referrer phone number format") - } - - wallets, err := s.walletSvc.GetWalletsByUser(ctx, referrerID) - if err != nil { - s.logger.Error("Failed to get wallets for referrer", "referrerID", referrerID, "error", err) + s.logger.Error("Failed to convert referrer id", "referrerId", referral.ReferrerID, "error", err) return err } - if len(wallets) == 0 { - s.logger.Error("Referrer has no wallet", "referrerID", referrerID) - return errors.New("referrer has no wallet") + + wallets, err := s.store.GetCustomerWallet(ctx, int64(referrerId)) + if err != nil { + s.logger.Error("Failed to get referrer wallets", "referrerId", referral.ReferrerID, "error", err) + return err } - walletID := wallets[0].ID - currentBonus := float64(wallets[0].Balance) - _, err = s.walletSvc.AddToWallet(ctx, walletID, domain.ToCurrency(float32(currentBonus+referral.RewardAmount)), domain.ValidInt64{}, domain.TRANSFER_DIRECT, domain.PaymentDetails{}) + _, err = s.walletSvc.AddToWallet(ctx, wallets.StaticID, domain.ToCurrency(float32(referral.RewardAmount)), domain.ValidInt64{}, domain.TRANSFER_DIRECT, domain.PaymentDetails{}) if err != nil { - s.logger.Error("Failed to add referral reward to wallet", "walletID", walletID, "referrerID", referrerID, "error", err) + s.logger.Error("Failed to add referral reward to static wallet", "walletID", wallets.StaticID, "referrer phone number", referredPhone, "error", err) return err }