Yimaru-BackEnd/internal/repository/subscriptions.go

311 lines
9.0 KiB
Go

package repository
import (
dbgen "Yimaru-Backend/gen/db"
"Yimaru-Backend/internal/domain"
"context"
"time"
"github.com/jackc/pgx/v5/pgtype"
"github.com/shopspring/decimal"
)
// Helper functions for numeric conversions
func toPgNumeric(val float64) pgtype.Numeric {
d := decimal.NewFromFloat(val)
var num pgtype.Numeric
_ = num.Scan(d.String())
return num
}
func fromPgNumeric(num pgtype.Numeric) float64 {
if !num.Valid {
return 0
}
f, _ := num.Float64Value()
return f.Float64
}
func toPgTimestamptz(t time.Time) pgtype.Timestamptz {
return pgtype.Timestamptz{Time: t, Valid: true}
}
func toPgTimestamptzPtr(t *time.Time) pgtype.Timestamptz {
if t == nil {
return pgtype.Timestamptz{Valid: false}
}
return pgtype.Timestamptz{Time: *t, Valid: true}
}
// =====================
// Subscription Plans
// =====================
func (s *Store) CreateSubscriptionPlan(ctx context.Context, input domain.CreateSubscriptionPlanInput) (*domain.SubscriptionPlan, error) {
plan, err := s.queries.CreateSubscriptionPlan(ctx, dbgen.CreateSubscriptionPlanParams{
Name: input.Name,
Description: toPgText(input.Description),
DurationValue: input.DurationValue,
DurationUnit: input.DurationUnit,
Price: toPgNumeric(input.Price),
Currency: input.Currency,
Column7: input.IsActive,
})
if err != nil {
return nil, err
}
return subscriptionPlanToDomain(plan), nil
}
func (s *Store) GetSubscriptionPlanByID(ctx context.Context, id int64) (*domain.SubscriptionPlan, error) {
plan, err := s.queries.GetSubscriptionPlanByID(ctx, id)
if err != nil {
return nil, err
}
return subscriptionPlanToDomain(plan), nil
}
func (s *Store) ListSubscriptionPlans(ctx context.Context, activeOnly bool) ([]domain.SubscriptionPlan, error) {
var plans []dbgen.SubscriptionPlan
var err error
if activeOnly {
plans, err = s.queries.ListActiveSubscriptionPlans(ctx)
} else {
plans, err = s.queries.ListSubscriptionPlans(ctx, false)
}
if err != nil {
return nil, err
}
result := make([]domain.SubscriptionPlan, len(plans))
for i, p := range plans {
result[i] = *subscriptionPlanToDomain(p)
}
return result, nil
}
func (s *Store) UpdateSubscriptionPlan(ctx context.Context, id int64, input domain.UpdateSubscriptionPlanInput) error {
return s.queries.UpdateSubscriptionPlan(ctx, dbgen.UpdateSubscriptionPlanParams{
Name: stringVal(input.Name),
Description: toPgText(input.Description),
DurationValue: int32Val(input.DurationValue),
DurationUnit: stringVal(input.DurationUnit),
Price: numericPtrToNumeric(input.Price),
Currency: stringVal(input.Currency),
IsActive: boolPtrToBool(input.IsActive),
ID: id,
})
}
func (s *Store) DeleteSubscriptionPlan(ctx context.Context, id int64) error {
return s.queries.DeleteSubscriptionPlan(ctx, id)
}
// =====================
// User Subscriptions
// =====================
func (s *Store) CreateUserSubscription(ctx context.Context, input domain.CreateUserSubscriptionInput) (*domain.UserSubscription, error) {
sub, err := s.queries.CreateUserSubscription(ctx, dbgen.CreateUserSubscriptionParams{
UserID: input.UserID,
PlanID: input.PlanID,
Column3: input.StartsAt,
ExpiresAt: toPgTimestamptz(input.ExpiresAt),
Column5: input.Status,
PaymentReference: toPgText(input.PaymentReference),
PaymentMethod: toPgText(input.PaymentMethod),
Column8: input.AutoRenew,
})
if err != nil {
return nil, err
}
return userSubscriptionToDomain(sub), nil
}
func (s *Store) GetUserSubscriptionByID(ctx context.Context, id int64) (*domain.UserSubscription, error) {
sub, err := s.queries.GetUserSubscriptionByID(ctx, id)
if err != nil {
return nil, err
}
return userSubscriptionWithPlanToDomain(sub), nil
}
func (s *Store) GetActiveSubscriptionByUserID(ctx context.Context, userID int64) (*domain.UserSubscription, error) {
sub, err := s.queries.GetActiveSubscriptionByUserID(ctx, userID)
if err != nil {
return nil, err
}
return &domain.UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
PlanID: sub.PlanID,
StartsAt: sub.StartsAt.Time,
ExpiresAt: sub.ExpiresAt.Time,
Status: sub.Status,
PaymentReference: fromPgText(sub.PaymentReference),
PaymentMethod: fromPgText(sub.PaymentMethod),
AutoRenew: sub.AutoRenew,
CancelledAt: timePtr(sub.CancelledAt),
CreatedAt: sub.CreatedAt.Time,
UpdatedAt: timePtr(sub.UpdatedAt),
PlanName: &sub.PlanName,
DurationValue: &sub.DurationValue,
DurationUnit: &sub.DurationUnit,
Price: float64Ptr(fromPgNumeric(sub.Price)),
Currency: &sub.Currency,
}, nil
}
func (s *Store) GetUserSubscriptionHistory(ctx context.Context, userID int64, limit, offset int32) ([]domain.UserSubscription, error) {
subs, err := s.queries.GetUserSubscriptionHistory(ctx, dbgen.GetUserSubscriptionHistoryParams{
UserID: userID,
Limit: pgtype.Int4{Int32: limit, Valid: true},
Offset: pgtype.Int4{Int32: offset, Valid: true},
})
if err != nil {
return nil, err
}
result := make([]domain.UserSubscription, len(subs))
for i, sub := range subs {
result[i] = domain.UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
PlanID: sub.PlanID,
StartsAt: sub.StartsAt.Time,
ExpiresAt: sub.ExpiresAt.Time,
Status: sub.Status,
PaymentReference: fromPgText(sub.PaymentReference),
PaymentMethod: fromPgText(sub.PaymentMethod),
AutoRenew: sub.AutoRenew,
CancelledAt: timePtr(sub.CancelledAt),
CreatedAt: sub.CreatedAt.Time,
UpdatedAt: timePtr(sub.UpdatedAt),
PlanName: &sub.PlanName,
DurationValue: &sub.DurationValue,
DurationUnit: &sub.DurationUnit,
Price: float64Ptr(fromPgNumeric(sub.Price)),
Currency: &sub.Currency,
}
}
return result, nil
}
func (s *Store) HasActiveSubscription(ctx context.Context, userID int64) (bool, error) {
return s.queries.HasActiveSubscription(ctx, userID)
}
func (s *Store) CancelUserSubscription(ctx context.Context, id int64) error {
return s.queries.CancelUserSubscription(ctx, id)
}
func (s *Store) UpdateSubscriptionStatus(ctx context.Context, id int64, status string) error {
return s.queries.UpdateUserSubscriptionStatus(ctx, dbgen.UpdateUserSubscriptionStatusParams{
Status: status,
ID: id,
})
}
func (s *Store) UpdateAutoRenew(ctx context.Context, id int64, autoRenew bool) error {
return s.queries.UpdateAutoRenew(ctx, dbgen.UpdateAutoRenewParams{
AutoRenew: autoRenew,
ID: id,
})
}
func (s *Store) ExtendSubscription(ctx context.Context, id int64, newExpiresAt time.Time) error {
return s.queries.ExtendSubscription(ctx, dbgen.ExtendSubscriptionParams{
ExpiresAt: toPgTimestamptz(newExpiresAt),
ID: id,
})
}
// Helper conversion functions
func subscriptionPlanToDomain(p dbgen.SubscriptionPlan) *domain.SubscriptionPlan {
return &domain.SubscriptionPlan{
ID: p.ID,
Name: p.Name,
Description: fromPgText(p.Description),
DurationValue: p.DurationValue,
DurationUnit: p.DurationUnit,
Price: fromPgNumeric(p.Price),
Currency: p.Currency,
IsActive: p.IsActive,
CreatedAt: p.CreatedAt.Time,
UpdatedAt: timePtr(p.UpdatedAt),
}
}
func userSubscriptionToDomain(s dbgen.UserSubscription) *domain.UserSubscription {
return &domain.UserSubscription{
ID: s.ID,
UserID: s.UserID,
PlanID: s.PlanID,
StartsAt: s.StartsAt.Time,
ExpiresAt: s.ExpiresAt.Time,
Status: s.Status,
PaymentReference: fromPgText(s.PaymentReference),
PaymentMethod: fromPgText(s.PaymentMethod),
AutoRenew: s.AutoRenew,
CancelledAt: timePtr(s.CancelledAt),
CreatedAt: s.CreatedAt.Time,
UpdatedAt: timePtr(s.UpdatedAt),
}
}
func userSubscriptionWithPlanToDomain(s dbgen.GetUserSubscriptionByIDRow) *domain.UserSubscription {
return &domain.UserSubscription{
ID: s.ID,
UserID: s.UserID,
PlanID: s.PlanID,
StartsAt: s.StartsAt.Time,
ExpiresAt: s.ExpiresAt.Time,
Status: s.Status,
PaymentReference: fromPgText(s.PaymentReference),
PaymentMethod: fromPgText(s.PaymentMethod),
AutoRenew: s.AutoRenew,
CancelledAt: timePtr(s.CancelledAt),
CreatedAt: s.CreatedAt.Time,
UpdatedAt: timePtr(s.UpdatedAt),
PlanName: &s.PlanName,
DurationValue: &s.DurationValue,
DurationUnit: &s.DurationUnit,
Price: float64Ptr(fromPgNumeric(s.Price)),
Currency: &s.Currency,
}
}
func stringVal(s *string) string {
if s == nil {
return ""
}
return *s
}
func int32Val(i *int32) int32 {
if i == nil {
return 0
}
return *i
}
func numericPtrToNumeric(val *float64) pgtype.Numeric {
if val == nil {
return pgtype.Numeric{Valid: false}
}
return toPgNumeric(*val)
}
func boolPtrToBool(b *bool) bool {
if b == nil {
return false
}
return *b
}
func float64Ptr(f float64) *float64 {
return &f
}