diff --git a/db/query/otp.sql b/db/query/otp.sql index 8d181a2..c36c18d 100644 --- a/db/query/otp.sql +++ b/db/query/otp.sql @@ -25,6 +25,14 @@ WHERE user_id = $1 ORDER BY id DESC LIMIT 1; +-- name: GetOtpByCode :one +SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at +FROM otps +WHERE user_id = $1 + AND otp = $2 +ORDER BY id DESC +LIMIT 1; + -- name: MarkOtpAsUsed :exec UPDATE otps SET used = TRUE, used_at = $2 diff --git a/gen/db/otp.sql.go b/gen/db/otp.sql.go index 3a9abb7..e35eb8c 100644 --- a/gen/db/otp.sql.go +++ b/gen/db/otp.sql.go @@ -83,6 +83,51 @@ func (q *Queries) GetOtp(ctx context.Context, userID int64) (GetOtpRow, error) { return i, err } +const GetOtpByCode = `-- name: GetOtpByCode :one +SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at +FROM otps +WHERE user_id = $1 + AND otp = $2 +ORDER BY id DESC +LIMIT 1 +` + +type GetOtpByCodeParams struct { + UserID int64 `json:"user_id"` + Otp string `json:"otp"` +} + +type GetOtpByCodeRow struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + SentTo string `json:"sent_to"` + Medium string `json:"medium"` + OtpFor string `json:"otp_for"` + Otp string `json:"otp"` + Used bool `json:"used"` + UsedAt pgtype.Timestamptz `json:"used_at"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` +} + +func (q *Queries) GetOtpByCode(ctx context.Context, arg GetOtpByCodeParams) (GetOtpByCodeRow, error) { + row := q.db.QueryRow(ctx, GetOtpByCode, arg.UserID, arg.Otp) + var i GetOtpByCodeRow + err := row.Scan( + &i.ID, + &i.UserID, + &i.SentTo, + &i.Medium, + &i.OtpFor, + &i.Otp, + &i.Used, + &i.UsedAt, + &i.CreatedAt, + &i.ExpiresAt, + ) + return i, err +} + const MarkOtpAsUsed = `-- name: MarkOtpAsUsed :exec UPDATE otps SET used = TRUE, used_at = $2 diff --git a/internal/ports/user.go b/internal/ports/user.go index e5e067a..8ad82d7 100644 --- a/internal/ports/user.go +++ b/internal/ports/user.go @@ -94,4 +94,5 @@ type OtpStore interface { MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error CreateOtp(ctx context.Context, otp domain.Otp) error GetOtp(ctx context.Context, userID int64) (domain.Otp, error) + GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error) } diff --git a/internal/repository/otp.go b/internal/repository/otp.go index 37f831e..851cafe 100644 --- a/internal/repository/otp.go +++ b/internal/repository/otp.go @@ -67,6 +67,36 @@ func (s *Store) GetOtp(ctx context.Context, userID int64) (domain.Otp, error) { ExpiresAt: row.ExpiresAt.Time, }, nil } + +func (s *Store) GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error) { + row, err := s.queries.GetOtpByCode(ctx, dbgen.GetOtpByCodeParams{ + UserID: userID, + Otp: otpCode, + }) + if err != nil { + if err == sql.ErrNoRows { + return domain.Otp{}, domain.ErrOtpNotFound + } + return domain.Otp{}, err + } + if !row.ExpiresAt.Valid { + return domain.Otp{}, domain.ErrOtpNotFound + } + + return domain.Otp{ + ID: row.ID, + UserID: row.UserID, + SentTo: row.SentTo, + Medium: domain.OtpMedium(row.Medium), + For: domain.OtpFor(row.OtpFor), + Otp: row.Otp, + Used: row.Used, + UsedAt: row.UsedAt.Time, + CreatedAt: row.CreatedAt.Time, + ExpiresAt: row.ExpiresAt.Time, + }, nil +} + func (s *Store) MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error { return s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{ ID: otp.ID, diff --git a/internal/services/authentication/impl.go b/internal/services/authentication/impl.go index 605fb43..5675acf 100644 --- a/internal/services/authentication/impl.go +++ b/internal/services/authentication/impl.go @@ -95,9 +95,13 @@ func (s *Service) VerifyOtp( return domain.LoginSuccess{}, err } - // 1. Retrieve OTP - storedOtp, err := s.otpStore.GetOtp(ctx, user.ID) + // 1. Retrieve OTP row matching submitted code. + // This avoids false positives when another OTP row exists for the same user. + storedOtp, err := s.otpStore.GetOtpByCode(ctx, user.ID, otpCode) if err != nil { + if errors.Is(err, domain.ErrOtpNotFound) { + return domain.LoginSuccess{}, domain.ErrInvalidOtp + } return domain.LoginSuccess{}, err } @@ -111,12 +115,7 @@ func (s *Service) VerifyOtp( return domain.LoginSuccess{}, domain.ErrOtpExpired } - // 4. Invalid - if storedOtp.Otp != otpCode { - return domain.LoginSuccess{}, domain.ErrInvalidOtp - } - - // 5. Mark OTP as used + // 4. Mark OTP as used storedOtp.Used = true storedOtp.UsedAt = timePtr(time.Now())