Yimaru-BackEnd/internal/repository/notification.go

286 lines
6.4 KiB
Go

package repository
import (
"context"
"encoding/json"
"strconv"
dbgen "Yimaru-Backend/gen/db"
"Yimaru-Backend/internal/domain"
"Yimaru-Backend/internal/ports"
"github.com/jackc/pgx/v5/pgtype"
)
func NewNotificationStore(s *Store) ports.NotificationStore {
return s
}
/* =========================
Create
========================= */
func (r *Store) CreateNotification(
ctx context.Context,
n *domain.Notification,
) (*domain.Notification, error) {
receiverType := string(n.ReceiverType)
if receiverType == "" {
receiverType = string(domain.ReceiverTypeUser)
}
params := dbgen.CreateNotificationParams{
UserID: n.RecipientID,
ReceiverType: receiverType,
Type: string(n.Type),
Level: string(n.Level),
Channel: pgtype.Text{String: string(n.DeliveryChannel), Valid: true},
Title: n.Payload.Headline,
Message: n.Payload.Message,
Payload: marshalPayload(n.Payload),
}
dbNotif, err := r.queries.CreateNotification(ctx, params)
if err != nil {
return nil, err
}
return mapDBToDomain(&dbNotif), nil
}
/* =========================
Read
========================= */
func (r *Store) GetUserNotifications(
ctx context.Context,
userID int64,
limit, offset int,
) ([]domain.Notification, int64, error) {
params := dbgen.GetUserNotificationsParams{
UserID: userID,
Limit: int32(limit),
Offset: int32(offset),
}
rows, err := r.queries.GetUserNotifications(ctx, params)
if err != nil {
return nil, 0, err
}
total, err := r.queries.GetUserNotificationCount(ctx, userID)
if err != nil {
return nil, 0, err
}
result := make([]domain.Notification, 0, len(rows))
for _, row := range rows {
result = append(result, *mapDBToDomain(&row))
}
return result, total, nil
}
func (r *Store) GetAllNotifications(
ctx context.Context,
limit, offset int,
) ([]domain.Notification, error) {
rows, err := r.queries.GetAllNotifications(ctx, dbgen.GetAllNotificationsParams{
Limit: int32(limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
result := make([]domain.Notification, 0, len(rows))
for _, row := range rows {
result = append(result, *mapDBToDomain(&row))
}
return result, nil
}
func (r *Store) GetFilteredNotifications(
ctx context.Context,
filter domain.NotificationFilter,
) ([]domain.Notification, int64, error) {
filterParams := dbgen.GetFilteredNotificationsParams{
FilterChannel: pgtype.Text{String: filter.Channel, Valid: filter.Channel != ""},
FilterType: pgtype.Text{String: filter.Type, Valid: filter.Type != ""},
PageLimit: int32(filter.Limit),
PageOffset: int32(filter.Offset),
}
countParams := dbgen.GetFilteredNotificationCountParams{
FilterChannel: filterParams.FilterChannel,
FilterType: filterParams.FilterType,
}
if filter.UserID != nil {
v := pgtype.Int8{Int64: *filter.UserID, Valid: true}
filterParams.FilterUserID = v
countParams.FilterUserID = v
}
if filter.IsRead != nil {
v := pgtype.Bool{Bool: *filter.IsRead, Valid: true}
filterParams.FilterIsRead = v
countParams.FilterIsRead = v
}
if filter.After != nil {
v := pgtype.Timestamptz{Time: *filter.After, Valid: true}
filterParams.FilterAfter = v
countParams.FilterAfter = v
}
if filter.Before != nil {
v := pgtype.Timestamptz{Time: *filter.Before, Valid: true}
filterParams.FilterBefore = v
countParams.FilterBefore = v
}
rows, err := r.queries.GetFilteredNotifications(ctx, filterParams)
if err != nil {
return nil, 0, err
}
total, err := r.queries.GetFilteredNotificationCount(ctx, countParams)
if err != nil {
return nil, 0, err
}
result := make([]domain.Notification, 0, len(rows))
for _, row := range rows {
result = append(result, *mapDBToDomain(&row))
}
return result, total, nil
}
func (r *Store) CountUnreadNotifications(
ctx context.Context,
userID int64,
) (int64, error) {
return r.queries.CountUnreadNotifications(ctx, userID)
}
/* =========================
Update
========================= */
func (r *Store) MarkNotificationAsRead(
ctx context.Context,
id int64,
) (*domain.Notification, error) {
dbNotif, err := r.queries.MarkNotificationAsRead(ctx, id)
if err != nil {
return nil, err
}
return mapDBToDomain(&dbNotif), nil
}
func (r *Store) MarkAllUserNotificationsAsRead(
ctx context.Context,
userID int64,
) error {
return r.queries.MarkAllUserNotificationsAsRead(ctx, userID)
}
func (r *Store) MarkNotificationAsUnread(
ctx context.Context,
id int64,
) (*domain.Notification, error) {
dbNotif, err := r.queries.MarkNotificationAsUnread(ctx, id)
if err != nil {
return nil, err
}
return mapDBToDomain(&dbNotif), nil
}
func (r *Store) MarkAllUserNotificationsAsUnread(
ctx context.Context,
userID int64,
) error {
return r.queries.MarkAllUserNotificationsAsUnread(ctx, userID)
}
/* =========================
Delete
========================= */
func (r *Store) DeleteUserNotifications(
ctx context.Context,
userID int64,
) error {
return r.queries.DeleteUserNotifications(ctx, userID)
}
/* =========================
Mapping
========================= */
func mapDBToDomain(db *dbgen.Notification) *domain.Notification {
payload, err := unmarshalPayload(db.Payload)
if err != nil {
payload = domain.NotificationPayload{}
}
headline := payload.Headline
if headline == "" {
headline = db.Title
}
message := payload.Message
if message == "" {
message = db.Message
}
var channel domain.DeliveryChannel
if db.Channel.Valid {
channel = domain.DeliveryChannel(db.Channel.String)
}
return &domain.Notification{
ID: strconv.FormatInt(db.ID, 10),
RecipientID: db.UserID,
ReceiverType: domain.ReceiverType(db.ReceiverType),
Type: domain.NotificationType(db.Type),
Level: domain.NotificationLevel(db.Level),
DeliveryChannel: channel,
DeliveryStatus: domain.DeliveryStatusPending,
Payload: domain.NotificationPayload{
Headline: headline,
Message: message,
Tags: payload.Tags,
},
IsRead: db.IsRead,
Timestamp: db.CreatedAt.Time,
// ReadAt: db.ReadAt.Time,
}
}
/* =========================
JSON Helpers
========================= */
func marshalPayload(p domain.NotificationPayload) []byte {
b, _ := json.Marshal(p)
return b
}
func unmarshalPayload(b []byte) (domain.NotificationPayload, error) {
var p domain.NotificationPayload
if len(b) == 0 {
return p, nil
}
if err := json.Unmarshal(b, &p); err != nil {
return p, err
}
return p, nil
}