Skip to content

Commit 08a2262

Browse files
authored
Handle App Service MSI response expires_on values (Azure#5972)
1 parent ff6fd86 commit 08a2262

File tree

2 files changed

+76
-35
lines changed

2 files changed

+76
-35
lines changed

sdk/identity/azure-identity/azure/identity/_authn_client.py

+43-19
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6+
import calendar
67
import time
78

89
from azure.core import Configuration, HttpRequest
910
from azure.core.credentials import AccessToken
10-
from azure.core.pipeline import Pipeline, PipelineRequest
11+
from azure.core.pipeline import Pipeline
1112
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy
1213
from azure.core.pipeline.transport import HttpTransport, RequestsTransport
1314
from msal import TokenCache
@@ -19,6 +20,7 @@
1920
except ImportError:
2021
TYPE_CHECKING = False
2122
if TYPE_CHECKING:
23+
from time import struct_time
2224
from typing import Any, Dict, Iterable, Mapping, Optional
2325
from azure.core.pipeline import PipelineResponse
2426
from azure.core.pipeline.policies import HTTPPolicy
@@ -46,39 +48,61 @@ def get_cached_token(self, scopes):
4648
return None
4749

4850
def _deserialize_and_cache_token(self, response, scopes, request_time):
49-
# type: (PipelineResponse, Iterable[str], int) -> str
51+
# type: (PipelineResponse, Iterable[str], int) -> AccessToken
5052
try:
51-
if "deserialized_data" in response.context:
52-
payload = response.context["deserialized_data"]
53-
else:
54-
payload = response.http_response.text()
53+
# ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response
54+
payload = response.context["deserialized_data"]
5555
token = payload["access_token"]
5656

57-
# these values are strings in IMDS responses but msal.TokenCache requires they be integers
57+
# these values are strings in some responses but msal.TokenCache requires they be integers
5858
# https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55
59-
expires_in = int(payload.get("expires_in", 0))
60-
if expires_in != 0:
61-
payload["expires_in"] = expires_in
59+
if "expires_in" in payload:
60+
payload["expires_in"] = int(payload["expires_in"])
6261
if "ext_expires_in" in payload:
6362
payload["ext_expires_in"] = int(payload["ext_expires_in"])
6463

64+
# this will raise if the payload has neither expires_on nor expires_in
65+
# (which is fine because that's unexpected, especially considering the payload contains access_token)
66+
expires_on = payload.get("expires_on") or payload["expires_in"] + request_time
67+
68+
# ensure expires_on is an int
69+
try:
70+
expires_on = int(expires_on)
71+
except ValueError:
72+
# probably an App Service MSI response, convert it to epoch seconds
73+
try:
74+
t = self._parse_app_service_expires_on(expires_on)
75+
expires_on = calendar.timegm(t)
76+
except ValueError:
77+
# have a token but don't know when it expires -> treat it as single-use
78+
expires_on = request_time
79+
80+
# now we have an int expires_on, ensure the cache entry gets it
81+
payload["expires_on"] = expires_on
82+
6583
self._cache.add({"response": payload, "scope": scopes})
6684

67-
# AccessToken contains the token's expires_on time. There are four cases for setting it:
68-
# 1. response has expires_on -> AccessToken uses it
69-
# 2. response has expires_on and expires_in -> AccessToken uses expires_on
70-
# 3. response has only expires_in -> AccessToken uses expires_in + time of request
71-
# 4. response has neither expires_on or expires_in -> AccessToken sets expires_on = 0
72-
# (not expecting this case; if it occurs, the token is effectively single-use)
73-
expires_on = payload.get("expires_on", 0)
74-
return AccessToken(token, expires_on or expires_in + request_time)
85+
return AccessToken(token, expires_on)
7586
except KeyError:
7687
if "access_token" in payload:
7788
payload["access_token"] = "****"
78-
raise AuthenticationError("Unexpected authentication response: {}".format(payload))
89+
raise AuthenticationError("Unexpected response: {}".format(payload))
7990
except Exception as ex:
8091
raise AuthenticationError("Authentication failed: {}".format(str(ex)))
8192

93+
@staticmethod
94+
def _parse_app_service_expires_on(expires_on):
95+
# type: (str) -> struct_time
96+
"""
97+
Parse expires_on from an App Service MSI response (e.g. "06/19/2019 23:42:01 +00:00") to struct_time.
98+
Expects the time is given in UTC (i.e. has offset +00:00).
99+
"""
100+
if not expires_on.endswith(" +00:00"):
101+
raise ValueError("'{}' doesn't match expected format".format(expires_on))
102+
103+
# parse the string minus the timezone offset
104+
return time.strptime(expires_on[: -len(" +00:00")], "%m/%d/%Y %H:%M:%S")
105+
82106
# TODO: public, factor out of request_token
83107
def _prepare_request(self, method="POST", headers=None, form_data=None, params=None):
84108
# type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest

sdk/identity/azure-identity/tests/test_authn_client.py

+33-16
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ def test_authn_client_deserialization():
1919
# using the synchronous AuthnClient to drive this test but the functionality tested is
2020
# in the sans I/O AuthnClientBase, shared with AsyncAuthnClient
2121
now = 6
22-
expires_in = 1800
22+
expires_in = 59 - now
2323
expires_on = now + expires_in
24-
expected_token = "token"
24+
access_token = "***"
25+
expected_access_token = AccessToken(access_token, expires_on)
26+
scope = "scope"
2527

2628
mock_response = Mock(
2729
headers={"content-type": "application/json"}, status_code=200, content_type=["application/json"]
@@ -30,29 +32,44 @@ def test_authn_client_deserialization():
3032

3133
# response with expires_on only
3234
mock_response.text = lambda: json.dumps(
33-
{"access_token": expected_token, "expires_on": expires_on, "token_type": "Bearer"}
35+
{"access_token": access_token, "expires_on": expires_on, "token_type": "Bearer", "resource": scope}
3436
)
35-
token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token("scope")
36-
assert token.token == expected_token
37-
assert token.expires_on == expires_on
37+
token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope)
38+
assert token == expected_access_token
3839

39-
# response with expires_in and expires_on
40+
# response with expires_on only and it's a datetime string (App Service MSI)
4041
mock_response.text = lambda: json.dumps(
41-
{"access_token": expected_token, "expires_in": expires_in, "expires_on": expires_on, "token_type": "Bearer"}
42+
{
43+
"access_token": access_token,
44+
"expires_on": "01/01/1970 00:00:{} +00:00".format(now + expires_in),
45+
"token_type": "Bearer",
46+
"resource": scope,
47+
}
4248
)
43-
token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token("scope")
44-
assert token.token == expected_token
45-
assert token.expires_on == expires_on
49+
token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope)
50+
assert token == expected_access_token
4651

47-
# response with expires_in only
52+
# response with string expires_in and expires_on (IMDS, Cloud Shell)
4853
mock_response.text = lambda: json.dumps(
49-
{"access_token": expected_token, "expires_in": expires_in, "token_type": "Bearer"}
54+
{
55+
"access_token": access_token,
56+
"expires_in": str(expires_in),
57+
"expires_on": str(expires_on),
58+
"token_type": "Bearer",
59+
"resource": scope,
60+
}
61+
)
62+
token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope)
63+
assert token == expected_access_token
64+
65+
# response with int expires_in (AAD)
66+
mock_response.text = lambda: json.dumps(
67+
{"access_token": access_token, "expires_in": expires_in, "token_type": "Bearer", "ext_expires_in": expires_in}
5068
)
5169
with patch("azure.identity._authn_client.time.time") as mock_time:
5270
mock_time.return_value = now
53-
token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token("scope")
54-
assert token.token == expected_token
55-
assert token.expires_on == expires_on
71+
token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope)
72+
assert token == expected_access_token
5673

5774

5875
def test_caching_when_only_expires_in_set():

0 commit comments

Comments
 (0)