fix OTP verification by submitted code

Resolve false OTP already used/expired responses during registration by loading OTP rows using user_id plus submitted otp code and validating usage/expiry on the matched row.

Made-with: Cursor
This commit is contained in:
Yared Yemane 2026-04-25 05:07:19 -07:00
parent 526426d9f9
commit 78f231f222
5 changed files with 91 additions and 8 deletions

View File

@ -25,6 +25,14 @@ WHERE user_id = $1
ORDER BY id DESC ORDER BY id DESC
LIMIT 1; LIMIT 1;
-- name: GetOtpByCode :one
SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at
FROM otps
WHERE user_id = $1
AND otp = $2
ORDER BY id DESC
LIMIT 1;
-- name: MarkOtpAsUsed :exec -- name: MarkOtpAsUsed :exec
UPDATE otps UPDATE otps
SET used = TRUE, used_at = $2 SET used = TRUE, used_at = $2

View File

@ -83,6 +83,51 @@ func (q *Queries) GetOtp(ctx context.Context, userID int64) (GetOtpRow, error) {
return i, err return i, err
} }
const GetOtpByCode = `-- name: GetOtpByCode :one
SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at
FROM otps
WHERE user_id = $1
AND otp = $2
ORDER BY id DESC
LIMIT 1
`
type GetOtpByCodeParams struct {
UserID int64 `json:"user_id"`
Otp string `json:"otp"`
}
type GetOtpByCodeRow struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
SentTo string `json:"sent_to"`
Medium string `json:"medium"`
OtpFor string `json:"otp_for"`
Otp string `json:"otp"`
Used bool `json:"used"`
UsedAt pgtype.Timestamptz `json:"used_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) GetOtpByCode(ctx context.Context, arg GetOtpByCodeParams) (GetOtpByCodeRow, error) {
row := q.db.QueryRow(ctx, GetOtpByCode, arg.UserID, arg.Otp)
var i GetOtpByCodeRow
err := row.Scan(
&i.ID,
&i.UserID,
&i.SentTo,
&i.Medium,
&i.OtpFor,
&i.Otp,
&i.Used,
&i.UsedAt,
&i.CreatedAt,
&i.ExpiresAt,
)
return i, err
}
const MarkOtpAsUsed = `-- name: MarkOtpAsUsed :exec const MarkOtpAsUsed = `-- name: MarkOtpAsUsed :exec
UPDATE otps UPDATE otps
SET used = TRUE, used_at = $2 SET used = TRUE, used_at = $2

View File

@ -94,4 +94,5 @@ type OtpStore interface {
MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error
CreateOtp(ctx context.Context, otp domain.Otp) error CreateOtp(ctx context.Context, otp domain.Otp) error
GetOtp(ctx context.Context, userID int64) (domain.Otp, error) GetOtp(ctx context.Context, userID int64) (domain.Otp, error)
GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error)
} }

View File

@ -67,6 +67,36 @@ func (s *Store) GetOtp(ctx context.Context, userID int64) (domain.Otp, error) {
ExpiresAt: row.ExpiresAt.Time, ExpiresAt: row.ExpiresAt.Time,
}, nil }, nil
} }
func (s *Store) GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error) {
row, err := s.queries.GetOtpByCode(ctx, dbgen.GetOtpByCodeParams{
UserID: userID,
Otp: otpCode,
})
if err != nil {
if err == sql.ErrNoRows {
return domain.Otp{}, domain.ErrOtpNotFound
}
return domain.Otp{}, err
}
if !row.ExpiresAt.Valid {
return domain.Otp{}, domain.ErrOtpNotFound
}
return domain.Otp{
ID: row.ID,
UserID: row.UserID,
SentTo: row.SentTo,
Medium: domain.OtpMedium(row.Medium),
For: domain.OtpFor(row.OtpFor),
Otp: row.Otp,
Used: row.Used,
UsedAt: row.UsedAt.Time,
CreatedAt: row.CreatedAt.Time,
ExpiresAt: row.ExpiresAt.Time,
}, nil
}
func (s *Store) MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error { func (s *Store) MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error {
return s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{ return s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{
ID: otp.ID, ID: otp.ID,

View File

@ -95,9 +95,13 @@ func (s *Service) VerifyOtp(
return domain.LoginSuccess{}, err return domain.LoginSuccess{}, err
} }
// 1. Retrieve OTP // 1. Retrieve OTP row matching submitted code.
storedOtp, err := s.otpStore.GetOtp(ctx, user.ID) // This avoids false positives when another OTP row exists for the same user.
storedOtp, err := s.otpStore.GetOtpByCode(ctx, user.ID, otpCode)
if err != nil { if err != nil {
if errors.Is(err, domain.ErrOtpNotFound) {
return domain.LoginSuccess{}, domain.ErrInvalidOtp
}
return domain.LoginSuccess{}, err return domain.LoginSuccess{}, err
} }
@ -111,12 +115,7 @@ func (s *Service) VerifyOtp(
return domain.LoginSuccess{}, domain.ErrOtpExpired return domain.LoginSuccess{}, domain.ErrOtpExpired
} }
// 4. Invalid // 4. Mark OTP as used
if storedOtp.Otp != otpCode {
return domain.LoginSuccess{}, domain.ErrInvalidOtp
}
// 5. Mark OTP as used
storedOtp.Used = true storedOtp.Used = true
storedOtp.UsedAt = timePtr(time.Now()) storedOtp.UsedAt = timePtr(time.Now())