Skip to content

Commit 9261e58

Browse files
authored
Improve managed identity support (Azure#5908)
* managed identity in App Service, Cloud Shell, IMDS * support user assigned identities
1 parent 047e226 commit 9261e58

11 files changed

+421
-153
lines changed

sdk/identity/azure-identity/azure/identity/_authn_client.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def _deserialize_and_cache_token(self, response, scopes, request_time):
7979
except Exception as ex:
8080
raise AuthenticationError("Authentication failed: {}".format(str(ex)))
8181

82+
# TODO: public, factor out of request_token
8283
def _prepare_request(self, method="POST", headers=None, form_data=None, params=None):
8384
# type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
8485
request = HttpRequest(method, self._auth_url, headers=headers)
@@ -102,11 +103,11 @@ def __init__(self, auth_url, config=None, policies=None, transport=None, **kwarg
102103
self._pipeline = Pipeline(transport=transport, policies=policies)
103104
super(AuthnClient, self).__init__(auth_url, **kwargs)
104105

105-
def request_token(self, scopes, method="POST", headers=None, form_data=None, params=None):
106-
# type: (Iterable[str], Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> AccessToken
106+
def request_token(self, scopes, method="POST", headers=None, form_data=None, params=None, **kwargs):
107+
# type: (Iterable[str], Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]], Any) -> AccessToken
107108
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
108109
request_time = int(time.time())
109-
response = self._pipeline.run(request, stream=False)
110+
response = self._pipeline.run(request, stream=False, **kwargs)
110111
token = self._deserialize_and_cache_token(response, scopes, request_time)
111112
return token
112113

@@ -115,5 +116,5 @@ def create_config(**kwargs):
115116
# type: (Mapping[str, Any]) -> Configuration
116117
config = Configuration(**kwargs)
117118
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
118-
config.retry_policy = RetryPolicy(retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs)
119+
config.retry_policy = RetryPolicy(**kwargs)
119120
return config

sdk/identity/azure-identity/azure/identity/_internal.py

+94-49
Original file line numberDiff line numberDiff line change
@@ -12,95 +12,140 @@
1212

1313
if TYPE_CHECKING:
1414
# pylint:disable=unused-import
15-
from typing import Any, Dict, Optional
15+
from typing import Any, Mapping, Optional, Type
1616
from azure.core.credentials import AccessToken
1717

1818
from azure.core import Configuration
19+
from azure.core.exceptions import HttpResponseError
1920
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, RetryPolicy
2021

2122
from ._authn_client import AuthnClient
22-
from .constants import Endpoints, MSI_ENDPOINT, MSI_SECRET
23+
from .constants import Endpoints, EnvironmentVariables
2324
from .exceptions import AuthenticationError
2425

2526

26-
class ImdsCredential:
27-
"""Authenticates with a managed identity via the IMDS endpoint"""
28-
29-
def __init__(self, config=None, **kwargs):
30-
# type: (Optional[Configuration], Dict[str, Any]) -> None
27+
class _ManagedIdentityBase(object):
28+
def __init__(self, endpoint, client_cls, config=None, client_id=None, **kwargs):
29+
# type: (str, Type, Optional[Configuration], Optional[str], Any) -> None
30+
self._client_id = client_id
3131
config = config or self.create_config(**kwargs)
32-
policies = [config.header_policy, ContentDecodePolicy(), config.logging_policy, config.retry_policy]
33-
self._client = AuthnClient(Endpoints.IMDS, config, policies, **kwargs)
32+
policies = [ContentDecodePolicy(), config.headers_policy, config.retry_policy, config.logging_policy]
33+
self._client = client_cls(endpoint, config, policies, **kwargs)
3434

3535
@staticmethod
3636
def create_config(**kwargs):
37-
# type: (Dict[str, str]) -> Configuration
37+
# type: (Mapping[str, Any]) -> Configuration
3838
timeout = kwargs.pop("connection_timeout", 2)
3939
config = Configuration(connection_timeout=timeout, **kwargs)
40-
config.header_policy = HeadersPolicy(base_headers={"Metadata": "true"}, **kwargs)
40+
41+
# retry is the only IO policy, so its class is a kwarg to increase async code sharing
42+
retry_policy = kwargs.pop("retry_policy", RetryPolicy) # type: ignore
43+
args = kwargs.copy() # combine kwargs and default retry settings in a Python 2-compatible way
44+
args.update(_ManagedIdentityBase._retry_settings) # type: ignore
45+
config.retry_policy = retry_policy(**args) # type: ignore
46+
47+
# Metadata header is required by IMDS and in Cloud Shell; App Service ignores it
48+
config.headers_policy = HeadersPolicy(base_headers={"Metadata": "true"}, **kwargs)
4149
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
42-
retries = kwargs.pop("retry_total", 5)
43-
config.retry_policy = RetryPolicy(
44-
retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
45-
)
50+
4651
return config
4752

53+
# given RetryPolicy's implementation, these settings most closely match the documented guidance for IMDS
54+
# https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance
55+
_retry_settings = {
56+
"retry_total": 5,
57+
"retry_status": 5,
58+
"retry_backoff_factor": 4,
59+
"retry_backoff_max": 60,
60+
"retry_on_status_codes": [404, 429] + list(range(500, 600)),
61+
}
62+
63+
64+
class ImdsCredential(_ManagedIdentityBase):
65+
"""Authenticates with a managed identity via the IMDS endpoint"""
66+
67+
def __init__(self, config=None, **kwargs):
68+
# type: (Optional[Configuration], Any) -> None
69+
super(ImdsCredential, self).__init__(endpoint=Endpoints.IMDS, client_cls=AuthnClient, config=config, **kwargs)
70+
self._endpoint_available = None # type: Optional[bool]
71+
4872
def get_token(self, *scopes):
4973
# type: (*str) -> AccessToken
74+
if self._endpoint_available is None:
75+
# Lacking another way to determine whether the IMDS endpoint is listening,
76+
# we send a request it would immediately reject (missing a required header),
77+
# setting a short timeout.
78+
try:
79+
self._client.request_token(scopes, method="GET", connection_timeout=0.3)
80+
self._endpoint_available = True
81+
except (AuthenticationError, HttpResponseError):
82+
# received a response a pipeline policy choked on (HttpResponseError)
83+
# or that couldn't be deserialized by AuthnClient (AuthenticationError)
84+
self._endpoint_available = True
85+
except Exception: # pylint:disable=broad-except
86+
# if anything else was raised, assume the endpoint is unavailable
87+
self._endpoint_available = False
88+
89+
if not self._endpoint_available:
90+
raise AuthenticationError("IMDS endpoint unavailable")
91+
5092
if len(scopes) != 1:
5193
raise ValueError("this credential supports one scope per request")
94+
5295
token = self._client.get_cached_token(scopes)
5396
if not token:
5497
resource = scopes[0]
5598
if resource.endswith("/.default"):
5699
resource = resource[: -len("/.default")]
57-
token = self._client.request_token(
58-
scopes, method="GET", params={"api-version": "2018-02-01", "resource": resource}
59-
)
100+
params = {"api-version": "2018-02-01", "resource": resource}
101+
if self._client_id:
102+
params["client_id"] = self._client_id
103+
token = self._client.request_token(scopes, method="GET", params=params)
60104
return token
61105

62106

63-
class MsiCredential:
64-
"""Authenticates via the MSI endpoint"""
107+
class MsiCredential(_ManagedIdentityBase):
108+
"""Authenticates via the MSI endpoint in App Service or Cloud Shell"""
65109

66110
def __init__(self, config=None, **kwargs):
67-
# type: (Optional[Configuration], Dict[str, Any]) -> None
68-
config = config or self.create_config(**kwargs)
69-
policies = [ContentDecodePolicy(), config.retry_policy, config.logging_policy]
70-
endpoint = os.environ.get(MSI_ENDPOINT)
71-
if not endpoint:
72-
raise ValueError("expected environment variable {} has no value".format(MSI_ENDPOINT))
73-
self._client = AuthnClient(endpoint, config, policies, **kwargs)
74-
75-
@staticmethod
76-
def create_config(**kwargs):
77-
# type: (Dict[str, str]) -> Configuration
78-
timeout = kwargs.pop("connection_timeout", 2)
79-
config = Configuration(connection_timeout=timeout, **kwargs)
80-
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
81-
retries = kwargs.pop("retry_total", 5)
82-
config.retry_policy = RetryPolicy(
83-
retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
84-
)
85-
return config
111+
# type: (Optional[Configuration], Mapping[str, Any]) -> None
112+
endpoint = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
113+
self._endpoint_available = endpoint is not None
114+
if self._endpoint_available:
115+
super(MsiCredential, self).__init__( # type: ignore
116+
endpoint=endpoint, client_cls=AuthnClient, config=config, **kwargs
117+
)
86118

87119
def get_token(self, *scopes):
88120
# type: (*str) -> AccessToken
121+
if not self._endpoint_available:
122+
raise AuthenticationError("MSI endpoint unavailable")
123+
89124
if len(scopes) != 1:
90125
raise ValueError("this credential supports only one scope per request")
126+
91127
token = self._client.get_cached_token(scopes)
92128
if not token:
93-
secret = os.environ.get(MSI_SECRET)
94-
if not secret:
95-
raise AuthenticationError("{} environment variable has no value".format(MSI_SECRET))
96129
resource = scopes[0]
97130
if resource.endswith("/.default"):
98131
resource = resource[: -len("/.default")]
99-
# TODO: support user-assigned client id
100-
token = self._client.request_token(
101-
scopes,
102-
method="GET",
103-
headers={"secret": secret},
104-
params={"api-version": "2017-09-01", "resource": resource},
105-
)
132+
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
133+
if secret:
134+
# MSI_ENDPOINT and MSI_SECRET set -> App Service
135+
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
136+
else:
137+
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
138+
token = self._request_legacy_token(scopes=scopes, resource=resource)
106139
return token
140+
141+
def _request_app_service_token(self, scopes, resource, secret):
142+
params = {"api-version": "2017-09-01", "resource": resource}
143+
if self._client_id:
144+
params["client_id"] = self._client_id
145+
return self._client.request_token(scopes, method="GET", headers={"secret": secret}, params=params)
146+
147+
def _request_legacy_token(self, scopes, resource):
148+
form_data = {"resource": resource}
149+
if self._client_id:
150+
form_data["client_id"] = self._client_id
151+
return self._client.request_token(scopes, method="POST", form_data=form_data)

sdk/identity/azure-identity/azure/identity/aio/_authn_client.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
5-
# --------------------------------------------------------------------------
5+
# -------------------------------------------------------------------------
66
import time
77
from typing import Any, Dict, Iterable, Mapping, Optional
88

@@ -41,18 +41,17 @@ async def request_token(
4141
headers: Optional[Mapping[str, str]] = None,
4242
form_data: Optional[Mapping[str, str]] = None,
4343
params: Optional[Dict[str, str]] = None,
44+
**kwargs: Any
4445
) -> AccessToken:
4546
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
4647
request_time = int(time.time())
47-
response = await self._pipeline.run(request, stream=False)
48+
response = await self._pipeline.run(request, stream=False, **kwargs)
4849
token = self._deserialize_and_cache_token(response, scopes, request_time)
4950
return token
5051

5152
@staticmethod
5253
def create_config(**kwargs: Mapping[str, Any]) -> Configuration:
5354
config = Configuration(**kwargs)
5455
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
55-
config.retry_policy = AsyncRetryPolicy(
56-
retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
57-
)
56+
config.retry_policy = AsyncRetryPolicy(**kwargs)
5857
return config

0 commit comments

Comments
 (0)