Skip to content
17 changes: 13 additions & 4 deletions google/api_core/retry/retry_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import time

from enum import Enum
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Iterator, TYPE_CHECKING

import requests.exceptions

Expand Down Expand Up @@ -174,7 +174,7 @@ def build_retry_error(
def _retry_error_helper(
exc: Exception,
deadline: float | None,
next_sleep: float,
sleep_iterator: Iterator[float],
error_list: list[Exception],
predicate_fn: Callable[[Exception], bool],
on_error_fn: Callable[[Exception], None] | None,
Expand All @@ -183,7 +183,7 @@ def _retry_error_helper(
tuple[Exception, Exception | None],
],
original_timeout: float | None,
):
) -> float:
"""
Shared logic for handling an error for all retry implementations

Expand All @@ -194,13 +194,15 @@ def _retry_error_helper(
Args:
- exc: the exception that was raised
- deadline: the deadline for the retry, calculated as a diff from time.monotonic()
- next_sleep: the next sleep interval
- sleep_iterator: iterator to draw the next backoff value from
- error_list: the list of exceptions that have been raised so far
- predicate_fn: takes `exc` and returns true if the operation should be retried
- on_error_fn: callback to execute when a retryable error occurs
- exc_factory_fn: callback used to build the exception to be raised on terminal failure
- original_timeout_val: the original timeout value for the retry (in seconds),
to be passed to the exception factory for building an error message
Returns:
- the sleep value chosen before the next attempt
"""
error_list.append(exc)
if not predicate_fn(exc):
Expand All @@ -212,6 +214,12 @@ def _retry_error_helper(
raise final_exc from source_exc
if on_error_fn is not None:
on_error_fn(exc)
# next_sleep is fetched after the on_error callback, to allow clients
# to update sleep_iterator values dynamically in response to errors
try:
next_sleep = next(sleep_iterator)
except StopIteration:
raise ValueError("Sleep generator stopped yielding sleep values.") from exc
if deadline is not None and time.monotonic() + next_sleep > deadline:
final_exc, source_exc = exc_factory_fn(
error_list,
Expand All @@ -222,6 +230,7 @@ def _retry_error_helper(
_LOGGER.debug(
"Retrying due to {}, sleeping {:.1f}s ...".format(error_list[-1], next_sleep)
)
return next_sleep


class _BaseRetry(object):
Expand Down
13 changes: 7 additions & 6 deletions google/api_core/retry/retry_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ def retry_target_stream(
time.monotonic() + timeout if timeout is not None else None
)
error_list: list[Exception] = []
sleep_iter = iter(sleep_generator)

for sleep in sleep_generator:
# continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper
# TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535
while True:
# Start a new retry loop
try:
# Note: in the future, we can add a ResumptionStrategy object
Expand All @@ -121,20 +124,18 @@ def retry_target_stream(
# This function explicitly must deal with broad exceptions.
except Exception as exc:
# defer to shared logic for handling errors
_retry_error_helper(
next_sleep = _retry_error_helper(
exc,
deadline,
sleep,
sleep_iter,
error_list,
predicate,
on_error,
exception_factory,
timeout,
)
# if exception not raised, sleep before next attempt
time.sleep(sleep)

raise ValueError("Sleep generator stopped yielding sleep values.")
time.sleep(next_sleep)


class StreamingRetry(_BaseRetry):
Expand Down
13 changes: 8 additions & 5 deletions google/api_core/retry/retry_streaming_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ async def retry_target_stream(
deadline = time.monotonic() + timeout if timeout else None
# keep track of retryable exceptions we encounter to pass in to exception_factory
error_list: list[Exception] = []
sleep_iter = iter(sleep_generator)
target_is_generator: bool | None = None

for sleep in sleep_generator:
# continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper
# TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535
while True:
# Start a new retry loop
try:
# Note: in the future, we can add a ResumptionStrategy object
Expand Down Expand Up @@ -174,22 +177,22 @@ async def retry_target_stream(
# This function explicitly must deal with broad exceptions.
except Exception as exc:
# defer to shared logic for handling errors
_retry_error_helper(
next_sleep = _retry_error_helper(
exc,
deadline,
sleep,
sleep_iter,
error_list,
predicate,
on_error,
exception_factory,
timeout,
)
# if exception not raised, sleep before next attempt
await asyncio.sleep(sleep)
await asyncio.sleep(next_sleep)

finally:
if target_is_generator and target_iterator is not None:
await cast(AsyncGenerator["_Y", None], target_iterator).aclose()
raise ValueError("Sleep generator stopped yielding sleep values.")


class AsyncStreamingRetry(_BaseRetry):
Expand Down
13 changes: 7 additions & 6 deletions google/api_core/retry/retry_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@ def retry_target(

deadline = time.monotonic() + timeout if timeout is not None else None
error_list: list[Exception] = []
sleep_iter = iter(sleep_generator)

for sleep in sleep_generator:
# continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper
# TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535
while True:
try:
result = target()
if inspect.isawaitable(result):
Expand All @@ -150,20 +153,18 @@ def retry_target(
# This function explicitly must deal with broad exceptions.
except Exception as exc:
# defer to shared logic for handling errors
_retry_error_helper(
next_sleep = _retry_error_helper(
exc,
deadline,
sleep,
sleep_iter,
error_list,
predicate,
on_error,
exception_factory,
timeout,
)
# if exception not raised, sleep before next attempt
time.sleep(sleep)

raise ValueError("Sleep generator stopped yielding sleep values.")
time.sleep(next_sleep)


class Retry(_BaseRetry):
Expand Down
13 changes: 7 additions & 6 deletions google/api_core/retry/retry_unary_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,28 +149,29 @@ async def retry_target(

deadline = time.monotonic() + timeout if timeout is not None else None
error_list: list[Exception] = []
sleep_iter = iter(sleep_generator)

for sleep in sleep_generator:
# continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper
# TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535
while True:
try:
return await target()
# pylint: disable=broad-except
# This function explicitly must deal with broad exceptions.
except Exception as exc:
# defer to shared logic for handling errors
_retry_error_helper(
next_sleep = _retry_error_helper(
exc,
deadline,
sleep,
sleep_iter,
error_list,
predicate,
on_error,
exception_factory,
timeout,
)
# if exception not raised, sleep before next attempt
await asyncio.sleep(sleep)

raise ValueError("Sleep generator stopped yielding sleep values.")
await asyncio.sleep(next_sleep)


class AsyncRetry(_BaseRetry):
Expand Down
35 changes: 32 additions & 3 deletions tests/asyncio/retry/test_retry_streaming_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,36 @@ async def test_retry_streaming_target_bad_sleep_generator():
from google.api_core.retry.retry_streaming_async import retry_target_stream

with pytest.raises(ValueError, match="Sleep generator"):
await retry_target_stream(None, None, [], None).__anext__()
await retry_target_stream(None, lambda x: True, [], None).__anext__()


@mock.patch("asyncio.sleep", autospec=True)
@pytest.mark.asyncio
async def test_retry_streaming_target_dynamic_backoff(sleep):
"""
sleep_generator should be iterated after on_error, to support dynamic backoff
"""
from functools import partial
from google.api_core.retry.retry_streaming_async import retry_target_stream

sleep.side_effect = RuntimeError("stop after sleep")
# start with empty sleep generator; values are added after exception in push_sleep_value
sleep_values = []
error_target = partial(TestAsyncStreamingRetry._generator_mock, error_on=0)
inserted_sleep = 99

def push_sleep_value(err):
sleep_values.append(inserted_sleep)

with pytest.raises(RuntimeError):
await retry_target_stream(
error_target,
predicate=lambda x: True,
sleep_generator=sleep_values,
on_error=push_sleep_value,
).__anext__()
assert sleep.call_count == 1
sleep.assert_called_once_with(inserted_sleep)


class TestAsyncStreamingRetry(Test_BaseRetry):
Expand Down Expand Up @@ -66,8 +95,8 @@ def if_exception_type(exc):
str(retry_),
)

@staticmethod
async def _generator_mock(
self,
num=5,
error_on=None,
exceptions_seen=None,
Expand All @@ -87,7 +116,7 @@ async def _generator_mock(
for i in range(num):
if sleep_time:
await asyncio.sleep(sleep_time)
if error_on and i == error_on:
if error_on is not None and i == error_on:
raise ValueError("generator mock error")
yield i
except (Exception, BaseException, GeneratorExit) as e:
Expand Down
27 changes: 26 additions & 1 deletion tests/asyncio/retry/test_retry_unary_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,34 @@ async def test_retry_target_timeout_exceeded(monotonic, sleep, use_deadline_arg)
@pytest.mark.asyncio
async def test_retry_target_bad_sleep_generator():
with pytest.raises(ValueError, match="Sleep generator"):
await retry_async.retry_target(mock.sentinel.target, lambda x: True, [], None)


@mock.patch("asyncio.sleep", autospec=True)
@pytest.mark.asyncio
async def test_retry_target_dynamic_backoff(sleep):
"""
sleep_generator should be iterated after on_error, to support dynamic backoff
"""
sleep.side_effect = RuntimeError("stop after sleep")
# start with empty sleep generator; values are added after exception in push_sleep_value
sleep_values = []
exception = ValueError("trigger retry")
error_target = mock.Mock(side_effect=exception)
inserted_sleep = 99

def push_sleep_value(err):
sleep_values.append(inserted_sleep)

with pytest.raises(RuntimeError):
await retry_async.retry_target(
mock.sentinel.target, mock.sentinel.predicate, [], None
error_target,
predicate=lambda x: True,
sleep_generator=sleep_values,
on_error=push_sleep_value,
)
assert sleep.call_count == 1
sleep.assert_called_once_with(inserted_sleep)


class TestAsyncRetry(Test_BaseRetry):
Expand Down
35 changes: 32 additions & 3 deletions tests/unit/retry/test_retry_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,36 @@ def test_retry_streaming_target_bad_sleep_generator():
with pytest.raises(
ValueError, match="Sleep generator stopped yielding sleep values"
):
next(retry_streaming.retry_target_stream(None, None, [], None))
next(retry_streaming.retry_target_stream(None, lambda x: True, [], None))


@mock.patch("time.sleep", autospec=True)
def test_retry_streaming_target_dynamic_backoff(sleep):
"""
sleep_generator should be iterated after on_error, to support dynamic backoff
"""
from functools import partial

sleep.side_effect = RuntimeError("stop after sleep")
# start with empty sleep generator; values are added after exception in push_sleep_value
sleep_values = []
error_target = partial(TestStreamingRetry._generator_mock, error_on=0)
inserted_sleep = 99

def push_sleep_value(err):
sleep_values.append(inserted_sleep)

with pytest.raises(RuntimeError):
next(
retry_streaming.retry_target_stream(
error_target,
predicate=lambda x: True,
sleep_generator=sleep_values,
on_error=push_sleep_value,
)
)
assert sleep.call_count == 1
sleep.assert_called_once_with(inserted_sleep)


class TestStreamingRetry(Test_BaseRetry):
Expand Down Expand Up @@ -63,8 +92,8 @@ def if_exception_type(exc):
str(retry_),
)

@staticmethod
def _generator_mock(
self,
num=5,
error_on=None,
return_val=None,
Expand All @@ -82,7 +111,7 @@ def _generator_mock(
"""
try:
for i in range(num):
if error_on and i == error_on:
if error_on is not None and i == error_on:
raise ValueError("generator mock error")
yield i
return return_val
Expand Down
Loading