Skip to content

Commit 461f2cc

Browse files
authored
Firestore: Add 'should_terminate' predicate for clean BiDi shutdown. (#8650)
Closes #7826.
1 parent 495e5cb commit 461f2cc

File tree

2 files changed

+123
-7
lines changed

2 files changed

+123
-7
lines changed

google/api_core/bidi.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,11 @@ def pending_requests(self):
349349
return self._request_queue.qsize()
350350

351351

352+
def _never_terminate(future_or_error):
353+
"""By default, no errors cause BiDi termination."""
354+
return False
355+
356+
352357
class ResumableBidiRpc(BidiRpc):
353358
"""A :class:`BidiRpc` that can automatically resume the stream on errors.
354359
@@ -391,6 +396,9 @@ def should_recover(exc):
391396
should_recover (Callable[[Exception], bool]): A function that returns
392397
True if the stream should be recovered. This will be called
393398
whenever an error is encountered on the stream.
399+
should_terminate (Callable[[Exception], bool]): A function that returns
400+
True if the stream should be terminated. This will be called
401+
whenever an error is encountered on the stream.
394402
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
395403
the request.
396404
throttle_reopen (bool): If ``True``, throttling will be applied to
@@ -401,12 +409,14 @@ def __init__(
401409
self,
402410
start_rpc,
403411
should_recover,
412+
should_terminate=_never_terminate,
404413
initial_request=None,
405414
metadata=None,
406415
throttle_reopen=False,
407416
):
408417
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
409418
self._should_recover = should_recover
419+
self._should_terminate = should_terminate
410420
self._operational_lock = threading.RLock()
411421
self._finalized = False
412422
self._finalize_lock = threading.Lock()
@@ -433,7 +443,9 @@ def _on_call_done(self, future):
433443
# error, not for errors that we can recover from. Note that grpc's
434444
# "future" here is also a grpc.RpcError.
435445
with self._operational_lock:
436-
if not self._should_recover(future):
446+
if self._should_terminate(future):
447+
self._finalize(future)
448+
elif not self._should_recover(future):
437449
self._finalize(future)
438450
else:
439451
_LOGGER.debug("Re-opening stream from gRPC callback.")
@@ -496,6 +508,12 @@ def _recoverable(self, method, *args, **kwargs):
496508
with self._operational_lock:
497509
_LOGGER.debug("Call to retryable %r caused %s.", method, exc)
498510

511+
if self._should_terminate(exc):
512+
self.close()
513+
_LOGGER.debug("Terminating %r due to %s.", method, exc)
514+
self._finalize(exc)
515+
break
516+
499517
if not self._should_recover(exc):
500518
self.close()
501519
_LOGGER.debug("Not retrying %r due to %s.", method, exc)

tests/unit/test_bidi.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,33 +370,111 @@ def cancel(self):
370370

371371

372372
class TestResumableBidiRpc(object):
373-
def test_initial_state(self):
374-
callback = mock.Mock()
375-
callback.return_value = True
376-
bidi_rpc = bidi.ResumableBidiRpc(None, callback)
373+
def test_ctor_defaults(self):
374+
start_rpc = mock.Mock()
375+
should_recover = mock.Mock()
376+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
377+
378+
assert bidi_rpc.is_active is False
379+
assert bidi_rpc._finalized is False
380+
assert bidi_rpc._start_rpc is start_rpc
381+
assert bidi_rpc._should_recover is should_recover
382+
assert bidi_rpc._should_terminate is bidi._never_terminate
383+
assert bidi_rpc._initial_request is None
384+
assert bidi_rpc._rpc_metadata is None
385+
assert bidi_rpc._reopen_throttle is None
386+
387+
def test_ctor_explicit(self):
388+
start_rpc = mock.Mock()
389+
should_recover = mock.Mock()
390+
should_terminate = mock.Mock()
391+
initial_request = mock.Mock()
392+
metadata = {"x-foo": "bar"}
393+
bidi_rpc = bidi.ResumableBidiRpc(
394+
start_rpc,
395+
should_recover,
396+
should_terminate=should_terminate,
397+
initial_request=initial_request,
398+
metadata=metadata,
399+
throttle_reopen=True,
400+
)
377401

378402
assert bidi_rpc.is_active is False
403+
assert bidi_rpc._finalized is False
404+
assert bidi_rpc._should_recover is should_recover
405+
assert bidi_rpc._should_terminate is should_terminate
406+
assert bidi_rpc._initial_request is initial_request
407+
assert bidi_rpc._rpc_metadata == metadata
408+
assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle)
409+
410+
def test_done_callbacks_terminate(self):
411+
cancellation = mock.Mock()
412+
start_rpc = mock.Mock()
413+
should_recover = mock.Mock(spec=["__call__"], return_value=True)
414+
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
415+
bidi_rpc = bidi.ResumableBidiRpc(
416+
start_rpc, should_recover, should_terminate=should_terminate
417+
)
418+
callback = mock.Mock(spec=["__call__"])
419+
420+
bidi_rpc.add_done_callback(callback)
421+
bidi_rpc._on_call_done(cancellation)
422+
423+
should_terminate.assert_called_once_with(cancellation)
424+
should_recover.assert_not_called()
425+
callback.assert_called_once_with(cancellation)
426+
assert not bidi_rpc.is_active
379427

380428
def test_done_callbacks_recoverable(self):
381429
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
382-
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
430+
should_recover = mock.Mock(spec=["__call__"], return_value=True)
431+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
383432
callback = mock.Mock(spec=["__call__"])
384433

385434
bidi_rpc.add_done_callback(callback)
386435
bidi_rpc._on_call_done(mock.sentinel.future)
387436

388437
callback.assert_not_called()
389438
start_rpc.assert_called_once()
439+
should_recover.assert_called_once_with(mock.sentinel.future)
390440
assert bidi_rpc.is_active
391441

392442
def test_done_callbacks_non_recoverable(self):
393-
bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
443+
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
444+
should_recover = mock.Mock(spec=["__call__"], return_value=False)
445+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
394446
callback = mock.Mock(spec=["__call__"])
395447

396448
bidi_rpc.add_done_callback(callback)
397449
bidi_rpc._on_call_done(mock.sentinel.future)
398450

399451
callback.assert_called_once_with(mock.sentinel.future)
452+
should_recover.assert_called_once_with(mock.sentinel.future)
453+
assert not bidi_rpc.is_active
454+
455+
def test_send_terminate(self):
456+
cancellation = ValueError()
457+
call_1 = CallStub([cancellation], active=False)
458+
call_2 = CallStub([])
459+
start_rpc = mock.create_autospec(
460+
grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
461+
)
462+
should_recover = mock.Mock(spec=["__call__"], return_value=False)
463+
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
464+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
465+
466+
bidi_rpc.open()
467+
468+
bidi_rpc.send(mock.sentinel.request)
469+
470+
assert bidi_rpc.pending_requests == 1
471+
assert bidi_rpc._request_queue.get() is None
472+
473+
should_recover.assert_not_called()
474+
should_terminate.assert_called_once_with(cancellation)
475+
assert bidi_rpc.call == call_1
476+
assert bidi_rpc.is_active is False
477+
assert call_1.cancelled is True
400478

401479
def test_send_recover(self):
402480
error = ValueError()
@@ -441,6 +519,26 @@ def test_send_failure(self):
441519
assert bidi_rpc.pending_requests == 1
442520
assert bidi_rpc._request_queue.get() is None
443521

522+
def test_recv_terminate(self):
523+
cancellation = ValueError()
524+
call = CallStub([cancellation])
525+
start_rpc = mock.create_autospec(
526+
grpc.StreamStreamMultiCallable, instance=True, return_value=call
527+
)
528+
should_recover = mock.Mock(spec=["__call__"], return_value=False)
529+
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
530+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
531+
532+
bidi_rpc.open()
533+
534+
bidi_rpc.recv()
535+
536+
should_recover.assert_not_called()
537+
should_terminate.assert_called_once_with(cancellation)
538+
assert bidi_rpc.call == call
539+
assert bidi_rpc.is_active is False
540+
assert call.cancelled is True
541+
444542
def test_recv_recover(self):
445543
error = ValueError()
446544
call_1 = CallStub([1, error])

0 commit comments

Comments
 (0)