Skip to content

Commit 827fc57

Browse files
authored
SansIOHttpPolicy can return awaitable objects (Azure#7497)
* Make BearerTokenCredentialPolicy a SansIOHTTPPolicy * Allow SansIOHTTPPolicy to return awaitables * SansIOHttpPolicy doc * typo * Awaitable not backported in typing for 2.7 + black * Fix typing * Feedbacks from Anna * Check sync runner doesn't get a coroutine * pylint
1 parent 652e872 commit 827fc57

File tree

5 files changed

+58
-30
lines changed

5 files changed

+58
-30
lines changed

sdk/core/azure-core/azure/core/pipeline/base.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@
3737
PoliciesType = List[Union[HTTPPolicy, SansIOHTTPPolicy]]
3838

3939

40+
def _await_result(func, *args, **kwargs):
41+
"""If func returns an awaitable, raise that this runner can't handle it."""
42+
result = func(*args, **kwargs)
43+
if hasattr(result, '__await__'):
44+
raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func))
45+
return result
46+
47+
4048
class _SansIOHTTPPolicyRunner(HTTPPolicy, Generic[HTTPRequestType, HTTPResponseType]):
4149
"""Sync implementation of the SansIO policy.
4250
@@ -60,14 +68,14 @@ def send(self, request):
6068
:return: The PipelineResponse object.
6169
:rtype: ~azure.core.pipeline.PipelineResponse
6270
"""
63-
self._policy.on_request(request)
71+
_await_result(self._policy.on_request, request)
6472
try:
6573
response = self.next.send(request)
6674
except Exception: #pylint: disable=broad-except
67-
if not self._policy.on_exception(request):
75+
if not _await_result(self._policy.on_exception, request):
6876
raise
6977
else:
70-
self._policy.on_response(request, response)
78+
_await_result(self._policy.on_response, request, response)
7179
return response
7280

7381

sdk/core/azure-core/azure/core/pipeline/base_async.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ async def __aexit__(self, exc_type, exc_value, traceback):
4949
return None
5050

5151

52+
async def _await_result(func, *args, **kwargs):
53+
"""If func returns an awaitable, await it."""
54+
result = func(*args, **kwargs)
55+
if hasattr(result, '__await__'):
56+
# type ignore on await: https://github.com/python/mypy/issues/7587
57+
return await result # type: ignore
58+
return result
59+
60+
5261
class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): #pylint: disable=unsubscriptable-object
5362
"""Async implementation of the SansIO policy.
5463
@@ -62,22 +71,22 @@ def __init__(self, policy: SansIOHTTPPolicy) -> None:
6271
super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
6372
self._policy = policy
6473

65-
async def send(self, request: PipelineRequest):
74+
async def send(self, request: PipelineRequest) -> PipelineResponse:
6675
"""Modifies the request and sends to the next policy in the chain.
6776
6877
:param request: The PipelineRequest object.
6978
:type request: ~azure.core.pipeline.PipelineRequest
7079
:return: The PipelineResponse object.
7180
:rtype: ~azure.core.pipeline.PipelineResponse
7281
"""
73-
self._policy.on_request(request)
82+
await _await_result(self._policy.on_request, request)
7483
try:
7584
response = await self.next.send(request) # type: ignore
7685
except Exception: #pylint: disable=broad-except
77-
if not self._policy.on_exception(request):
86+
if not await _await_result(self._policy.on_exception, request):
7887
raise
7988
else:
80-
self._policy.on_response(request, response)
89+
await _await_result(self._policy.on_response, request, response)
8190
return response
8291

8392

sdk/core/azure-core/azure/core/pipeline/policies/authentication.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# -------------------------------------------------------------------------
66
import time
77

8-
from . import HTTPPolicy
8+
from . import SansIOHTTPPolicy
99

1010
try:
1111
from typing import TYPE_CHECKING # pylint:disable=unused-import
@@ -16,7 +16,7 @@
1616
# pylint:disable=unused-import
1717
from typing import Any, Dict, Mapping, Optional
1818
from azure.core.credentials import AccessToken, TokenCredential
19-
from azure.core.pipeline import PipelineRequest, PipelineResponse
19+
from azure.core.pipeline import PipelineRequest
2020

2121

2222
# pylint:disable=too-few-public-methods
@@ -51,24 +51,21 @@ def _need_new_token(self):
5151
return not self._token or self._token.expires_on - time.time() < 300
5252

5353

54-
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
54+
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
5555
"""Adds a bearer token Authorization header to requests.
5656
5757
:param credential: The credential.
5858
:type credential: ~azure.core.TokenCredential
5959
:param str scopes: Lets you specify the type of access needed.
6060
"""
6161

62-
def send(self, request):
63-
# type: (PipelineRequest) -> PipelineResponse
62+
def on_request(self, request):
63+
# type: (PipelineRequest) -> None
6464
"""Adds a bearer token Authorization header to request and sends request to next policy.
6565
6666
:param request: The pipeline request object
6767
:type request: ~azure.core.pipeline.PipelineRequest
68-
:return: The pipeline response object
69-
:rtype: ~azure.core.pipeline.PipelineResponse
7068
"""
7169
if self._need_new_token:
7270
self._token = self._credential.get_token(*self._scopes)
7371
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
74-
return self.next.send(request)

sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# -------------------------------------------------------------------------
66
import threading
77

8-
from azure.core.pipeline import PipelineRequest, PipelineResponse
9-
from azure.core.pipeline.policies import AsyncHTTPPolicy
8+
from azure.core.pipeline import PipelineRequest
9+
from azure.core.pipeline.policies import SansIOHTTPPolicy
1010
from azure.core.pipeline.policies.authentication import _BearerTokenCredentialPolicyBase
1111

1212

13-
class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, AsyncHTTPPolicy):
13+
class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
1414
# pylint:disable=too-few-public-methods
1515
"""Adds a bearer token Authorization header to requests.
1616
@@ -23,16 +23,13 @@ def __init__(self, credential, *scopes, **kwargs):
2323
super().__init__(credential, *scopes, **kwargs)
2424
self._lock = threading.Lock()
2525

26-
async def send(self, request: PipelineRequest) -> PipelineResponse:
26+
async def on_request(self, request: PipelineRequest):
2727
"""Adds a bearer token Authorization header to request and sends request to next policy.
2828
2929
:param request: The pipeline request object to be modified.
3030
:type request: ~azure.core.pipeline.PipelineRequest
31-
:return: The pipeline response object
32-
:rtype: ~azure.core.pipeline.PipelineResponse
3331
"""
3432
with self._lock:
3533
if self._need_new_token:
3634
self._token = await self._credential.get_token(*self._scopes) # type: ignore
3735
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
38-
return await self.next.send(request) # type: ignore

sdk/core/azure-core/azure/core/pipeline/policies/base.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,30 @@
2828
import copy
2929
import logging
3030

31-
from typing import (TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, # pylint: disable=unused-import
32-
Tuple, Callable, Iterator)
31+
from typing import (
32+
Generic,
33+
TypeVar,
34+
Union,
35+
Any,
36+
Dict,
37+
Optional,
38+
) # pylint: disable=unused-import
39+
40+
try:
41+
from typing import Awaitable # pylint: disable=unused-import
42+
except ImportError:
43+
pass
3344

3445
from azure.core.pipeline import ABC, PipelineRequest, PipelineResponse
3546

47+
3648
HTTPResponseType = TypeVar("HTTPResponseType")
3749
HTTPRequestType = TypeVar("HTTPRequestType")
3850

3951
_LOGGER = logging.getLogger(__name__)
4052

4153

42-
class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore
54+
class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore
4355
"""An HTTP policy ABC.
4456
4557
Use with a synchronous pipeline.
@@ -48,6 +60,7 @@ class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignor
4860
instantiated and all policies chained.
4961
:type next: ~azure.core.pipeline.policies.HTTPPolicy or ~azure.core.pipeline.transport.HTTPTransport
5062
"""
63+
5164
def __init__(self):
5265
self.next = None
5366

@@ -64,6 +77,7 @@ def send(self, request):
6477
:rtype: ~azure.core.pipeline.PipelineResponse
6578
"""
6679

80+
6781
class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
6882
"""Represents a sans I/O policy.
6983
@@ -72,18 +86,20 @@ class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
7286
on the specifics of any particular transport. SansIOHTTPPolicy
7387
subclasses will function in either a Pipeline or an AsyncPipeline,
7488
and can act either before the request is done, or after.
89+
You can optionally make these methods coroutines (or return awaitable objects)
90+
but they will then be tied to AsyncPipeline usage.
7591
"""
7692

7793
def on_request(self, request):
78-
# type: (PipelineRequest) -> None
94+
# type: (PipelineRequest) -> Union[None, Awaitable[None]]
7995
"""Is executed before sending the request from next policy.
8096
8197
:param request: Request to be modified before sent from next policy.
8298
:type request: ~azure.core.pipeline.PipelineRequest
8399
"""
84100

85101
def on_response(self, request, response):
86-
# type: (PipelineRequest, PipelineResponse) -> None
102+
# type: (PipelineRequest, PipelineResponse) -> Union[None, Awaitable[None]]
87103
"""Is executed after the request comes back from the policy.
88104
89105
:param request: Request to be modified after returning from the policy.
@@ -92,9 +108,9 @@ def on_response(self, request, response):
92108
:type response: ~azure.core.pipeline.PipelineResponse
93109
"""
94110

95-
#pylint: disable=no-self-use
96-
def on_exception(self, _request): #pylint: disable=unused-argument
97-
# type: (PipelineRequest) -> bool
111+
# pylint: disable=no-self-use
112+
def on_exception(self, _request): # pylint: disable=unused-argument
113+
# type: (PipelineRequest) -> Union[bool, Awaitable[bool]]
98114
"""Is executed if an exception is raised while executing the next policy.
99115
100116
Developer can optionally implement this method to return True
@@ -129,6 +145,7 @@ class RequestHistory(object):
129145
:param Exception error: An error encountered during the request, or None if the response was received successfully.
130146
:param dict context: The pipeline context.
131147
"""
148+
132149
def __init__(self, http_request, http_response=None, error=None, context=None):
133150
# type: (PipelineRequest, Optional[PipelineResponse], Exception, Optional[Dict[str, Any]]) -> None
134151
self.http_request = copy.deepcopy(http_request)

0 commit comments

Comments
 (0)