fix OTP verification by submitted code
Resolve false OTP already used/expired responses during registration by loading OTP rows using user_id plus submitted otp code and validating usage/expiry on the matched row. Made-with: Cursor
This commit is contained in:
parent
526426d9f9
commit
78f231f222
|
|
@ -25,6 +25,14 @@ WHERE user_id = $1
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
|
|
||||||
|
-- name: GetOtpByCode :one
|
||||||
|
SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at
|
||||||
|
FROM otps
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND otp = $2
|
||||||
|
ORDER BY id DESC
|
||||||
|
LIMIT 1;
|
||||||
|
|
||||||
-- name: MarkOtpAsUsed :exec
|
-- name: MarkOtpAsUsed :exec
|
||||||
UPDATE otps
|
UPDATE otps
|
||||||
SET used = TRUE, used_at = $2
|
SET used = TRUE, used_at = $2
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,51 @@ func (q *Queries) GetOtp(ctx context.Context, userID int64) (GetOtpRow, error) {
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const GetOtpByCode = `-- name: GetOtpByCode :one
|
||||||
|
SELECT id, user_id, sent_to, medium, otp_for, otp, used, used_at, created_at, expires_at
|
||||||
|
FROM otps
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND otp = $2
|
||||||
|
ORDER BY id DESC
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetOtpByCodeParams struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Otp string `json:"otp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetOtpByCodeRow struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
SentTo string `json:"sent_to"`
|
||||||
|
Medium string `json:"medium"`
|
||||||
|
OtpFor string `json:"otp_for"`
|
||||||
|
Otp string `json:"otp"`
|
||||||
|
Used bool `json:"used"`
|
||||||
|
UsedAt pgtype.Timestamptz `json:"used_at"`
|
||||||
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
|
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetOtpByCode(ctx context.Context, arg GetOtpByCodeParams) (GetOtpByCodeRow, error) {
|
||||||
|
row := q.db.QueryRow(ctx, GetOtpByCode, arg.UserID, arg.Otp)
|
||||||
|
var i GetOtpByCodeRow
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.SentTo,
|
||||||
|
&i.Medium,
|
||||||
|
&i.OtpFor,
|
||||||
|
&i.Otp,
|
||||||
|
&i.Used,
|
||||||
|
&i.UsedAt,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const MarkOtpAsUsed = `-- name: MarkOtpAsUsed :exec
|
const MarkOtpAsUsed = `-- name: MarkOtpAsUsed :exec
|
||||||
UPDATE otps
|
UPDATE otps
|
||||||
SET used = TRUE, used_at = $2
|
SET used = TRUE, used_at = $2
|
||||||
|
|
|
||||||
|
|
@ -94,4 +94,5 @@ type OtpStore interface {
|
||||||
MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error
|
MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error
|
||||||
CreateOtp(ctx context.Context, otp domain.Otp) error
|
CreateOtp(ctx context.Context, otp domain.Otp) error
|
||||||
GetOtp(ctx context.Context, userID int64) (domain.Otp, error)
|
GetOtp(ctx context.Context, userID int64) (domain.Otp, error)
|
||||||
|
GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,36 @@ func (s *Store) GetOtp(ctx context.Context, userID int64) (domain.Otp, error) {
|
||||||
ExpiresAt: row.ExpiresAt.Time,
|
ExpiresAt: row.ExpiresAt.Time,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOtpByCode(ctx context.Context, userID int64, otpCode string) (domain.Otp, error) {
|
||||||
|
row, err := s.queries.GetOtpByCode(ctx, dbgen.GetOtpByCodeParams{
|
||||||
|
UserID: userID,
|
||||||
|
Otp: otpCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return domain.Otp{}, domain.ErrOtpNotFound
|
||||||
|
}
|
||||||
|
return domain.Otp{}, err
|
||||||
|
}
|
||||||
|
if !row.ExpiresAt.Valid {
|
||||||
|
return domain.Otp{}, domain.ErrOtpNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return domain.Otp{
|
||||||
|
ID: row.ID,
|
||||||
|
UserID: row.UserID,
|
||||||
|
SentTo: row.SentTo,
|
||||||
|
Medium: domain.OtpMedium(row.Medium),
|
||||||
|
For: domain.OtpFor(row.OtpFor),
|
||||||
|
Otp: row.Otp,
|
||||||
|
Used: row.Used,
|
||||||
|
UsedAt: row.UsedAt.Time,
|
||||||
|
CreatedAt: row.CreatedAt.Time,
|
||||||
|
ExpiresAt: row.ExpiresAt.Time,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error {
|
func (s *Store) MarkOtpAsUsed(ctx context.Context, otp domain.Otp) error {
|
||||||
return s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{
|
return s.queries.MarkOtpAsUsed(ctx, dbgen.MarkOtpAsUsedParams{
|
||||||
ID: otp.ID,
|
ID: otp.ID,
|
||||||
|
|
|
||||||
|
|
@ -95,9 +95,13 @@ func (s *Service) VerifyOtp(
|
||||||
return domain.LoginSuccess{}, err
|
return domain.LoginSuccess{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. Retrieve OTP
|
// 1. Retrieve OTP row matching submitted code.
|
||||||
storedOtp, err := s.otpStore.GetOtp(ctx, user.ID)
|
// This avoids false positives when another OTP row exists for the same user.
|
||||||
|
storedOtp, err := s.otpStore.GetOtpByCode(ctx, user.ID, otpCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, domain.ErrOtpNotFound) {
|
||||||
|
return domain.LoginSuccess{}, domain.ErrInvalidOtp
|
||||||
|
}
|
||||||
return domain.LoginSuccess{}, err
|
return domain.LoginSuccess{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -111,12 +115,7 @@ func (s *Service) VerifyOtp(
|
||||||
return domain.LoginSuccess{}, domain.ErrOtpExpired
|
return domain.LoginSuccess{}, domain.ErrOtpExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Invalid
|
// 4. Mark OTP as used
|
||||||
if storedOtp.Otp != otpCode {
|
|
||||||
return domain.LoginSuccess{}, domain.ErrInvalidOtp
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. Mark OTP as used
|
|
||||||
storedOtp.Used = true
|
storedOtp.Used = true
|
||||||
storedOtp.UsedAt = timePtr(time.Now())
|
storedOtp.UsedAt = timePtr(time.Now())
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user