package authentication import ( "context" "crypto/rand" "encoding/base32" "errors" "time" "Yimaru-Backend/internal/domain" "golang.org/x/crypto/bcrypt" ) var ( ErrInvalidPassword = errors.New("incorrect password") ErrUserNotFound = errors.New("user not found") ErrExpiredToken = errors.New("token expired") ErrRefreshTokenNotFound = errors.New("refresh token not found") // i.e login again ErrUserSuspended = errors.New("user has been suspended") ) type LoginSuccess struct { UserId int64 Role domain.Role RfToken string } type LoginRequest struct { Email string `json:"email"` PhoneNumber string `json:"phone_number"` Password string `json:"password"` OTPCode string `json:"otp_code"` } func (s *Service) Login( ctx context.Context, req LoginRequest, ) (LoginSuccess, error) { // Try to find user by username first user, err := s.userStore.GetUserByEmailPhone(ctx, req.Email, req.PhoneNumber) if err != nil { // If not found by username, try email or phone lookup using the same identifier return LoginSuccess{}, err } if user.Status == domain.UserStatusPending { return LoginSuccess{}, domain.ErrUserNotVerified } // Status check instead of Suspended if user.Status == domain.UserStatusSuspended { return LoginSuccess{}, ErrUserSuspended } if req.Email != "" { if err := matchPassword(req.Password, user.Password); err != nil { return LoginSuccess{}, err } } else if req.PhoneNumber != "" { if err := s.UserSvc.VerifyOtp(ctx, req.Email, req.PhoneNumber, req.OTPCode); err != nil { return LoginSuccess{}, err } } // Handle existing refresh token oldRefreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user.ID) if err != nil && !errors.Is(err, ErrRefreshTokenNotFound) { return LoginSuccess{}, err } // Revoke if exists and not revoked if err == nil && !oldRefreshToken.Revoked { if err := s.tokenStore.RevokeRefreshToken(ctx, oldRefreshToken.Token); err != nil { return LoginSuccess{}, err } } // Generate new refresh token refreshToken, err := generateRefreshToken() if err != nil { return LoginSuccess{}, err } if err := s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{ Token: refreshToken, UserID: user.ID, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second), }); err != nil { return LoginSuccess{}, err } // Return login success payload return LoginSuccess{ UserId: user.ID, Role: user.Role, RfToken: refreshToken, }, nil } func (s *Service) RefreshToken(ctx context.Context, refToken string) (domain.RefreshToken, error) { token, err := s.tokenStore.GetRefreshToken(ctx, refToken) if err != nil { return domain.RefreshToken{}, err } if token.Revoked { return domain.RefreshToken{}, ErrRefreshTokenNotFound } if token.ExpiresAt.Before(time.Now()) { return domain.RefreshToken{}, ErrExpiredToken } // newRefToken, err := generateRefreshToken() // if err != nil { // return "", err // } // err = s.tokenStore.CreateRefreshToken(ctx, domain.RefreshToken{ // Token: newRefToken, // UserID: token.UserID, // CreatedAt: time.Now(), // ExpiresAt: time.Now().Add(time.Duration(s.RefreshExpiry) * time.Second), // }) return token, nil } func (s *Service) GetLastLogin(ctx context.Context, user_id int64) (*time.Time, error) { refreshToken, err := s.tokenStore.GetRefreshTokenByUserID(ctx, user_id) if err != nil { return nil, err } return &refreshToken.CreatedAt, nil } func (s *Service) Logout(ctx context.Context, refToken string) error { token, err := s.tokenStore.GetRefreshToken(ctx, refToken) if err != nil { return err } if token.Revoked { return ErrRefreshTokenNotFound } if token.ExpiresAt.Before(time.Now()) { return ErrExpiredToken } return s.tokenStore.RevokeRefreshToken(ctx, refToken) } func matchPassword(plaintextPassword string, hash []byte) error { err := bcrypt.CompareHashAndPassword(hash, []byte(plaintextPassword)) if err != nil { switch { case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): return ErrInvalidPassword default: return err } } return err } func generateRefreshToken() (string, error) { randomBytes := make([]byte, 32) _, err := rand.Read(randomBytes) if err != nil { return "", err } plaintext := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes) return plaintext, nil }