diff --git a/internal/repository/lms_progress_tx.go b/internal/repository/lms_progress_tx.go index a6da6ae..906da8c 100644 --- a/internal/repository/lms_progress_tx.go +++ b/internal/repository/lms_progress_tx.go @@ -32,7 +32,7 @@ func (s *Store) CompleteLessonForUser(ctx context.Context, userID, lessonID int6 return err } - if err := s.cascadeLMSCompletion(ctx, q, userID, mod.ID, crs.ID, crs.ProgramID); err != nil { + if err := s.cascadeLMSCompletion(ctx, q, userID, &mod.ID, crs.ID, crs.ProgramID); err != nil { return err } @@ -62,21 +62,43 @@ func (s *Store) CompletePracticeForUser(ctx context.Context, userID, questionSet if err != nil { return err } - if !scope.ModuleID.Valid { - return fmt.Errorf("practice %d is not linked to a module", questionSetID) + var ( + moduleID *int64 + courseID int64 + ) + switch { + case scope.ModuleID.Valid: + mid := scope.ModuleID.Int64 + moduleID = &mid + mod, err := q.GetModuleByID(ctx, mid) + if err != nil { + return err + } + courseID = mod.CourseID + case scope.LessonID.Valid: + lesson, err := q.GetLessonByID(ctx, scope.LessonID.Int64) + if err != nil { + return err + } + mid := lesson.ModuleID + moduleID = &mid + mod, err := q.GetModuleByID(ctx, mid) + if err != nil { + return err + } + courseID = mod.CourseID + case scope.CourseID.Valid: + courseID = scope.CourseID.Int64 + default: + return fmt.Errorf("practice %d is not linked to lesson/module/course", questionSetID) } - mod, err := q.GetModuleByID(ctx, scope.ModuleID.Int64) + crs, err := q.GetCourseByID(ctx, courseID) if err != nil { return err } - crs, err := q.GetCourseByID(ctx, mod.CourseID) - if err != nil { - return err - } - - if err := s.cascadeLMSCompletion(ctx, q, userID, mod.ID, crs.ID, crs.ProgramID); err != nil { + if err := s.cascadeLMSCompletion(ctx, q, userID, moduleID, crs.ID, crs.ProgramID); err != nil { return err } @@ -86,38 +108,40 @@ func (s *Store) CompletePracticeForUser(ctx context.Context, userID, questionSet return nil } -func (s *Store) cascadeLMSCompletion(ctx context.Context, q *dbgen.Queries, userID, moduleID, courseID, programID int64) error { - moduleLessonsTotal, err := q.CountLessonsInModule(ctx, moduleID) - if err != nil { - return err - } - moduleLessonsDone, err := q.CountUserCompletedLessonsInModule(ctx, dbgen.CountUserCompletedLessonsInModuleParams{ - ModuleID: moduleID, - UserID: userID, - }) - if err != nil { - return err - } - modulePracticesTotal, err := q.CountPublishedPracticesInModule(ctx, toPgInt8(&moduleID)) - if err != nil { - return err - } - modulePracticesDone, err := q.CountUserCompletedPublishedPracticesInModule(ctx, dbgen.CountUserCompletedPublishedPracticesInModuleParams{ - ModuleID: toPgInt8(&moduleID), - UserID: userID, - }) - if err != nil { - return err - } +func (s *Store) cascadeLMSCompletion(ctx context.Context, q *dbgen.Queries, userID int64, moduleID *int64, courseID, programID int64) error { + if moduleID != nil { + moduleLessonsTotal, err := q.CountLessonsInModule(ctx, *moduleID) + if err != nil { + return err + } + moduleLessonsDone, err := q.CountUserCompletedLessonsInModule(ctx, dbgen.CountUserCompletedLessonsInModuleParams{ + ModuleID: *moduleID, + UserID: userID, + }) + if err != nil { + return err + } + modulePracticesTotal, err := q.CountPublishedPracticesInModule(ctx, toPgInt8(moduleID)) + if err != nil { + return err + } + modulePracticesDone, err := q.CountUserCompletedPublishedPracticesInModule(ctx, dbgen.CountUserCompletedPublishedPracticesInModuleParams{ + ModuleID: toPgInt8(moduleID), + UserID: userID, + }) + if err != nil { + return err + } - moduleLessonsComplete := moduleLessonsTotal > 0 && moduleLessonsDone >= moduleLessonsTotal - modulePracticesComplete := modulePracticesDone >= modulePracticesTotal - if !moduleLessonsComplete || !modulePracticesComplete { - return nil - } + moduleLessonsComplete := moduleLessonsTotal > 0 && moduleLessonsDone >= moduleLessonsTotal + modulePracticesComplete := modulePracticesDone >= modulePracticesTotal + if !moduleLessonsComplete || !modulePracticesComplete { + return nil + } - if err := q.InsertUserModuleProgress(ctx, dbgen.InsertUserModuleProgressParams{UserID: userID, ModuleID: moduleID}); err != nil { - return err + if err := q.InsertUserModuleProgress(ctx, dbgen.InsertUserModuleProgressParams{UserID: userID, ModuleID: *moduleID}); err != nil { + return err + } } nMods, err := q.CountModulesInCourse(ctx, courseID)