Yimaru-BackEnd/internal/repository/user.go

780 lines
21 KiB
Go

package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
dbgen "github.com/SamuelTariku/FortuneBet-Backend/gen/db"
"github.com/SamuelTariku/FortuneBet-Backend/internal/domain"
"github.com/jackc/pgx/v5/pgtype"
)
func (s *Store) CreateUser(ctx context.Context, user domain.User, usedOtpId int64, is_company bool) (domain.User, error) {
err := s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{
ID: usedOtpId,
UsedAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
})
if err != nil {
return domain.User{}, err
}
userRes, err := s.queries.CreateUser(ctx, dbgen.CreateUserParams{
FirstName: user.FirstName,
LastName: user.LastName,
Email: pgtype.Text{
String: user.Email,
Valid: user.Email != "",
},
PhoneNumber: pgtype.Text{
String: user.PhoneNumber,
Valid: user.PhoneNumber != "",
},
Password: user.Password,
Role: string(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
UpdatedAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
})
if err != nil {
return domain.User{}, err
}
return domain.User{
ID: userRes.ID,
FirstName: userRes.FirstName,
LastName: userRes.LastName,
Email: userRes.Email.String,
PhoneNumber: userRes.PhoneNumber.String,
Role: domain.Role(userRes.Role),
}, nil
}
func (s *Store) GetUserByID(ctx context.Context, id int64) (domain.User, error) {
user, err := s.queries.GetUserByID(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return domain.User{}, domain.ErrUserNotFound
}
return domain.User{}, err
}
return domain.User{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
EmailVerified: user.EmailVerified,
Password: user.Password,
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
SuspendedAt: user.SuspendedAt.Time,
Suspended: user.Suspended,
CompanyID: domain.ValidInt64{
Value: user.CompanyID.Int64,
Valid: user.CompanyID.Valid,
},
}, nil
}
func (s *Store) GetAllUsers(ctx context.Context, filter domain.UserFilter) ([]domain.User, int64, error) {
users, err := s.queries.GetAllUsers(ctx, dbgen.GetAllUsersParams{
Role: filter.Role,
CompanyID: pgtype.Int8{
Int64: filter.CompanyID.Value,
Valid: filter.CompanyID.Valid,
},
Limit: pgtype.Int4{
Int32: int32(filter.PageSize.Value),
Valid: filter.PageSize.Valid,
},
Offset: pgtype.Int4{
Int32: int32(filter.Page.Value),
Valid: filter.Page.Valid,
},
Query: pgtype.Text{
String: filter.Query.Value,
Valid: filter.Query.Valid,
},
CreatedBefore: pgtype.Timestamptz{
Time: filter.CreatedBefore.Value,
Valid: filter.CreatedBefore.Valid,
},
CreatedAfter: pgtype.Timestamptz{
Time: filter.CreatedAfter.Value,
Valid: filter.CreatedAfter.Valid,
},
})
if err != nil {
return nil, 0, err
}
userList := make([]domain.User, len(users))
for i, user := range users {
userList[i] = domain.User{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
EmailVerified: user.EmailVerified,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
SuspendedAt: user.SuspendedAt.Time,
Suspended: user.Suspended,
CompanyID: domain.ValidInt64{
Value: user.CompanyID.Int64,
Valid: user.CompanyID.Valid,
},
}
}
totalCount, err := s.queries.GetTotalUsers(ctx, dbgen.GetTotalUsersParams{
Role: filter.Role,
CompanyID: pgtype.Int8{
Int64: filter.CompanyID.Value,
Valid: filter.CompanyID.Valid,
},
})
return userList, totalCount, nil
}
func (s *Store) GetAllCashiers(ctx context.Context, filter domain.UserFilter) ([]domain.GetCashier, int64, error) {
users, err := s.queries.GetAllCashiers(ctx, dbgen.GetAllCashiersParams{
Query: pgtype.Text{
String: filter.Query.Value,
Valid: filter.Query.Valid,
},
CreatedBefore: pgtype.Timestamptz{
Time: filter.CreatedBefore.Value,
Valid: filter.CreatedBefore.Valid,
},
CreatedAfter: pgtype.Timestamptz{
Time: filter.CreatedAfter.Value,
Valid: filter.CreatedAfter.Valid,
},
})
if err != nil {
return nil, 0, err
}
userList := make([]domain.GetCashier, len(users))
for i, user := range users {
userList[i] = domain.GetCashier{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
SuspendedAt: user.SuspendedAt.Time,
Suspended: user.Suspended,
BranchID: user.BranchID,
BranchName: user.BranchName,
BranchWallet: user.BranchWallet,
BranchLocation: user.BranchLocation,
}
}
totalCount, err := s.queries.GetTotalUsers(ctx, dbgen.GetTotalUsersParams{
Role: string(domain.RoleCashier),
})
return userList, totalCount, nil
}
func (s *Store) GetCashierByID(ctx context.Context, cashierID int64) (domain.GetCashier, error) {
user, err := s.queries.GetCashierByID(ctx, cashierID)
if err != nil {
return domain.GetCashier{}, err
}
return domain.GetCashier{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
SuspendedAt: user.SuspendedAt.Time,
Suspended: user.Suspended,
BranchID: user.BranchID,
BranchName: user.BranchName,
BranchWallet: user.BranchWallet,
BranchLocation: user.BranchLocation,
}, nil
}
func (s *Store) GetCashiersByBranch(ctx context.Context, branchID int64) ([]domain.User, error) {
users, err := s.queries.GetCashiersByBranch(ctx, branchID)
if err != nil {
return nil, err
}
userList := make([]domain.User, len(users))
for i, user := range users {
userList[i] = domain.User{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
SuspendedAt: user.SuspendedAt.Time,
Suspended: user.Suspended,
}
}
return userList, nil
}
func (s *Store) SearchUserByNameOrPhone(ctx context.Context, searchString string, role *domain.Role, companyID domain.ValidInt64) ([]domain.User, error) {
query := dbgen.SearchUserByNameOrPhoneParams{
Column1: pgtype.Text{
String: searchString,
Valid: true,
},
CompanyID: pgtype.Int8{
Int64: companyID.Value,
Valid: companyID.Valid,
},
}
if role != nil {
query.Role = pgtype.Text{
String: string(*role),
Valid: true,
}
}
users, err := s.queries.SearchUserByNameOrPhone(ctx, query)
if err != nil {
return nil, err
}
userList := make([]domain.User, 0, len(users))
for _, user := range users {
userList = append(userList, domain.User{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
Suspended: user.Suspended,
SuspendedAt: user.SuspendedAt.Time,
})
}
return userList, nil
}
func (s *Store) UpdateUser(ctx context.Context, user domain.UpdateUserReq) error {
err := s.queries.UpdateUser(ctx, dbgen.UpdateUserParams{
ID: user.UserId,
FirstName: user.FirstName.Value,
LastName: user.LastName.Value,
Suspended: user.Suspended.Value,
})
fmt.Printf("Updating User %v with values %v", user.UserId, user)
if err != nil {
return err
}
return nil
}
func (s *Store) UpdateUserCompany(ctx context.Context, id int64, companyID int64) error {
err := s.queries.UpdateUserCompany(ctx, dbgen.UpdateUserCompanyParams{
CompanyID: pgtype.Int8{
Int64: companyID,
Valid: true,
},
ID: id,
})
if err != nil {
return err
}
return nil
}
func (s *Store) UpdateUserSuspend(ctx context.Context, id int64, status bool) error {
err := s.queries.SuspendUser(ctx, dbgen.SuspendUserParams{
ID: id,
Suspended: status,
SuspendedAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
})
if err != nil {
return err
}
return nil
}
func (s *Store) DeleteUser(ctx context.Context, id int64) error {
err := s.queries.DeleteUser(ctx, id)
if err != nil {
return err
}
return nil
}
func (s *Store) CheckPhoneEmailExist(ctx context.Context, phoneNum, email string) (bool, bool, error) {
row, err := s.queries.CheckPhoneEmailExist(ctx, dbgen.CheckPhoneEmailExistParams{
PhoneNumber: pgtype.Text{
String: phoneNum,
Valid: phoneNum != "",
},
Email: pgtype.Text{
String: email,
Valid: email != "",
},
})
if err != nil {
return false, false, err
}
return row.EmailExists, row.PhoneExists, nil
}
func (s *Store) GetUserByEmail(ctx context.Context, email string) (domain.User, error) {
user, err := s.queries.GetUserByEmail(ctx, pgtype.Text{
String: email,
Valid: true,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return domain.User{}, domain.ErrUserNotFound
}
return domain.User{}, err
}
return domain.User{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
Suspended: user.Suspended,
SuspendedAt: user.SuspendedAt.Time,
}, nil
}
func (s *Store) GetUserByPhone(ctx context.Context, phoneNum string) (domain.User, error) {
user, err := s.queries.GetUserByPhone(ctx, pgtype.Text{
String: phoneNum,
Valid: true,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return domain.User{}, domain.ErrUserNotFound
}
return domain.User{}, err
}
return domain.User{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email.String,
PhoneNumber: user.PhoneNumber.String,
Role: domain.Role(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: user.CreatedAt.Time,
UpdatedAt: user.UpdatedAt.Time,
Suspended: user.Suspended,
SuspendedAt: user.SuspendedAt.Time,
}, nil
}
func (s *Store) UpdatePassword(ctx context.Context, identifier string, password []byte, usedOtpId int64) error {
err := s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{
ID: usedOtpId,
UsedAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
})
if err != nil {
return err
}
err = s.queries.UpdatePassword(ctx, dbgen.UpdatePasswordParams{
Password: password,
Email: pgtype.Text{
String: identifier,
Valid: true,
},
PhoneNumber: pgtype.Text{
String: identifier,
Valid: true,
},
})
if err != nil {
return err
}
return nil
}
func (s *Store) CreateUserWithoutOtp(ctx context.Context, user domain.User, is_company bool) (domain.User, error) {
userRes, err := s.queries.CreateUser(ctx, dbgen.CreateUserParams{
FirstName: user.FirstName,
LastName: user.LastName,
Email: pgtype.Text{
String: user.Email,
Valid: user.Email != "",
},
PhoneNumber: pgtype.Text{
String: user.PhoneNumber,
Valid: user.PhoneNumber != "",
},
Password: user.Password,
Role: string(user.Role),
EmailVerified: user.EmailVerified,
PhoneVerified: user.PhoneVerified,
CreatedAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
UpdatedAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
Suspended: user.Suspended,
CompanyID: pgtype.Int8{
Int64: user.CompanyID.Value,
Valid: user.CompanyID.Valid,
},
})
if err != nil {
return domain.User{}, err
}
return domain.User{
ID: userRes.ID,
FirstName: userRes.FirstName,
LastName: userRes.LastName,
Email: userRes.Email.String,
PhoneNumber: userRes.PhoneNumber.String,
Role: domain.Role(userRes.Role),
EmailVerified: userRes.EmailVerified,
PhoneVerified: userRes.PhoneVerified,
CreatedAt: userRes.CreatedAt.Time,
UpdatedAt: userRes.UpdatedAt.Time,
Suspended: userRes.Suspended,
}, nil
}
// GetCustomerCounts returns total and active customer counts
func (s *Store) GetCustomerCounts(ctx context.Context, filter domain.ReportFilter) (total, active, inactive int64, err error) {
query := `SELECT
COUNT(*) as total,
SUM(CASE WHEN suspended = false THEN 1 ELSE 0 END) as active,
SUM(CASE WHEN suspended = true THEN 1 ELSE 0 END) as inactive
FROM users WHERE role = 'customer'`
args := []interface{}{}
argPos := 1
// Add filters if provided
if filter.CompanyID.Valid {
query += fmt.Sprintf(" AND company_id = $%d", argPos)
args = append(args, filter.CompanyID.Value)
argPos++
}
if filter.BranchID.Valid {
query += fmt.Sprintf(" AND id IN (SELECT user_id FROM branch_cashiers WHERE branch_id = $%d)", argPos)
args = append(args, filter.BranchID.Value)
argPos++
}
if filter.StartTime.Valid {
query += fmt.Sprintf(" AND created_at >= $%d", argPos)
args = append(args, filter.StartTime.Value)
argPos++
}
if filter.EndTime.Valid {
query += fmt.Sprintf(" AND created_at <= $%d", argPos)
args = append(args, filter.EndTime.Value)
argPos++
}
row := s.conn.QueryRow(ctx, query, args...)
err = row.Scan(&total, &active, &inactive)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to get customer counts: %w", err)
}
return total, active, inactive, nil
}
// GetCustomerDetails returns customer details map
func (s *Store) GetCustomerDetails(ctx context.Context, filter domain.ReportFilter) (map[int64]domain.CustomerDetail, error) {
query := `SELECT id, first_name, last_name
FROM users WHERE role = 'customer'`
args := []interface{}{}
argPos := 1
// Add filters if provided
if filter.CompanyID.Valid {
query += fmt.Sprintf(" AND company_id = $%d", argPos)
args = append(args, filter.CompanyID.Value)
argPos++
}
if filter.BranchID.Valid {
query += fmt.Sprintf(" AND id IN (SELECT user_id FROM branch_cashiers WHERE branch_id = $%d)", argPos)
args = append(args, filter.BranchID.Value)
argPos++
}
if filter.StartTime.Valid {
query += fmt.Sprintf(" AND created_at >= $%d", argPos)
args = append(args, filter.StartTime.Value)
argPos++
}
if filter.EndTime.Valid {
query += fmt.Sprintf(" AND created_at <= $%d", argPos)
args = append(args, filter.EndTime.Value)
argPos++
}
rows, err := s.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query customer details: %w", err)
}
defer rows.Close()
details := make(map[int64]domain.CustomerDetail)
for rows.Next() {
var id int64
var firstName, lastName string
if err := rows.Scan(&id, &firstName, &lastName); err != nil {
return nil, fmt.Errorf("failed to scan customer detail: %w", err)
}
details[id] = domain.CustomerDetail{
Name: fmt.Sprintf("%s %s", firstName, lastName),
}
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("rows error: %w", err)
}
return details, nil
}
// GetBranchCustomerCounts returns customer counts per branch
func (s *Store) GetBranchCustomerCounts(ctx context.Context, filter domain.ReportFilter) (map[int64]int64, error) {
query := `SELECT branch_id, COUNT(DISTINCT user_id)
FROM branch_cashiers
JOIN users ON branch_cashiers.user_id = users.id
WHERE users.role = 'customer'`
args := []interface{}{}
argPos := 1
// Add filters if provided
if filter.CompanyID.Valid {
query += fmt.Sprintf(" AND branch_id IN (SELECT id FROM branches WHERE company_id = $%d)", argPos)
args = append(args, filter.CompanyID.Value)
argPos++
}
if filter.BranchID.Valid {
query += fmt.Sprintf(" AND branch_id = $%d", argPos)
args = append(args, filter.BranchID.Value)
argPos++
}
if filter.StartTime.Valid {
query += fmt.Sprintf(" AND users.created_at >= $%d", argPos)
args = append(args, filter.StartTime.Value)
argPos++
}
if filter.EndTime.Valid {
query += fmt.Sprintf(" AND users.created_at <= $%d", argPos)
args = append(args, filter.EndTime.Value)
argPos++
}
query += " GROUP BY branch_id"
rows, err := s.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query branch customer counts: %w", err)
}
defer rows.Close()
counts := make(map[int64]int64)
for rows.Next() {
var branchID int64
var count int64
if err := rows.Scan(&branchID, &count); err != nil {
return nil, fmt.Errorf("failed to scan branch customer count: %w", err)
}
counts[branchID] = count
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("rows error: %w", err)
}
return counts, nil
}
func (s *Store) GetCustomerPreferences(ctx context.Context, filter domain.ReportFilter) (map[int64]domain.CustomerPreferences, error) {
query := `WITH customer_sports AS (
SELECT
b.user_id,
bo.sport_id,
COUNT(*) as bet_count,
ROW_NUMBER() OVER (PARTITION BY b.user_id ORDER BY COUNT(*) DESC) as sport_rank
FROM bets b
JOIN bet_outcomes bo ON b.id = bo.bet_id
WHERE b.user_id IS NOT NULL AND bo.sport_id IS NOT NULL
),
customer_markets AS (
SELECT
b.user_id,
bo.market_name,
COUNT(*) as bet_count,
ROW_NUMBER() OVER (PARTITION BY b.user_id ORDER BY COUNT(*) DESC) as market_rank
FROM bets b
JOIN bet_outcomes bo ON b.id = bo.bet_id
WHERE b.user_id IS NOT NULL AND bo.market_name IS NOT NULL
`
args := []interface{}{}
argPos := 1
// Add filters if provided
if filter.CompanyID.Valid {
query += fmt.Sprintf(" AND b.company_id = $%d", argPos)
args = append(args, filter.CompanyID.Value)
argPos++
}
if filter.BranchID.Valid {
query += fmt.Sprintf(" AND b.branch_id = $%d", argPos)
args = append(args, filter.BranchID.Value)
argPos++
}
if filter.UserID.Valid {
query += fmt.Sprintf(" AND b.user_id = $%d", argPos)
args = append(args, filter.UserID.Value)
argPos++
}
if filter.StartTime.Valid {
query += fmt.Sprintf(" AND b.created_at >= $%d", argPos)
args = append(args, filter.StartTime.Value)
argPos++
}
if filter.EndTime.Valid {
query += fmt.Sprintf(" AND b.created_at <= $%d", argPos)
args = append(args, filter.EndTime.Value)
argPos++
}
query += ` GROUP BY b.user_id, bo.sport_id
),
favorite_sports AS (
SELECT user_id, sport_id
FROM customer_sports
WHERE sport_rank = 1
),
favorite_markets AS (
SELECT user_id, market_name
FROM customer_markets
WHERE market_rank = 1
)
SELECT
fs.user_id,
fs.sport_id as favorite_sport,
fm.market_name as favorite_market
FROM favorite_sports fs
LEFT JOIN favorite_markets fm ON fs.user_id = fm.user_id`
rows, err := s.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query customer preferences: %w", err)
}
defer rows.Close()
preferences := make(map[int64]domain.CustomerPreferences)
for rows.Next() {
var userID int64
var pref domain.CustomerPreferences
if err := rows.Scan(&userID, &pref.FavoriteSport, &pref.FavoriteMarket); err != nil {
return nil, fmt.Errorf("failed to scan customer preference: %w", err)
}
preferences[userID] = pref
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("rows error: %w", err)
}
return preferences, nil
}
func (s *Store) GetRoleCounts(ctx context.Context, role string, filter domain.ReportFilter) (total, active, inactive int64, err error) {
query := `SELECT
COUNT(*) as total,
COUNT(CASE WHEN suspended = false THEN 1 END) as active,
COUNT(CASE WHEN suspended = true THEN 1 END) as inactive
FROM users WHERE role = $1`
args := []interface{}{role}
argPos := 2
// Add filters if provided
if filter.CompanyID.Valid {
query += fmt.Sprintf(" AND company_id = $%d", argPos)
args = append(args, filter.CompanyID.Value)
argPos++
}
if filter.StartTime.Valid {
query += fmt.Sprintf(" AND %screated_at >= $%d", func() string {
if len(args) == 1 { // Only role parameter so far
return " "
}
return " AND "
}(), argPos)
args = append(args, filter.StartTime.Value)
argPos++
}
if filter.EndTime.Valid {
query += fmt.Sprintf(" AND created_at <= $%d", argPos)
args = append(args, filter.EndTime.Value)
argPos++
}
row := s.conn.QueryRow(ctx, query, args...)
err = row.Scan(&total, &active, &inactive)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to get %s counts: %w", role, err)
}
return total, active, inactive, nil
}