Skip to content

Commit 461a3f5

Browse files
authored
chore: Modify exponential backoff implementation to have no initial sleep (#1547)
chore: No sleep on initial attempt in exponential backoff implementation It is unintuitive that the initial attempt in the exponential backoff loop sleeps. This can lead to subtle bugs in future call sites. This patch refactors the exponential backoff to begin sleeping on the 2nd iteration so requests can be done in a single for loop.
1 parent 8338594 commit 461a3f5

File tree

7 files changed

+56
-43
lines changed

7 files changed

+56
-43
lines changed

google/auth/_exponential_backoff.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import random
1616
import time
1717

18+
from google.auth import exceptions
19+
1820
# The default amount of retry attempts
1921
_DEFAULT_RETRY_TOTAL_ATTEMPTS = 3
2022

@@ -68,6 +70,11 @@ def __init__(
6870
randomization_factor=_DEFAULT_RANDOMIZATION_FACTOR,
6971
multiplier=_DEFAULT_MULTIPLIER,
7072
):
73+
if total_attempts < 1:
74+
raise exceptions.InvalidValue(
75+
f"total_attempts must be greater than or equal to 1 but was {total_attempts}"
76+
)
77+
7178
self._total_attempts = total_attempts
7279
self._initial_wait_seconds = initial_wait_seconds
7380

@@ -87,6 +94,9 @@ def __next__(self):
8794
raise StopIteration
8895
self._backoff_count += 1
8996

97+
if self._backoff_count <= 1:
98+
return self._backoff_count
99+
90100
jitter_variance = self._current_wait_in_seconds * self._randomization_factor
91101
jitter = random.uniform(
92102
self._current_wait_in_seconds - jitter_variance,

google/oauth2/_client.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,11 @@ def _token_endpoint_request_no_throw(
183183
if headers:
184184
headers_to_use.update(headers)
185185

186-
def _perform_request():
186+
response_data = {}
187+
retryable_error = False
188+
189+
retries = _exponential_backoff.ExponentialBackoff()
190+
for _ in retries:
187191
response = request(
188192
method="POST", url=token_uri, headers=headers_to_use, body=body, **kwargs
189193
)
@@ -192,7 +196,7 @@ def _perform_request():
192196
if hasattr(response.data, "decode")
193197
else response.data
194198
)
195-
response_data = ""
199+
196200
try:
197201
# response_body should be a JSON
198202
response_data = json.loads(response_body)
@@ -206,18 +210,8 @@ def _perform_request():
206210
status_code=response.status, response_data=response_data
207211
)
208212

209-
return False, response_data, retryable_error
210-
211-
request_succeeded, response_data, retryable_error = _perform_request()
212-
213-
if request_succeeded or not retryable_error or not can_retry:
214-
return request_succeeded, response_data, retryable_error
215-
216-
retries = _exponential_backoff.ExponentialBackoff()
217-
for _ in retries:
218-
request_succeeded, response_data, retryable_error = _perform_request()
219-
if request_succeeded or not retryable_error:
220-
return request_succeeded, response_data, retryable_error
213+
if not can_retry or not retryable_error:
214+
return False, response_data, retryable_error
221215

222216
return False, response_data, retryable_error
223217

google/oauth2/_client_async.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ async def _token_endpoint_request_no_throw(
6767
if access_token:
6868
headers["Authorization"] = "Bearer {}".format(access_token)
6969

70-
async def _perform_request():
70+
response_data = {}
71+
retryable_error = False
72+
73+
retries = _exponential_backoff.ExponentialBackoff()
74+
for _ in retries:
7175
response = await request(
7276
method="POST", url=token_uri, headers=headers, body=body
7377
)
@@ -93,18 +97,8 @@ async def _perform_request():
9397
status_code=response.status, response_data=response_data
9498
)
9599

96-
return False, response_data, retryable_error
97-
98-
request_succeeded, response_data, retryable_error = await _perform_request()
99-
100-
if request_succeeded or not retryable_error or not can_retry:
101-
return request_succeeded, response_data, retryable_error
102-
103-
retries = _exponential_backoff.ExponentialBackoff()
104-
for _ in retries:
105-
request_succeeded, response_data, retryable_error = await _perform_request()
106-
if request_succeeded or not retryable_error:
107-
return request_succeeded, response_data, retryable_error
100+
if not can_retry or not retryable_error:
101+
return False, response_data, retryable_error
108102

109103
return False, response_data, retryable_error
110104

system_tests/secrets.tar.enc

0 Bytes
Binary file not shown.

tests/oauth2/test__client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ def test__token_endpoint_request_internal_failure_error():
194194
_client._token_endpoint_request(
195195
request, "http://example.com", {"error_description": "internal_failure"}
196196
)
197-
# request should be called once and then with 3 retries
198-
assert request.call_count == 4
197+
# request with 2 retries
198+
assert request.call_count == 3
199199

200200
request = make_request(
201201
{"error": "internal_failure"}, status=http_client.BAD_REQUEST
@@ -205,8 +205,8 @@ def test__token_endpoint_request_internal_failure_error():
205205
_client._token_endpoint_request(
206206
request, "http://example.com", {"error": "internal_failure"}
207207
)
208-
# request should be called once and then with 3 retries
209-
assert request.call_count == 4
208+
# request with 2 retries
209+
assert request.call_count == 3
210210

211211

212212
def test__token_endpoint_request_internal_failure_and_retry_failure_error():
@@ -625,6 +625,6 @@ def test__token_endpoint_request_no_throw_with_retry(can_retry):
625625
)
626626

627627
if can_retry:
628-
assert mock_request.call_count == 4
628+
assert mock_request.call_count == 3
629629
else:
630630
assert mock_request.call_count == 1

tests/test__exponential_backoff.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
import mock
16+
import pytest # type: ignore
1617

1718
from google.auth import _exponential_backoff
19+
from google.auth import exceptions
1820

1921

2022
@mock.patch("time.sleep", return_value=None)
@@ -24,18 +26,31 @@ def test_exponential_backoff(mock_time):
2426
iteration_count = 0
2527

2628
for attempt in eb:
27-
backoff_interval = mock_time.call_args[0][0]
28-
jitter = curr_wait * eb._randomization_factor
29-
30-
assert (curr_wait - jitter) <= backoff_interval <= (curr_wait + jitter)
31-
assert attempt == iteration_count + 1
32-
assert eb.backoff_count == iteration_count + 1
33-
assert eb._current_wait_in_seconds == eb._multiplier ** (iteration_count + 1)
34-
35-
curr_wait = eb._current_wait_in_seconds
29+
if attempt == 1:
30+
assert mock_time.call_count == 0
31+
else:
32+
backoff_interval = mock_time.call_args[0][0]
33+
jitter = curr_wait * eb._randomization_factor
34+
35+
assert (curr_wait - jitter) <= backoff_interval <= (curr_wait + jitter)
36+
assert attempt == iteration_count + 1
37+
assert eb.backoff_count == iteration_count + 1
38+
assert eb._current_wait_in_seconds == eb._multiplier ** iteration_count
39+
40+
curr_wait = eb._current_wait_in_seconds
3641
iteration_count += 1
3742

3843
assert eb.total_attempts == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
3944
assert eb.backoff_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
4045
assert iteration_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
41-
assert mock_time.call_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS
46+
assert (
47+
mock_time.call_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS - 1
48+
)
49+
50+
51+
def test_minimum_total_attempts():
52+
with pytest.raises(exceptions.InvalidValue):
53+
_exponential_backoff.ExponentialBackoff(total_attempts=0)
54+
with pytest.raises(exceptions.InvalidValue):
55+
_exponential_backoff.ExponentialBackoff(total_attempts=-1)
56+
_exponential_backoff.ExponentialBackoff(total_attempts=1)

tests_async/oauth2/test__client_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,6 @@ async def test__token_endpoint_request_no_throw_with_retry(can_retry):
492492
)
493493

494494
if can_retry:
495-
assert mock_request.call_count == 4
495+
assert mock_request.call_count == 3
496496
else:
497497
assert mock_request.call_count == 1

0 commit comments

Comments
 (0)