From eafd68d3c2c100d8ed14f93206bd2df70f40ac5b Mon Sep 17 00:00:00 2001 From: Samuel Tariku Date: Fri, 23 May 2025 21:43:34 +0300 Subject: [PATCH] fix: restricting search --- db/query/user.sql | 16 +++++++++--- db/query/wallet.sql | 4 +-- gen/db/models.go | 1 - gen/db/user.sql.go | 26 ++++++++++++++++---- gen/db/wallet.sql.go | 27 ++++----------------- internal/repository/user.go | 27 +++++++++++++++++---- internal/repository/wallet.go | 7 ++---- internal/services/bet/service.go | 3 ++- internal/services/odds/service.go | 1 - internal/services/user/port.go | 2 +- internal/services/user/user.go | 4 +-- internal/services/wallet/port.go | 2 +- internal/services/wallet/wallet.go | 4 +-- internal/web_server/handlers/bet_handler.go | 7 +++++- internal/web_server/handlers/user.go | 6 +++-- 15 files changed, 82 insertions(+), 55 deletions(-) diff --git a/db/query/user.sql b/db/query/user.sql index 3341656..91ddccb 100644 --- a/db/query/user.sql +++ b/db/query/user.sql @@ -93,9 +93,19 @@ SELECT id, suspended_at, company_id FROM users -WHERE first_name ILIKE '%' || $1 || '%' - OR last_name ILIKE '%' || $1 || '%' - OR phone_number LIKE '%' || $1 || '%'; +WHERE ( + first_name ILIKE '%' || $1 || '%' + OR last_name ILIKE '%' || $1 || '%' + OR phone_number LIKE '%' || $1 || '%' + ) + AND ( + role = sqlc.narg('role') + OR sqlc.narg('role') IS NULL + ) + AND ( + company_id = sqlc.narg('company_id') + OR sqlc.narg('company_id') IS NULL + ); -- name: UpdateUser :exec UPDATE users SET first_name = $1, diff --git a/db/query/wallet.sql b/db/query/wallet.sql index 9030134..e825653 100644 --- a/db/query/wallet.sql +++ b/db/query/wallet.sql @@ -29,7 +29,6 @@ WHERE user_id = $1; -- name: GetCustomerWallet :one SELECT cw.id, cw.customer_id, - cw.company_id, rw.id AS regular_id, rw.balance AS regular_balance, sw.id AS static_id, @@ -40,8 +39,7 @@ SELECT cw.id, FROM customer_wallets cw JOIN wallets rw ON cw.regular_wallet_id = rw.id JOIN wallets sw ON cw.static_wallet_id = sw.id -WHERE cw.customer_id = $1 - AND cw.company_id = $2; +WHERE cw.customer_id = $1; -- name: GetAllBranchWallets :many SELECT wallets.id, wallets.balance, diff --git a/gen/db/models.go b/gen/db/models.go index 9b27432..7c1695f 100644 --- a/gen/db/models.go +++ b/gen/db/models.go @@ -168,7 +168,6 @@ type Company struct { type CustomerWallet struct { ID int64 `json:"id"` CustomerID int64 `json:"customer_id"` - CompanyID int64 `json:"company_id"` RegularWalletID int64 `json:"regular_wallet_id"` StaticWalletID int64 `json:"static_wallet_id"` CreatedAt pgtype.Timestamp `json:"created_at"` diff --git a/gen/db/user.sql.go b/gen/db/user.sql.go index dd2f985..e0860c6 100644 --- a/gen/db/user.sql.go +++ b/gen/db/user.sql.go @@ -427,11 +427,27 @@ SELECT id, suspended_at, company_id FROM users -WHERE first_name ILIKE '%' || $1 || '%' - OR last_name ILIKE '%' || $1 || '%' - OR phone_number LIKE '%' || $1 || '%' +WHERE ( + first_name ILIKE '%' || $1 || '%' + OR last_name ILIKE '%' || $1 || '%' + OR phone_number LIKE '%' || $1 || '%' + ) + AND ( + role = $2 + OR $2 IS NULL + ) + AND ( + company_id = $3 + OR $3 IS NULL + ) ` +type SearchUserByNameOrPhoneParams struct { + Column1 pgtype.Text `json:"column_1"` + Role pgtype.Text `json:"role"` + CompanyID pgtype.Int8 `json:"company_id"` +} + type SearchUserByNameOrPhoneRow struct { ID int64 `json:"id"` FirstName string `json:"first_name"` @@ -448,8 +464,8 @@ type SearchUserByNameOrPhoneRow struct { CompanyID pgtype.Int8 `json:"company_id"` } -func (q *Queries) SearchUserByNameOrPhone(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUserByNameOrPhoneRow, error) { - rows, err := q.db.Query(ctx, SearchUserByNameOrPhone, dollar_1) +func (q *Queries) SearchUserByNameOrPhone(ctx context.Context, arg SearchUserByNameOrPhoneParams) ([]SearchUserByNameOrPhoneRow, error) { + rows, err := q.db.Query(ctx, SearchUserByNameOrPhone, arg.Column1, arg.Role, arg.CompanyID) if err != nil { return nil, err } diff --git a/gen/db/wallet.sql.go b/gen/db/wallet.sql.go index b3637f8..64c3359 100644 --- a/gen/db/wallet.sql.go +++ b/gen/db/wallet.sql.go @@ -14,33 +14,25 @@ import ( const CreateCustomerWallet = `-- name: CreateCustomerWallet :one INSERT INTO customer_wallets ( customer_id, - company_id, regular_wallet_id, static_wallet_id ) -VALUES ($1, $2, $3, $4) -RETURNING id, customer_id, company_id, regular_wallet_id, static_wallet_id, created_at, updated_at +VALUES ($1, $2, $3) +RETURNING id, customer_id, regular_wallet_id, static_wallet_id, created_at, updated_at ` type CreateCustomerWalletParams struct { CustomerID int64 `json:"customer_id"` - CompanyID int64 `json:"company_id"` RegularWalletID int64 `json:"regular_wallet_id"` StaticWalletID int64 `json:"static_wallet_id"` } func (q *Queries) CreateCustomerWallet(ctx context.Context, arg CreateCustomerWalletParams) (CustomerWallet, error) { - row := q.db.QueryRow(ctx, CreateCustomerWallet, - arg.CustomerID, - arg.CompanyID, - arg.RegularWalletID, - arg.StaticWalletID, - ) + row := q.db.QueryRow(ctx, CreateCustomerWallet, arg.CustomerID, arg.RegularWalletID, arg.StaticWalletID) var i CustomerWallet err := row.Scan( &i.ID, &i.CustomerID, - &i.CompanyID, &i.RegularWalletID, &i.StaticWalletID, &i.CreatedAt, @@ -190,7 +182,6 @@ func (q *Queries) GetAllWallets(ctx context.Context) ([]Wallet, error) { const GetCustomerWallet = `-- name: GetCustomerWallet :one SELECT cw.id, cw.customer_id, - cw.company_id, rw.id AS regular_id, rw.balance AS regular_balance, sw.id AS static_id, @@ -202,18 +193,11 @@ FROM customer_wallets cw JOIN wallets rw ON cw.regular_wallet_id = rw.id JOIN wallets sw ON cw.static_wallet_id = sw.id WHERE cw.customer_id = $1 - AND cw.company_id = $2 ` -type GetCustomerWalletParams struct { - CustomerID int64 `json:"customer_id"` - CompanyID int64 `json:"company_id"` -} - type GetCustomerWalletRow struct { ID int64 `json:"id"` CustomerID int64 `json:"customer_id"` - CompanyID int64 `json:"company_id"` RegularID int64 `json:"regular_id"` RegularBalance int64 `json:"regular_balance"` StaticID int64 `json:"static_id"` @@ -223,13 +207,12 @@ type GetCustomerWalletRow struct { CreatedAt pgtype.Timestamp `json:"created_at"` } -func (q *Queries) GetCustomerWallet(ctx context.Context, arg GetCustomerWalletParams) (GetCustomerWalletRow, error) { - row := q.db.QueryRow(ctx, GetCustomerWallet, arg.CustomerID, arg.CompanyID) +func (q *Queries) GetCustomerWallet(ctx context.Context, customerID int64) (GetCustomerWalletRow, error) { + row := q.db.QueryRow(ctx, GetCustomerWallet, customerID) var i GetCustomerWalletRow err := row.Scan( &i.ID, &i.CustomerID, - &i.CompanyID, &i.RegularID, &i.RegularBalance, &i.StaticID, diff --git a/internal/repository/user.go b/internal/repository/user.go index df82a40..7405542 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -202,11 +202,28 @@ func (s *Store) GetCashiersByBranch(ctx context.Context, branchID int64) ([]doma return userList, nil } -func (s *Store) SearchUserByNameOrPhone(ctx context.Context, searchString string) ([]domain.User, error) { - users, err := s.queries.SearchUserByNameOrPhone(ctx, pgtype.Text{ - String: searchString, - Valid: true, - }) +func (s *Store) SearchUserByNameOrPhone(ctx context.Context, searchString string, role *domain.Role, companyID domain.ValidInt64) ([]domain.User, error) { + + query := dbgen.SearchUserByNameOrPhoneParams{ + Column1: pgtype.Text{ + String: searchString, + Valid: true, + }, + CompanyID: pgtype.Int8{ + Int64: companyID.Value, + Valid: companyID.Valid, + }, + } + + if role != nil { + + query.Role = pgtype.Text{ + String: string(*role), + Valid: true, + } + } + + users, err := s.queries.SearchUserByNameOrPhone(ctx, query) if err != nil { return nil, err } diff --git a/internal/repository/wallet.go b/internal/repository/wallet.go index e61fb74..54fd077 100644 --- a/internal/repository/wallet.go +++ b/internal/repository/wallet.go @@ -114,11 +114,8 @@ func (s *Store) GetWalletsByUser(ctx context.Context, userID int64) ([]domain.Wa return result, nil } -func (s *Store) GetCustomerWallet(ctx context.Context, customerID int64, companyID int64) (domain.GetCustomerWallet, error) { - customerWallet, err := s.queries.GetCustomerWallet(ctx, dbgen.GetCustomerWalletParams{ - CustomerID: customerID, - CompanyID: companyID, - }) +func (s *Store) GetCustomerWallet(ctx context.Context, customerID int64) (domain.GetCustomerWallet, error) { + customerWallet, err := s.queries.GetCustomerWallet(ctx, customerID) if err != nil { return domain.GetCustomerWallet{}, err diff --git a/internal/services/bet/service.go b/internal/services/bet/service.go index 8f98433..4e3f9bf 100644 --- a/internal/services/bet/service.go +++ b/internal/services/bet/service.go @@ -23,6 +23,7 @@ var ( ErrNoEventsAvailable = errors.New("Not enough events available with the given filters") ErrGenerateRandomOutcome = errors.New("Failed to generate any random outcome for events") ErrOutcomesNotCompleted = errors.New("Some bet outcomes are still pending") + ErrEventHasBeenRemoved = errors.New("Event has been removed") ) type Service struct { @@ -75,7 +76,7 @@ func (s *Service) GenerateBetOutcome(ctx context.Context, eventID int64, marketI event, err := s.eventSvc.GetUpcomingEventByID(ctx, eventIDStr) if err != nil { - return domain.CreateBetOutcome{}, err + return domain.CreateBetOutcome{}, ErrEventHasBeenRemoved } currentTime := time.Now() diff --git a/internal/services/odds/service.go b/internal/services/odds/service.go index a2c4016..36f3a8a 100644 --- a/internal/services/odds/service.go +++ b/internal/services/odds/service.go @@ -98,7 +98,6 @@ func (s *ServiceImpl) FetchNonLiveOdds(ctx context.Context) error { s.logger.Error("Error while inserting ice hockey odd") errs = append(errs, err) } - } // result := oddsData.Results[0] diff --git a/internal/services/user/port.go b/internal/services/user/port.go index f6adec0..6a09597 100644 --- a/internal/services/user/port.go +++ b/internal/services/user/port.go @@ -21,7 +21,7 @@ type UserStore interface { CheckPhoneEmailExist(ctx context.Context, phoneNum, email string) (bool, bool, error) GetUserByEmail(ctx context.Context, email string) (domain.User, error) GetUserByPhone(ctx context.Context, phoneNum string) (domain.User, error) - SearchUserByNameOrPhone(ctx context.Context, searchString string) ([]domain.User, error) + SearchUserByNameOrPhone(ctx context.Context, searchString string, role *domain.Role, companyID domain.ValidInt64) ([]domain.User, error) UpdatePassword(ctx context.Context, identifier string, password []byte, usedOtpId int64) error // identifier verified email or phone } type SmsGateway interface { diff --git a/internal/services/user/user.go b/internal/services/user/user.go index a9d303e..6529c16 100644 --- a/internal/services/user/user.go +++ b/internal/services/user/user.go @@ -6,9 +6,9 @@ import ( "github.com/SamuelTariku/FortuneBet-Backend/internal/domain" ) -func (s *Service) SearchUserByNameOrPhone(ctx context.Context, searchString string) ([]domain.User, error) { +func (s *Service) SearchUserByNameOrPhone(ctx context.Context, searchString string, role *domain.Role, companyID domain.ValidInt64) ([]domain.User, error) { // Search user - return s.userStore.SearchUserByNameOrPhone(ctx, searchString) + return s.userStore.SearchUserByNameOrPhone(ctx, searchString, role, companyID) } func (s *Service) UpdateUser(ctx context.Context, user domain.UpdateUserReq) error { diff --git a/internal/services/wallet/port.go b/internal/services/wallet/port.go index 9271039..9c3fcb9 100644 --- a/internal/services/wallet/port.go +++ b/internal/services/wallet/port.go @@ -12,7 +12,7 @@ type WalletStore interface { GetWalletByID(ctx context.Context, id int64) (domain.Wallet, error) GetAllWallets(ctx context.Context) ([]domain.Wallet, error) GetWalletsByUser(ctx context.Context, id int64) ([]domain.Wallet, error) - GetCustomerWallet(ctx context.Context, customerID int64, companyID int64) (domain.GetCustomerWallet, error) + GetCustomerWallet(ctx context.Context, customerID int64) (domain.GetCustomerWallet, error) GetAllBranchWallets(ctx context.Context) ([]domain.BranchWallet, error) UpdateBalance(ctx context.Context, id int64, balance domain.Currency) error UpdateWalletActive(ctx context.Context, id int64, isActive bool) error diff --git a/internal/services/wallet/wallet.go b/internal/services/wallet/wallet.go index 4749af2..feb29d0 100644 --- a/internal/services/wallet/wallet.go +++ b/internal/services/wallet/wallet.go @@ -56,8 +56,8 @@ func (s *Service) GetWalletsByUser(ctx context.Context, id int64) ([]domain.Wall return s.walletStore.GetWalletsByUser(ctx, id) } -func (s *Service) GetCustomerWallet(ctx context.Context, customerID int64, companyID int64) (domain.GetCustomerWallet, error) { - return s.walletStore.GetCustomerWallet(ctx, customerID, companyID) +func (s *Service) GetCustomerWallet(ctx context.Context, customerID int64) (domain.GetCustomerWallet, error) { + return s.walletStore.GetCustomerWallet(ctx, customerID) } func (s *Service) GetAllBranchWallets(ctx context.Context) ([]domain.BranchWallet, error) { diff --git a/internal/web_server/handlers/bet_handler.go b/internal/web_server/handlers/bet_handler.go index da85139..b01fbd3 100644 --- a/internal/web_server/handlers/bet_handler.go +++ b/internal/web_server/handlers/bet_handler.go @@ -6,6 +6,7 @@ import ( "github.com/SamuelTariku/FortuneBet-Backend/internal/domain" "github.com/SamuelTariku/FortuneBet-Backend/internal/services/bet" + "github.com/SamuelTariku/FortuneBet-Backend/internal/services/wallet" "github.com/SamuelTariku/FortuneBet-Backend/internal/web_server/response" "github.com/gofiber/fiber/v2" ) @@ -42,6 +43,10 @@ func (h *Handler) CreateBet(c *fiber.Ctx) error { if err != nil { h.logger.Error("PlaceBet failed", "error", err) + switch err { + case bet.ErrEventHasBeenRemoved, bet.ErrEventHasNotEnded, bet.ErrRawOddInvalid, wallet.ErrBalanceInsufficient: + return fiber.NewError(fiber.StatusBadGateway, err.Error()) + } return fiber.NewError(fiber.StatusInternalServerError, "Unable to create bet") } @@ -180,7 +185,7 @@ func (h *Handler) GetBetByID(c *fiber.Ctx) error { bet, err := h.betSvc.GetBetByID(c.Context(), id) if err != nil { - // TODO: handle all the errors types + // TODO: handle all the errors types h.logger.Error("Failed to get bet by ID", "betID", id, "error", err) return fiber.NewError(fiber.StatusNotFound, "Failed to retrieve bet") } diff --git a/internal/web_server/handlers/user.go b/internal/web_server/handlers/user.go index 4d6efcf..a0121f9 100644 --- a/internal/web_server/handlers/user.go +++ b/internal/web_server/handlers/user.go @@ -381,7 +381,8 @@ func getMedium(email, phoneNumber string) (domain.OtpMedium, error) { } type SearchUserByNameOrPhoneReq struct { - SearchString string + SearchString string `json:"query"` + Role *domain.Role `json:"role,omitempty"` } // SearchUserByNameOrPhone godoc @@ -409,7 +410,8 @@ func (h *Handler) SearchUserByNameOrPhone(c *fiber.Ctx) error { response.WriteJSON(c, fiber.StatusBadRequest, "Invalid request", valErrs, nil) return nil } - users, err := h.userSvc.SearchUserByNameOrPhone(c.Context(), req.SearchString) + companyID := c.Locals("company_id").(domain.ValidInt64) + users, err := h.userSvc.SearchUserByNameOrPhone(c.Context(), req.SearchString, req.Role, companyID) if err != nil { h.logger.Error("SearchUserByNameOrPhone failed", "error", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{