diff --git a/internal/repository/lms_user_progress_snapshot.go b/internal/repository/lms_user_progress_snapshot.go index d255b6e..8c833d6 100644 --- a/internal/repository/lms_user_progress_snapshot.go +++ b/internal/repository/lms_user_progress_snapshot.go @@ -31,7 +31,7 @@ func (s *Store) GetLMSUserProgressSnapshot(ctx context.Context, userID int64) (d LessonIDs: pgInt8IDsToInt64(lessons), ModuleIDs: pgInt8IDsToInt64(mods), CourseIDs: pgInt8IDsToInt64(courses), - ProgramIDs: programs, + ProgramIDs: int64IDsOrEmpty(programs), }, nil } @@ -46,6 +46,13 @@ func pgInt8IDsToInt64(items []pgtype.Int8) []int64 { return out } +func int64IDsOrEmpty(items []int64) []int64 { + if items == nil { + return []int64{} + } + return items +} + // ListUserLMSFlatLearningActivity returns flattened LMS activity rows for admin reporting (lesson + practice completions). func (s *Store) ListUserLMSFlatLearningActivity(ctx context.Context, userID int64) ([]dbgen.ListUserLMSFlatLearningActivityByUserRow, error) { return s.queries.ListUserLMSFlatLearningActivityByUser(ctx, userID) diff --git a/internal/repository/lms_user_progress_snapshot_test.go b/internal/repository/lms_user_progress_snapshot_test.go new file mode 100644 index 0000000..668e823 --- /dev/null +++ b/internal/repository/lms_user_progress_snapshot_test.go @@ -0,0 +1,41 @@ +package repository + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +func TestPgInt8IDsToInt64ReturnsEmptySlice(t *testing.T) { + got := pgInt8IDsToInt64(nil) + if got == nil { + t.Fatal("expected empty slice, got nil") + } + if len(got) != 0 { + t.Fatalf("expected empty slice, got len=%d", len(got)) + } +} + +func TestPgInt8IDsToInt64FiltersInvalidIDs(t *testing.T) { + got := pgInt8IDsToInt64([]pgtype.Int8{ + {Int64: 10, Valid: true}, + {Valid: false}, + {Int64: 20, Valid: true}, + }) + if len(got) != 2 { + t.Fatalf("expected 2 ids, got %d", len(got)) + } + if got[0] != 10 || got[1] != 20 { + t.Fatalf("unexpected ids: %#v", got) + } +} + +func TestInt64IDsOrEmptyReturnsEmptySlice(t *testing.T) { + got := int64IDsOrEmpty(nil) + if got == nil { + t.Fatal("expected empty slice, got nil") + } + if len(got) != 0 { + t.Fatalf("expected empty slice, got len=%d", len(got)) + } +}