fix OTP verification by submitted code
Resolve false OTP already used/expired responses during registration by loading OTP rows using user_id plus submitted otp code and validating usage/expiry on the matched row. Made-with: Cursor
This commit is contained in:
parent
526426d9f9
commit
78f231f222
|
|
@ -25,6 +25,14 @@ WHERE user_id = $1
|
|||
ORDER BY id DESC
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetOtpByCode :one
|
||||
SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at
|
||||
FROM otps
|
||||
WHERE user_id = $1
|
||||
AND otp = $2
|
||||
ORDER BY id DESC
|
||||
LIMIT 1;
|
||||
|
||||
-- name: MarkOtpAsUsed :exec
|
||||
UPDATE otps
|
||||
SET used = TRUE, used_at = $2
|
||||
|
|
|
|||
|
|
@ -83,6 +83,51 @@ func (q *Queries) GetOtp(ctx context.Context, userID int64) (GetOtpRow, error) {
|
|||
return i, err
|
||||
}
|
||||
|
||||
const GetOtpByCode = `-- name: GetOtpByCode :one
|
||||
SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at
|
||||
FROM otps
|
||||
WHERE user_id = $1
|
||||
AND otp = $2
|
||||
ORDER BY id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
type GetOtpByCodeParams struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Otp string `json:"otp"`
|
||||
}
|
||||
|
||||
type GetOtpByCodeRow struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
SentTo string `json:"sent_to"`
|
||||
Medium string `json:"medium"`
|
||||
OtpFor string `json:"otp_for"`
|
||||
Otp string `json:"otp"`
|
||||
Used bool `json:"used"`
|
||||
UsedAt pgtype.Timestamptz `json:"used_at"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetOtpByCode(ctx context.Context, arg GetOtpByCodeParams) (GetOtpByCodeRow, error) {
|
||||
row := q.db.QueryRow(ctx, GetOtpByCode, arg.UserID, arg.Otp)
|
||||
var i GetOtpByCodeRow
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.SentTo,
|
||||
&i.Medium,
|
||||
&i.OtpFor,
|
||||
&i.Otp,
|
||||
&i.Used,
|
||||
&i.UsedAt,
|
||||
&i.CreatedAt,
|
||||
&i.ExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const MarkOtpAsUsed = `-- name: MarkOtpAsUsed :exec
|
||||
UPDATE otps
|
||||
SET used = TRUE, used_at = $2
|
||||
|
|
|
|||
|
|
@ -94,4 +94,5 @@ type OtpStore interface {
|
|||
MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error
|
||||
CreateOtp(ctx context.Context, otp domain.Otp) error
|
||||
GetOtp(ctx context.Context, userID int64) (domain.Otp, error)
|
||||
GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,6 +67,36 @@ func (s *Store) GetOtp(ctx context.Context, userID int64) (domain.Otp, error) {
|
|||
ExpiresAt: row.ExpiresAt.Time,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error) {
|
||||
row, err := s.queries.GetOtpByCode(ctx, dbgen.GetOtpByCodeParams{
|
||||
UserID: userID,
|
||||
Otp: otpCode,
|
||||
})
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return domain.Otp{}, domain.ErrOtpNotFound
|
||||
}
|
||||
return domain.Otp{}, err
|
||||
}
|
||||
if !row.ExpiresAt.Valid {
|
||||
return domain.Otp{}, domain.ErrOtpNotFound
|
||||
}
|
||||
|
||||
return domain.Otp{
|
||||
ID: row.ID,
|
||||
UserID: row.UserID,
|
||||
SentTo: row.SentTo,
|
||||
Medium: domain.OtpMedium(row.Medium),
|
||||
For: domain.OtpFor(row.OtpFor),
|
||||
Otp: row.Otp,
|
||||
Used: row.Used,
|
||||
UsedAt: row.UsedAt.Time,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
ExpiresAt: row.ExpiresAt.Time,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error {
|
||||
return s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{
|
||||
ID: otp.ID,
|
||||
|
|
|
|||
|
|
@ -95,9 +95,13 @@ func (s *Service) VerifyOtp(
|
|||
return domain.LoginSuccess{}, err
|
||||
}
|
||||
|
||||
// 1. Retrieve OTP
|
||||
storedOtp, err := s.otpStore.GetOtp(ctx, user.ID)
|
||||
// 1. Retrieve OTP row matching submitted code.
|
||||
// This avoids false positives when another OTP row exists for the same user.
|
||||
storedOtp, err := s.otpStore.GetOtpByCode(ctx, user.ID, otpCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrOtpNotFound) {
|
||||
return domain.LoginSuccess{}, domain.ErrInvalidOtp
|
||||
}
|
||||
return domain.LoginSuccess{}, err
|
||||
}
|
||||
|
||||
|
|
@ -111,12 +115,7 @@ func (s *Service) VerifyOtp(
|
|||
return domain.LoginSuccess{}, domain.ErrOtpExpired
|
||||
}
|
||||
|
||||
// 4. Invalid
|
||||
if storedOtp.Otp != otpCode {
|
||||
return domain.LoginSuccess{}, domain.ErrInvalidOtp
|
||||
}
|
||||
|
||||
// 5. Mark OTP as used
|
||||
// 4. Mark OTP as used
|
||||
storedOtp.Used = true
|
||||
storedOtp.UsedAt = timePtr(time.Now())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user