Skip to content

add async Retry __eq__ and __hash__ & fix ExponentialWithJitterBackoff __eq__ #3668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 8 additions & 50 deletions redis/asyncio/retry.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,20 @@
from asyncio import sleep
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
from typing import Any, Awaitable, Callable, Tuple, Type, TypeVar

from redis.exceptions import ConnectionError, RedisError, TimeoutError

if TYPE_CHECKING:
from redis.backoff import AbstractBackoff

from redis.retry import AbstractRetry

T = TypeVar("T")


class Retry:
"""Retry a specific number of times after a failure"""

__slots__ = "_backoff", "_retries", "_supported_errors"

def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[RedisError], ...] = (
ConnectionError,
TimeoutError,
),
):
"""
Initialize a `Retry` object with a `Backoff` object
that retries a maximum of `retries` times.
`retries` can be negative to retry forever.
You can specify the types of supported errors which trigger
a retry with the `supported_errors` parameter.
"""
self._backoff = backoff
self._retries = retries
self._supported_errors = supported_errors

def update_supported_errors(self, specified_errors: list):
"""
Updates the supported errors with the specified error types
"""
self._supported_errors = tuple(
set(self._supported_errors + tuple(specified_errors))
)

def get_retries(self) -> int:
"""
Get the number of retries.
"""
return self._retries

def update_retries(self, value: int) -> None:
"""
Set the number of retries.
"""
self._retries = value
class Retry(AbstractRetry):
_supported_errors: Tuple[Type[RedisError], ...] = (
ConnectionError,
TimeoutError,
)

async def call_with_retry(
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
self, do: Callable[[], Awaitable[T]], fail: Callable[[Exception], Any]
) -> T:
"""
Execute an operation that might fail and returns its result, or
Expand Down
2 changes: 1 addition & 1 deletion redis/backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __hash__(self) -> int:
return hash((self._base, self._cap))

def __eq__(self, other) -> bool:
if not isinstance(other, EqualJitterBackoff):
if not isinstance(other, ExponentialWithJitterBackoff):
return NotImplemented

return self._base == other._base and self._cap == other._cap
Expand Down
26 changes: 17 additions & 9 deletions redis/retry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import socket
from time import sleep
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Union

from redis.exceptions import ConnectionError, TimeoutError

Expand All @@ -10,18 +10,17 @@
from redis.backoff import AbstractBackoff


class Retry:
class AbstractRetry:
"""Retry a specific number of times after a failure"""

__slots__ = "_backoff", "_retries", "_supported_errors"
_supported_errors: Tuple[Type[Exception], ...]

def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[Exception], ...] = (
ConnectionError,
TimeoutError,
socket.timeout,
),
supported_errors: Union[Tuple[Type[Exception], ...], None] = None,
):
"""
Initialize a `Retry` object with a `Backoff` object
Expand All @@ -32,10 +31,11 @@ def __init__(
"""
self._backoff = backoff
self._retries = retries
self._supported_errors = supported_errors
if supported_errors:
self._supported_errors = supported_errors

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
if not isinstance(other, AbstractRetry):
return NotImplemented

return (
Expand Down Expand Up @@ -69,6 +69,14 @@ def update_retries(self, value: int) -> None:
"""
self._retries = value


class Retry(AbstractRetry):
_supported_errors: Tuple[Type[Exception], ...] = (
ConnectionError,
TimeoutError,
socket.timeout,
)

def call_with_retry(
self,
do: Callable[[], T],
Expand Down
22 changes: 17 additions & 5 deletions tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch

import pytest
from redis.asyncio.retry import Retry as AsyncRetry
from redis.backoff import (
AbstractBackoff,
ConstantBackoff,
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_retry_on_error_retry(self, Class, retries):
assert c.retry._retries == retries


@pytest.mark.parametrize("retry_class", [Retry, AsyncRetry])
@pytest.mark.parametrize(
"args",
[
Expand All @@ -108,8 +110,8 @@ def test_retry_on_error_retry(self, Class, retries):
for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5))
],
)
def test_retry_eq_and_hashable(args):
assert Retry(*args) == Retry(*args)
def test_retry_eq_and_hashable(retry_class, args):
assert retry_class(*args) == retry_class(*args)

# create another retry object with different parameters
copy = list(args)
Expand All @@ -118,9 +120,19 @@ def test_retry_eq_and_hashable(args):
else:
copy[0] = ConstantBackoff(9000)

assert Retry(*args) != Retry(*copy)
assert Retry(*copy) != Retry(*args)
assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2
assert retry_class(*args) != retry_class(*copy)
assert retry_class(*copy) != retry_class(*args)
assert (
len(
{
retry_class(*args),
retry_class(*args),
retry_class(*copy),
retry_class(*copy),
}
)
== 2
)


class TestRetry:
Expand Down
Loading