Yimaru-BackEnd/internal/repository/company.go

140 lines
3.7 KiB
Go

package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/SamuelTariku/FortuneBet-Backend/internal/domain"
"github.com/SamuelTariku/FortuneBet-Backend/internal/pkgs/helpers"
"github.com/jackc/pgx/v5/pgtype"
)
func (s *Store) CreateCompany(ctx context.Context, company domain.CreateCompany) (domain.Company, error) {
baseSlug := helpers.GenerateSlug(company.Name)
uniqueSlug := baseSlug
i := 1
for {
_, err := s.queries.GetCompanyIDUsingSlug(ctx, uniqueSlug)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// slug is unique
break
} else {
// real DB error
return domain.Company{}, err
}
}
uniqueSlug = fmt.Sprintf("%s-%d", baseSlug, i)
i++
}
dbCompany, err := s.queries.CreateCompany(ctx, domain.ConvertCreateCompany(company, uniqueSlug))
if err != nil {
return domain.Company{}, err
}
return domain.ConvertDBCompany(dbCompany), nil
}
func (s *Store) GetAllCompanies(ctx context.Context, filter domain.CompanyFilter) ([]domain.GetCompany, error) {
dbCompanies, err := s.queries.GetAllCompanies(ctx, domain.ConvertGetAllCompaniesParams(filter))
if err != nil {
return nil, err
}
var companies []domain.GetCompany = make([]domain.GetCompany, 0, len(dbCompanies))
for _, dbCompany := range dbCompanies {
companies = append(companies, domain.ConvertDBCompanyDetails(dbCompany))
}
return companies, nil
}
func (s *Store) SearchCompanyByName(ctx context.Context, name string) ([]domain.GetCompany, error) {
dbCompanies, err := s.queries.SearchCompanyByName(ctx, pgtype.Text{
String: name,
Valid: true,
})
if err != nil {
return nil, err
}
var companies []domain.GetCompany = make([]domain.GetCompany, 0, len(dbCompanies))
for _, dbCompany := range dbCompanies {
companies = append(companies, domain.ConvertDBCompanyDetails(dbCompany))
}
return companies, nil
}
func (s *Store) GetCompanyByID(ctx context.Context, id int64) (domain.GetCompany, error) {
dbCompany, err := s.queries.GetCompanyByID(ctx, id)
if err != nil {
return domain.GetCompany{}, err
}
return domain.ConvertDBCompanyDetails(dbCompany), nil
}
func (s *Store) GetCompanyIDBySlug(ctx context.Context, slug string) (int64, error) {
dbCompanyID, err := s.queries.GetCompanyIDUsingSlug(ctx, slug)
if err != nil {
return 0, err
}
return dbCompanyID, nil
}
func (s *Store) UpdateCompany(ctx context.Context, company domain.UpdateCompany) (domain.Company, error) {
dbCompany, err := s.queries.UpdateCompany(ctx, domain.ConvertUpdateCompany(company))
if err != nil {
return domain.Company{}, err
}
return domain.ConvertDBCompany(dbCompany), nil
}
func (s *Store) DeleteCompany(ctx context.Context, id int64) error {
return s.queries.DeleteCompany(ctx, id)
}
func (s *Store) GetCompanyCounts(ctx context.Context, filter domain.ReportFilter) (total, active, inactive int64, err error) {
query := `SELECT
COUNT(*) as total,
COUNT(CASE WHEN w.is_active = true THEN 1 END) as active,
COUNT(CASE WHEN w.is_active = false THEN 1 END) as inactive
FROM companies c
JOIN wallets w ON c.wallet_id = w.id`
args := []interface{}{}
argPos := 1
// Add filters if provided
if filter.StartTime.Valid {
query += fmt.Sprintf(" WHERE %screated_at >= $%d", func() string {
if len(args) == 0 {
return ""
}
return " AND "
}(), argPos)
args = append(args, filter.StartTime.Value)
argPos++
}
if filter.EndTime.Valid {
query += fmt.Sprintf(" AND created_at <= $%d", argPos)
args = append(args, filter.EndTime.Value)
argPos++
}
row := s.conn.QueryRow(ctx, query, args...)
err = row.Scan(&total, &active, &inactive)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to get company counts: %w", err)
}
return total, active, inactive, nil
}