272 lines
6.5 KiB
Go
272 lines
6.5 KiB
Go
package authentication
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base32"
|
|
"errors"
|
|
"time"
|
|
|
|
"Yimaru-Backend/internal/domain"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidPassword = errors.New("incorrect password")
|
|
ErrUserNotFound = errors.New("user not found")
|
|
ErrExpiredToken = errors.New("token expired")
|
|
ErrRefreshTokenNotFound = errors.New("refresh token not found") // i.e login again
|
|
ErrUserSuspended = errors.New("user has been suspended")
|
|
)
|
|
|
|
func (s *Service) Login(
|
|
ctx context.Context,
|
|
req domain.LoginRequest,
|
|
) (domain.LoginSuccess, error) {
|
|
|
|
user, err := s.userStore.GetUserByEmailPhone(ctx, req.Email, req.PhoneNumber)
|
|
if err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
if user.Status == domain.UserStatusPending {
|
|
return domain.LoginSuccess{}, domain.ErrUserNotVerified
|
|
}
|
|
|
|
if user.Status == domain.UserStatusSuspended {
|
|
return domain.LoginSuccess{}, ErrUserSuspended
|
|
}
|
|
|
|
// Email + password login
|
|
if req.Email != "" {
|
|
if err := matchPassword(req.Password, user.Password); err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
oldRefreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user.ID)
|
|
if err != nil && !errors.Is(err, ErrRefreshTokenNotFound) {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
if err == nil && !oldRefreshToken.Revoked {
|
|
if err := s.tokenStore.RevokeRefreshToken(ctx, oldRefreshToken.Token); err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
}
|
|
|
|
refreshToken, err := generateRefreshToken()
|
|
if err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
if err := s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{
|
|
Token: refreshToken,
|
|
UserID: user.ID,
|
|
CreatedAt: time.Now(),
|
|
ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second),
|
|
}); err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
return domain.LoginSuccess{
|
|
UserId: user.ID,
|
|
Role: user.Role,
|
|
RfToken: refreshToken,
|
|
}, nil
|
|
}
|
|
|
|
// Phone + OTP login
|
|
if req.PhoneNumber != "" {
|
|
return s.VerifyOtp(ctx, req.Email, req.PhoneNumber, req.OTPCode)
|
|
}
|
|
|
|
// ❗ Mandatory fallback return
|
|
return domain.LoginSuccess{}, ErrInvalidPassword
|
|
}
|
|
|
|
func (s *Service) VerifyOtp(
|
|
ctx context.Context,
|
|
email, phone, otpCode string,
|
|
) (domain.LoginSuccess, error) {
|
|
|
|
user, err := s.userStore.GetUserByEmailPhone(ctx, email, phone)
|
|
if err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
// 1. Retrieve OTP
|
|
storedOtp, err := s.otpStore.GetOtp(ctx, user.ID)
|
|
if err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
// 2. Already used
|
|
if storedOtp.Used {
|
|
return domain.LoginSuccess{}, domain.ErrOtpAlreadyUsed
|
|
}
|
|
|
|
// 3. Expired
|
|
if time.Now().After(storedOtp.ExpiresAt) {
|
|
return domain.LoginSuccess{}, domain.ErrOtpExpired
|
|
}
|
|
|
|
// 4. Invalid
|
|
if storedOtp.Otp != otpCode {
|
|
return domain.LoginSuccess{}, domain.ErrInvalidOtp
|
|
}
|
|
|
|
// 5. Mark OTP as used
|
|
storedOtp.Used = true
|
|
storedOtp.UsedAt = timePtr(time.Now())
|
|
|
|
if err := s.otpStore.MarkOtpAsUsed(ctx, storedOtp); err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
// 6. Activate user if still pending
|
|
if user.Status == domain.UserStatusPending {
|
|
if err := s.userStore.UpdateUserStatus(ctx, domain.UpdateUserStatusReq{
|
|
UserID: user.ID,
|
|
Status: string(domain.UserStatusActive),
|
|
}); err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
}
|
|
|
|
// 7. Handle existing refresh token
|
|
oldRefreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user.ID)
|
|
if err != nil && !errors.Is(err, ErrRefreshTokenNotFound) {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
if err == nil && !oldRefreshToken.Revoked {
|
|
if err := s.tokenStore.RevokeRefreshToken(ctx, oldRefreshToken.Token); err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
}
|
|
|
|
// 8. Generate new refresh token
|
|
refreshToken, err := generateRefreshToken()
|
|
if err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
if err := s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{
|
|
Token: refreshToken,
|
|
UserID: user.ID,
|
|
CreatedAt: time.Now(),
|
|
ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second),
|
|
}); err != nil {
|
|
return domain.LoginSuccess{}, err
|
|
}
|
|
|
|
// 9. Return success payload
|
|
return domain.LoginSuccess{
|
|
UserId: user.ID,
|
|
Role: user.Role,
|
|
RfToken: refreshToken,
|
|
}, nil
|
|
}
|
|
|
|
// helper function to get a pointer to time.Time
|
|
func timePtr(t time.Time) time.Time {
|
|
return t
|
|
}
|
|
|
|
func (s *Service) RefreshToken(
|
|
ctx context.Context,
|
|
refToken string,
|
|
) (domain.RefreshToken, error) {
|
|
|
|
// 1. Load refresh token
|
|
token, err := s.tokenStore.GetRefreshToken(ctx, refToken)
|
|
if err != nil {
|
|
return domain.RefreshToken{}, err
|
|
}
|
|
|
|
// 2. Validate token
|
|
if token.Revoked {
|
|
return domain.RefreshToken{}, ErrRefreshTokenNotFound
|
|
}
|
|
|
|
if token.ExpiresAt.Before(time.Now()) {
|
|
return domain.RefreshToken{}, ErrExpiredToken
|
|
}
|
|
|
|
// 3. Revoke old token (single-use guarantee)
|
|
if err := s.tokenStore.RevokeRefreshToken(ctx, refToken); err != nil {
|
|
return domain.RefreshToken{}, err
|
|
}
|
|
|
|
// 4. Generate new refresh token
|
|
newRefreshToken, err := generateRefreshToken()
|
|
if err != nil {
|
|
return domain.RefreshToken{}, err
|
|
}
|
|
|
|
newToken := domain.RefreshToken{
|
|
Token: newRefreshToken,
|
|
UserID: token.UserID,
|
|
CreatedAt: time.Now(),
|
|
ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second),
|
|
}
|
|
|
|
// 5. Persist new token
|
|
if err := s.tokenStore.CreateRefreshToken(ctx, newToken); err != nil {
|
|
return domain.RefreshToken{}, err
|
|
}
|
|
|
|
// 6. Return new token
|
|
return newToken, nil
|
|
}
|
|
|
|
func (s *Service) GetLastLogin(ctx context.Context, user_id int64) (*time.Time, error) {
|
|
refreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user_id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &refreshToken.CreatedAt, nil
|
|
|
|
}
|
|
|
|
func (s *Service) Logout(ctx context.Context, refToken string) error {
|
|
token, err := s.tokenStore.GetRefreshToken(ctx, refToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if token.Revoked {
|
|
return ErrRefreshTokenNotFound
|
|
}
|
|
if token.ExpiresAt.Before(time.Now()) {
|
|
return ErrExpiredToken
|
|
}
|
|
|
|
return s.tokenStore.RevokeRefreshToken(ctx, refToken)
|
|
}
|
|
|
|
func matchPassword(plaintextPassword string, hash []byte) error {
|
|
err := bcrypt.CompareHashAndPassword(hash, []byte(plaintextPassword))
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
|
|
return ErrInvalidPassword
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func generateRefreshToken() (string, error) {
|
|
randomBytes := make([]byte, 32)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
plaintext := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes)
|
|
return plaintext, nil
|
|
}
|