|
3 | 3 | # Licensed under the MIT License. See LICENSE.txt in the project root for
|
4 | 4 | # license information.
|
5 | 5 | # -------------------------------------------------------------------------
|
| 6 | +import calendar |
6 | 7 | import time
|
7 | 8 |
|
8 | 9 | from azure.core import Configuration, HttpRequest
|
9 | 10 | from azure.core.credentials import AccessToken
|
10 |
| -from azure.core.pipeline import Pipeline, PipelineRequest |
| 11 | +from azure.core.pipeline import Pipeline |
11 | 12 | from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy
|
12 | 13 | from azure.core.pipeline.transport import HttpTransport, RequestsTransport
|
13 | 14 | from msal import TokenCache
|
|
19 | 20 | except ImportError:
|
20 | 21 | TYPE_CHECKING = False
|
21 | 22 | if TYPE_CHECKING:
|
| 23 | + from time import struct_time |
22 | 24 | from typing import Any, Dict, Iterable, Mapping, Optional
|
23 | 25 | from azure.core.pipeline import PipelineResponse
|
24 | 26 | from azure.core.pipeline.policies import HTTPPolicy
|
@@ -46,39 +48,61 @@ def get_cached_token(self, scopes):
|
46 | 48 | return None
|
47 | 49 |
|
48 | 50 | def _deserialize_and_cache_token(self, response, scopes, request_time):
|
49 |
| - # type: (PipelineResponse, Iterable[str], int) -> str |
| 51 | + # type: (PipelineResponse, Iterable[str], int) -> AccessToken |
50 | 52 | try:
|
51 |
| - if "deserialized_data" in response.context: |
52 |
| - payload = response.context["deserialized_data"] |
53 |
| - else: |
54 |
| - payload = response.http_response.text() |
| 53 | + # ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response |
| 54 | + payload = response.context["deserialized_data"] |
55 | 55 | token = payload["access_token"]
|
56 | 56 |
|
57 |
| - # these values are strings in IMDS responses but msal.TokenCache requires they be integers |
| 57 | + # these values are strings in some responses but msal.TokenCache requires they be integers |
58 | 58 | # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55
|
59 |
| - expires_in = int(payload.get("expires_in", 0)) |
60 |
| - if expires_in != 0: |
61 |
| - payload["expires_in"] = expires_in |
| 59 | + if "expires_in" in payload: |
| 60 | + payload["expires_in"] = int(payload["expires_in"]) |
62 | 61 | if "ext_expires_in" in payload:
|
63 | 62 | payload["ext_expires_in"] = int(payload["ext_expires_in"])
|
64 | 63 |
|
| 64 | + # this will raise if the payload has neither expires_on nor expires_in |
| 65 | + # (which is fine because that's unexpected, especially considering the payload contains access_token) |
| 66 | + expires_on = payload.get("expires_on") or payload["expires_in"] + request_time |
| 67 | + |
| 68 | + # ensure expires_on is an int |
| 69 | + try: |
| 70 | + expires_on = int(expires_on) |
| 71 | + except ValueError: |
| 72 | + # probably an App Service MSI response, convert it to epoch seconds |
| 73 | + try: |
| 74 | + t = self._parse_app_service_expires_on(expires_on) |
| 75 | + expires_on = calendar.timegm(t) |
| 76 | + except ValueError: |
| 77 | + # have a token but don't know when it expires -> treat it as single-use |
| 78 | + expires_on = request_time |
| 79 | + |
| 80 | + # now we have an int expires_on, ensure the cache entry gets it |
| 81 | + payload["expires_on"] = expires_on |
| 82 | + |
65 | 83 | self._cache.add({"response": payload, "scope": scopes})
|
66 | 84 |
|
67 |
| - # AccessToken contains the token's expires_on time. There are four cases for setting it: |
68 |
| - # 1. response has expires_on -> AccessToken uses it |
69 |
| - # 2. response has expires_on and expires_in -> AccessToken uses expires_on |
70 |
| - # 3. response has only expires_in -> AccessToken uses expires_in + time of request |
71 |
| - # 4. response has neither expires_on or expires_in -> AccessToken sets expires_on = 0 |
72 |
| - # (not expecting this case; if it occurs, the token is effectively single-use) |
73 |
| - expires_on = payload.get("expires_on", 0) |
74 |
| - return AccessToken(token, expires_on or expires_in + request_time) |
| 85 | + return AccessToken(token, expires_on) |
75 | 86 | except KeyError:
|
76 | 87 | if "access_token" in payload:
|
77 | 88 | payload["access_token"] = "****"
|
78 |
| - raise AuthenticationError("Unexpected authentication response: {}".format(payload)) |
| 89 | + raise AuthenticationError("Unexpected response: {}".format(payload)) |
79 | 90 | except Exception as ex:
|
80 | 91 | raise AuthenticationError("Authentication failed: {}".format(str(ex)))
|
81 | 92 |
|
| 93 | + @staticmethod |
| 94 | + def _parse_app_service_expires_on(expires_on): |
| 95 | + # type: (str) -> struct_time |
| 96 | + """ |
| 97 | + Parse expires_on from an App Service MSI response (e.g. "06/19/2019 23:42:01 +00:00") to struct_time. |
| 98 | + Expects the time is given in UTC (i.e. has offset +00:00). |
| 99 | + """ |
| 100 | + if not expires_on.endswith(" +00:00"): |
| 101 | + raise ValueError("'{}' doesn't match expected format".format(expires_on)) |
| 102 | + |
| 103 | + # parse the string minus the timezone offset |
| 104 | + return time.strptime(expires_on[: -len(" +00:00")], "%m/%d/%Y %H:%M:%S") |
| 105 | + |
82 | 106 | # TODO: public, factor out of request_token
|
83 | 107 | def _prepare_request(self, method="POST", headers=None, form_data=None, params=None):
|
84 | 108 | # type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
|
|
0 commit comments