package authentication import ( "context" "crypto/rand" "encoding/base32" "errors" "time" "Yimaru-Backend/internal/domain" "golang.org/x/crypto/bcrypt" ) var ( ErrInvalidPassword = errors.New("incorrect password") ErrUserNotFound = errors.New("user not found") ErrExpiredToken = errors.New("token expired") ErrRefreshTokenNotFound = errors.New("refresh token not found") // i.e login again ErrUserSuspended = errors.New("user has been suspended") ) func (s *Service) Login( ctx context.Context, req domain.LoginRequest, ) (domain.LoginSuccess, error) { user, err := s.userStore.GetUserByEmailPhone(ctx, req.Email, req.PhoneNumber) if err != nil { return domain.LoginSuccess{}, err } if user.Status == domain.UserStatusPending { return domain.LoginSuccess{}, domain.ErrUserNotVerified } if user.Status == domain.UserStatusSuspended { return domain.LoginSuccess{}, ErrUserSuspended } // Email + password login if req.Email != "" { if err := matchPassword(req.Password, user.Password); err != nil { return domain.LoginSuccess{}, err } oldRefreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user.ID) if err != nil && !errors.Is(err, ErrRefreshTokenNotFound) { return domain.LoginSuccess{}, err } if err == nil && !oldRefreshToken.Revoked { if err := s.tokenStore.RevokeRefreshToken(ctx, oldRefreshToken.Token); err != nil { return domain.LoginSuccess{}, err } } refreshToken, err := generateRefreshToken() if err != nil { return domain.LoginSuccess{}, err } if err := s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{ Token: refreshToken, UserID: user.ID, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second), }); err != nil { return domain.LoginSuccess{}, err } return domain.LoginSuccess{ UserId: user.ID, Role: user.Role, RfToken: refreshToken, }, nil } // Phone + OTP login if req.PhoneNumber != "" { return s.VerifyOtp(ctx, req.Email, req.PhoneNumber, req.OTPCode) } // ❗ Mandatory fallback return return domain.LoginSuccess{}, ErrInvalidPassword } func (s *Service) VerifyOtp( ctx context.Context, email, phone, otpCode string, ) (domain.LoginSuccess, error) { user, err := s.userStore.GetUserByEmailPhone(ctx, email, phone) if err != nil { return domain.LoginSuccess{}, err } // 1. Retrieve OTP storedOtp, err := s.otpStore.GetOtp(ctx, user.ID) if err != nil { return domain.LoginSuccess{}, err } // 2. Already used if storedOtp.Used { return domain.LoginSuccess{}, domain.ErrOtpAlreadyUsed } // 3. Expired if time.Now().After(storedOtp.ExpiresAt) { return domain.LoginSuccess{}, domain.ErrOtpExpired } // 4. Invalid if storedOtp.Otp != otpCode { return domain.LoginSuccess{}, domain.ErrInvalidOtp } // 5. Mark OTP as used storedOtp.Used = true storedOtp.UsedAt = timePtr(time.Now()) if err := s.otpStore.MarkOtpAsUsed(ctx, storedOtp); err != nil { return domain.LoginSuccess{}, err } // 6. Activate user if still pending if user.Status == domain.UserStatusPending { if err := s.userStore.UpdateUserStatus(ctx, domain.UpdateUserStatusReq{ UserID: user.ID, Status: string(domain.UserStatusActive), }); err != nil { return domain.LoginSuccess{}, err } } // 7. Handle existing refresh token oldRefreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user.ID) if err != nil && !errors.Is(err, ErrRefreshTokenNotFound) { return domain.LoginSuccess{}, err } if err == nil && !oldRefreshToken.Revoked { if err := s.tokenStore.RevokeRefreshToken(ctx, oldRefreshToken.Token); err != nil { return domain.LoginSuccess{}, err } } // 8. Generate new refresh token refreshToken, err := generateRefreshToken() if err != nil { return domain.LoginSuccess{}, err } if err := s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{ Token: refreshToken, UserID: user.ID, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second), }); err != nil { return domain.LoginSuccess{}, err } // 9. Return success payload return domain.LoginSuccess{ UserId: user.ID, Role: user.Role, RfToken: refreshToken, }, nil } // helper function to get a pointer to time.Time func timePtr(t time.Time) time.Time { return t } func (s *Service) RefreshToken( ctx context.Context, refToken string, ) (domain.RefreshToken, error) { // 1. Load refresh token token, err := s.tokenStore.GetRefreshToken(ctx, refToken) if err != nil { return domain.RefreshToken{}, err } // 2. Validate token if token.Revoked { return domain.RefreshToken{}, ErrRefreshTokenNotFound } if token.ExpiresAt.Before(time.Now()) { return domain.RefreshToken{}, ErrExpiredToken } // 3. Revoke old token (single-use guarantee) if err := s.tokenStore.RevokeRefreshToken(ctx, refToken); err != nil { return domain.RefreshToken{}, err } // 4. Generate new refresh token newRefreshToken, err := generateRefreshToken() if err != nil { return domain.RefreshToken{}, err } newToken := domain.RefreshToken{ Token: newRefreshToken, UserID: token.UserID, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second), } // 5. Persist new token if err := s.tokenStore.CreateRefreshToken(ctx, newToken); err != nil { return domain.RefreshToken{}, err } // 6. Return new token return newToken, nil } func (s *Service) GetLastLogin(ctx context.Context, user_id int64) (*time.Time, error) { refreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user_id) if err != nil { return nil, err } return &refreshToken.CreatedAt, nil } func (s *Service) Logout(ctx context.Context, refToken string) error { token, err := s.tokenStore.GetRefreshToken(ctx, refToken) if err != nil { return err } if token.Revoked { return ErrRefreshTokenNotFound } if token.ExpiresAt.Before(time.Now()) { return ErrExpiredToken } return s.tokenStore.RevokeRefreshToken(ctx, refToken) } func matchPassword(plaintextPassword string, hash []byte) error { err := bcrypt.CompareHashAndPassword(hash, []byte(plaintextPassword)) if err != nil { switch { case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): return ErrInvalidPassword default: return err } } return err } func generateRefreshToken() (string, error) { randomBytes := make([]byte, 32) _, err := rand.Read(randomBytes) if err != nil { return "", err } plaintext := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes) return plaintext, nil }