Yimaru-BackEnd/internal/services/authentication/impl.go

177 lines
4.4 KiB
Go

package authentication
import (
"context"
"crypto/rand"
"encoding/base32"
"errors"
"time"
"Yimaru-Backend/internal/domain"
"golang.org/x/crypto/bcrypt"
)
var (
ErrInvalidPassword = errors.New("incorrect password")
ErrUserNotFound = errors.New("user not found")
ErrExpiredToken = errors.New("token expired")
ErrRefreshTokenNotFound = errors.New("refresh token not found") // i.e login again
ErrUserSuspended = errors.New("user has been suspended")
)
type LoginSuccess struct {
UserId int64
Role domain.Role
RfToken string
}
type LoginRequest struct {
Email string `json:"email"`
PhoneNumber string `json:"phone_number"`
Password string `json:"password"`
OTPCode string `json:"otp_code"`
}
func (s *Service) Login(
ctx context.Context,
req LoginRequest,
) (LoginSuccess, error) {
// Try to find user by username first
user, err := s.userStore.GetUserByEmailPhone(ctx, req.Email, req.PhoneNumber)
if err != nil {
// If not found by username, try email or phone lookup using the same identifier
return LoginSuccess{}, err
}
if user.Status == domain.UserStatusPending {
return LoginSuccess{}, domain.ErrUserNotVerified
}
// Status check instead of Suspended
if user.Status == domain.UserStatusSuspended {
return LoginSuccess{}, ErrUserSuspended
}
if req.Email != "" {
if err := matchPassword(req.Password, user.Password); err != nil {
return LoginSuccess{}, err
}
} else if req.PhoneNumber != "" {
if err := s.UserSvc.VerifyOtp(ctx, req.Email, req.PhoneNumber, req.OTPCode); err != nil {
return LoginSuccess{}, err
}
}
// Handle existing refresh token
oldRefreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user.ID)
if err != nil && !errors.Is(err, ErrRefreshTokenNotFound) {
return LoginSuccess{}, err
}
// Revoke if exists and not revoked
if err == nil && !oldRefreshToken.Revoked {
if err := s.tokenStore.RevokeRefreshToken(ctx, oldRefreshToken.Token); err != nil {
return LoginSuccess{}, err
}
}
// Generate new refresh token
refreshToken, err := generateRefreshToken()
if err != nil {
return LoginSuccess{}, err
}
if err := s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{
Token: refreshToken,
UserID: user.ID,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second),
}); err != nil {
return LoginSuccess{}, err
}
// Return login success payload
return LoginSuccess{
UserId: user.ID,
Role: user.Role,
RfToken: refreshToken,
}, nil
}
func (s *Service) RefreshToken(ctx context.Context, refToken string) (domain.RefreshToken, error) {
token, err := s.tokenStore.GetRefreshToken(ctx, refToken)
if err != nil {
return domain.RefreshToken{}, err
}
if token.Revoked {
return domain.RefreshToken{}, ErrRefreshTokenNotFound
}
if token.ExpiresAt.Before(time.Now()) {
return domain.RefreshToken{}, ErrExpiredToken
}
// newRefToken, err := generateRefreshToken()
// if err != nil {
// return "", err
// }
// err = s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{
// Token: newRefToken,
// UserID: token.UserID,
// CreatedAt: time.Now(),
// ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second),
// })
return token, nil
}
func (s *Service) GetLastLogin(ctx context.Context, user_id int64) (*time.Time, error) {
refreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user_id)
if err != nil {
return nil, err
}
return &refreshToken.CreatedAt, nil
}
func (s *Service) Logout(ctx context.Context, refToken string) error {
token, err := s.tokenStore.GetRefreshToken(ctx, refToken)
if err != nil {
return err
}
if token.Revoked {
return ErrRefreshTokenNotFound
}
if token.ExpiresAt.Before(time.Now()) {
return ErrExpiredToken
}
return s.tokenStore.RevokeRefreshToken(ctx, refToken)
}
func matchPassword(plaintextPassword string, hash []byte) error {
err := bcrypt.CompareHashAndPassword(hash, []byte(plaintextPassword))
if err != nil {
switch {
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
return ErrInvalidPassword
default:
return err
}
}
return err
}
func generateRefreshToken() (string, error) {
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
plaintext := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes)
return plaintext, nil
}