diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 3a8d88c..547e072 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -110,9 +110,9 @@ func (s *AuthServiceImpl) ValidateToken(tokenString string) (*contract.UserRespo } func (s *AuthServiceImpl) RefreshToken(ctx context.Context, tokenString string) (*contract.LoginResponse, error) { - claims, err := s.parseToken(tokenString) + claims, err := s.parseRefreshToken(tokenString) if err != nil { - return nil, fmt.Errorf("invalid token: %w", err) + return nil, fmt.Errorf("invalid refresh token: %w", err) } userResponse, err := s.userProcessor.GetUserByID(ctx, claims.UserID) @@ -227,3 +227,26 @@ func (s *AuthServiceImpl) parseToken(tokenString string) (*Claims, error) { return nil, errors.New("invalid token") } + +func (s *AuthServiceImpl) parseRefreshToken(tokenString string) (*Claims, error) { + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(s.refreshSecret), nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*Claims); ok && token.Valid { + // Verify this is a refresh token by checking the issuer + if claims.Issuer != "apskel-pos-refresh" { + return nil, errors.New("not a valid refresh token") + } + return claims, nil + } + + return nil, errors.New("invalid refresh token") +}