Skip to content

Commit ae9a063

Browse files
committed
create abstract retry class to share methods between retry implementations
1 parent 56a251c commit ae9a063

File tree

3 files changed

+42
-64
lines changed

3 files changed

+42
-64
lines changed

redis/asyncio/retry.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,20 @@
11
from asyncio import sleep
2-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
2+
from typing import Any, Awaitable, Callable, Tuple, Type, TypeVar
33

44
from redis.exceptions import ConnectionError, RedisError, TimeoutError
5-
6-
if TYPE_CHECKING:
7-
from redis.backoff import AbstractBackoff
8-
5+
from redis.retry import AbstractRetry
96

107
T = TypeVar("T")
118

129

13-
class Retry:
14-
"""Retry a specific number of times after a failure"""
15-
16-
__slots__ = "_backoff", "_retries", "_supported_errors"
17-
18-
def __init__(
19-
self,
20-
backoff: "AbstractBackoff",
21-
retries: int,
22-
supported_errors: Tuple[Type[RedisError], ...] = (
23-
ConnectionError,
24-
TimeoutError,
25-
),
26-
):
27-
"""
28-
Initialize a `Retry` object with a `Backoff` object
29-
that retries a maximum of `retries` times.
30-
`retries` can be negative to retry forever.
31-
You can specify the types of supported errors which trigger
32-
a retry with the `supported_errors` parameter.
33-
"""
34-
self._backoff = backoff
35-
self._retries = retries
36-
self._supported_errors = supported_errors
37-
38-
def update_supported_errors(self, specified_errors: list):
39-
"""
40-
Updates the supported errors with the specified error types
41-
"""
42-
self._supported_errors = tuple(
43-
set(self._supported_errors + tuple(specified_errors))
44-
)
45-
46-
def get_retries(self) -> int:
47-
"""
48-
Get the number of retries.
49-
"""
50-
return self._retries
51-
52-
def update_retries(self, value: int) -> None:
53-
"""
54-
Set the number of retries.
55-
"""
56-
self._retries = value
10+
class Retry(AbstractRetry):
11+
_supported_errors: Tuple[Type[RedisError], ...] = (
12+
ConnectionError,
13+
TimeoutError,
14+
)
5715

5816
async def call_with_retry(
59-
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
17+
self, do: Callable[[], Awaitable[T]], fail: Callable[[Exception], Any]
6018
) -> T:
6119
"""
6220
Execute an operation that might fail and returns its result, or

redis/retry.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import socket
22
from time import sleep
3-
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar
3+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Union
44

55
from redis.exceptions import ConnectionError, TimeoutError
66

@@ -10,18 +10,17 @@
1010
from redis.backoff import AbstractBackoff
1111

1212

13-
class Retry:
13+
class AbstractRetry:
1414
"""Retry a specific number of times after a failure"""
1515

16+
__slots__ = "_backoff", "_retries", "_supported_errors"
17+
_supported_errors: Tuple[Type[Exception], ...]
18+
1619
def __init__(
1720
self,
1821
backoff: "AbstractBackoff",
1922
retries: int,
20-
supported_errors: Tuple[Type[Exception], ...] = (
21-
ConnectionError,
22-
TimeoutError,
23-
socket.timeout,
24-
),
23+
supported_errors: Union[Tuple[Type[Exception], ...], None] = None,
2524
):
2625
"""
2726
Initialize a `Retry` object with a `Backoff` object
@@ -32,10 +31,11 @@ def __init__(
3231
"""
3332
self._backoff = backoff
3433
self._retries = retries
35-
self._supported_errors = supported_errors
34+
if supported_errors:
35+
self._supported_errors = supported_errors
3636

3737
def __eq__(self, other: Any) -> bool:
38-
if not isinstance(other, Retry):
38+
if not isinstance(other, AbstractRetry):
3939
return NotImplemented
4040

4141
return (
@@ -69,6 +69,14 @@ def update_retries(self, value: int) -> None:
6969
"""
7070
self._retries = value
7171

72+
73+
class Retry(AbstractRetry):
74+
_supported_errors: Tuple[Type[Exception], ...] = (
75+
ConnectionError,
76+
TimeoutError,
77+
socket.timeout,
78+
)
79+
7280
def call_with_retry(
7381
self,
7482
do: Callable[[], T],

tests/test_retry.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import patch
22

33
import pytest
4+
from redis.asyncio.retry import Retry as AsyncRetry
45
from redis.backoff import (
56
AbstractBackoff,
67
ConstantBackoff,
@@ -89,6 +90,7 @@ def test_retry_on_error_retry(self, Class, retries):
8990
assert c.retry._retries == retries
9091

9192

93+
@pytest.mark.parametrize("retry_class", [Retry, AsyncRetry])
9294
@pytest.mark.parametrize(
9395
"args",
9496
[
@@ -108,8 +110,8 @@ def test_retry_on_error_retry(self, Class, retries):
108110
for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5))
109111
],
110112
)
111-
def test_retry_eq_and_hashable(args):
112-
assert Retry(*args) == Retry(*args)
113+
def test_retry_eq_and_hashable(retry_class, args):
114+
assert retry_class(*args) == retry_class(*args)
113115

114116
# create another retry object with different parameters
115117
copy = list(args)
@@ -118,9 +120,19 @@ def test_retry_eq_and_hashable(args):
118120
else:
119121
copy[0] = ConstantBackoff(9000)
120122

121-
assert Retry(*args) != Retry(*copy)
122-
assert Retry(*copy) != Retry(*args)
123-
assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2
123+
assert retry_class(*args) != retry_class(*copy)
124+
assert retry_class(*copy) != retry_class(*args)
125+
assert (
126+
len(
127+
{
128+
retry_class(*args),
129+
retry_class(*args),
130+
retry_class(*copy),
131+
retry_class(*copy),
132+
}
133+
)
134+
== 2
135+
)
124136

125137

126138
class TestRetry:

0 commit comments

Comments
 (0)