|
12 | 12 |
|
13 | 13 | if TYPE_CHECKING:
|
14 | 14 | # pylint:disable=unused-import
|
15 |
| - from typing import Any, Dict, Optional |
| 15 | + from typing import Any, Mapping, Optional, Type |
16 | 16 | from azure.core.credentials import AccessToken
|
17 | 17 |
|
18 | 18 | from azure.core import Configuration
|
| 19 | +from azure.core.exceptions import HttpResponseError |
19 | 20 | from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, RetryPolicy
|
20 | 21 |
|
21 | 22 | from ._authn_client import AuthnClient
|
22 |
| -from .constants import Endpoints, MSI_ENDPOINT, MSI_SECRET |
| 23 | +from .constants import Endpoints, EnvironmentVariables |
23 | 24 | from .exceptions import AuthenticationError
|
24 | 25 |
|
25 | 26 |
|
26 |
| -class ImdsCredential: |
27 |
| - """Authenticates with a managed identity via the IMDS endpoint""" |
28 |
| - |
29 |
| - def __init__(self, config=None, **kwargs): |
30 |
| - # type: (Optional[Configuration], Dict[str, Any]) -> None |
| 27 | +class _ManagedIdentityBase(object): |
| 28 | + def __init__(self, endpoint, client_cls, config=None, client_id=None, **kwargs): |
| 29 | + # type: (str, Type, Optional[Configuration], Optional[str], Any) -> None |
| 30 | + self._client_id = client_id |
31 | 31 | config = config or self.create_config(**kwargs)
|
32 |
| - policies = [config.header_policy, ContentDecodePolicy(), config.logging_policy, config.retry_policy] |
33 |
| - self._client = AuthnClient(Endpoints.IMDS, config, policies, **kwargs) |
| 32 | + policies = [ContentDecodePolicy(), config.headers_policy, config.retry_policy, config.logging_policy] |
| 33 | + self._client = client_cls(endpoint, config, policies, **kwargs) |
34 | 34 |
|
35 | 35 | @staticmethod
|
36 | 36 | def create_config(**kwargs):
|
37 |
| - # type: (Dict[str, str]) -> Configuration |
| 37 | + # type: (Mapping[str, Any]) -> Configuration |
38 | 38 | timeout = kwargs.pop("connection_timeout", 2)
|
39 | 39 | config = Configuration(connection_timeout=timeout, **kwargs)
|
40 |
| - config.header_policy = HeadersPolicy(base_headers={"Metadata": "true"}, **kwargs) |
| 40 | + |
| 41 | + # retry is the only IO policy, so its class is a kwarg to increase async code sharing |
| 42 | + retry_policy = kwargs.pop("retry_policy", RetryPolicy) # type: ignore |
| 43 | + args = kwargs.copy() # combine kwargs and default retry settings in a Python 2-compatible way |
| 44 | + args.update(_ManagedIdentityBase._retry_settings) # type: ignore |
| 45 | + config.retry_policy = retry_policy(**args) # type: ignore |
| 46 | + |
| 47 | + # Metadata header is required by IMDS and in Cloud Shell; App Service ignores it |
| 48 | + config.headers_policy = HeadersPolicy(base_headers={"Metadata": "true"}, **kwargs) |
41 | 49 | config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
|
42 |
| - retries = kwargs.pop("retry_total", 5) |
43 |
| - config.retry_policy = RetryPolicy( |
44 |
| - retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs |
45 |
| - ) |
| 50 | + |
46 | 51 | return config
|
47 | 52 |
|
| 53 | + # given RetryPolicy's implementation, these settings most closely match the documented guidance for IMDS |
| 54 | + # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance |
| 55 | + _retry_settings = { |
| 56 | + "retry_total": 5, |
| 57 | + "retry_status": 5, |
| 58 | + "retry_backoff_factor": 4, |
| 59 | + "retry_backoff_max": 60, |
| 60 | + "retry_on_status_codes": [404, 429] + list(range(500, 600)), |
| 61 | + } |
| 62 | + |
| 63 | + |
| 64 | +class ImdsCredential(_ManagedIdentityBase): |
| 65 | + """Authenticates with a managed identity via the IMDS endpoint""" |
| 66 | + |
| 67 | + def __init__(self, config=None, **kwargs): |
| 68 | + # type: (Optional[Configuration], Any) -> None |
| 69 | + super(ImdsCredential, self).__init__(endpoint=Endpoints.IMDS, client_cls=AuthnClient, config=config, **kwargs) |
| 70 | + self._endpoint_available = None # type: Optional[bool] |
| 71 | + |
48 | 72 | def get_token(self, *scopes):
|
49 | 73 | # type: (*str) -> AccessToken
|
| 74 | + if self._endpoint_available is None: |
| 75 | + # Lacking another way to determine whether the IMDS endpoint is listening, |
| 76 | + # we send a request it would immediately reject (missing a required header), |
| 77 | + # setting a short timeout. |
| 78 | + try: |
| 79 | + self._client.request_token(scopes, method="GET", connection_timeout=0.3) |
| 80 | + self._endpoint_available = True |
| 81 | + except (AuthenticationError, HttpResponseError): |
| 82 | + # received a response a pipeline policy choked on (HttpResponseError) |
| 83 | + # or that couldn't be deserialized by AuthnClient (AuthenticationError) |
| 84 | + self._endpoint_available = True |
| 85 | + except Exception: # pylint:disable=broad-except |
| 86 | + # if anything else was raised, assume the endpoint is unavailable |
| 87 | + self._endpoint_available = False |
| 88 | + |
| 89 | + if not self._endpoint_available: |
| 90 | + raise AuthenticationError("IMDS endpoint unavailable") |
| 91 | + |
50 | 92 | if len(scopes) != 1:
|
51 | 93 | raise ValueError("this credential supports one scope per request")
|
| 94 | + |
52 | 95 | token = self._client.get_cached_token(scopes)
|
53 | 96 | if not token:
|
54 | 97 | resource = scopes[0]
|
55 | 98 | if resource.endswith("/.default"):
|
56 | 99 | resource = resource[: -len("/.default")]
|
57 |
| - token = self._client.request_token( |
58 |
| - scopes, method="GET", params={"api-version": "2018-02-01", "resource": resource} |
59 |
| - ) |
| 100 | + params = {"api-version": "2018-02-01", "resource": resource} |
| 101 | + if self._client_id: |
| 102 | + params["client_id"] = self._client_id |
| 103 | + token = self._client.request_token(scopes, method="GET", params=params) |
60 | 104 | return token
|
61 | 105 |
|
62 | 106 |
|
63 |
| -class MsiCredential: |
64 |
| - """Authenticates via the MSI endpoint""" |
| 107 | +class MsiCredential(_ManagedIdentityBase): |
| 108 | + """Authenticates via the MSI endpoint in App Service or Cloud Shell""" |
65 | 109 |
|
66 | 110 | def __init__(self, config=None, **kwargs):
|
67 |
| - # type: (Optional[Configuration], Dict[str, Any]) -> None |
68 |
| - config = config or self.create_config(**kwargs) |
69 |
| - policies = [ContentDecodePolicy(), config.retry_policy, config.logging_policy] |
70 |
| - endpoint = os.environ.get(MSI_ENDPOINT) |
71 |
| - if not endpoint: |
72 |
| - raise ValueError("expected environment variable {} has no value".format(MSI_ENDPOINT)) |
73 |
| - self._client = AuthnClient(endpoint, config, policies, **kwargs) |
74 |
| - |
75 |
| - @staticmethod |
76 |
| - def create_config(**kwargs): |
77 |
| - # type: (Dict[str, str]) -> Configuration |
78 |
| - timeout = kwargs.pop("connection_timeout", 2) |
79 |
| - config = Configuration(connection_timeout=timeout, **kwargs) |
80 |
| - config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) |
81 |
| - retries = kwargs.pop("retry_total", 5) |
82 |
| - config.retry_policy = RetryPolicy( |
83 |
| - retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs |
84 |
| - ) |
85 |
| - return config |
| 111 | + # type: (Optional[Configuration], Mapping[str, Any]) -> None |
| 112 | + endpoint = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) |
| 113 | + self._endpoint_available = endpoint is not None |
| 114 | + if self._endpoint_available: |
| 115 | + super(MsiCredential, self).__init__( # type: ignore |
| 116 | + endpoint=endpoint, client_cls=AuthnClient, config=config, **kwargs |
| 117 | + ) |
86 | 118 |
|
87 | 119 | def get_token(self, *scopes):
|
88 | 120 | # type: (*str) -> AccessToken
|
| 121 | + if not self._endpoint_available: |
| 122 | + raise AuthenticationError("MSI endpoint unavailable") |
| 123 | + |
89 | 124 | if len(scopes) != 1:
|
90 | 125 | raise ValueError("this credential supports only one scope per request")
|
| 126 | + |
91 | 127 | token = self._client.get_cached_token(scopes)
|
92 | 128 | if not token:
|
93 |
| - secret = os.environ.get(MSI_SECRET) |
94 |
| - if not secret: |
95 |
| - raise AuthenticationError("{} environment variable has no value".format(MSI_SECRET)) |
96 | 129 | resource = scopes[0]
|
97 | 130 | if resource.endswith("/.default"):
|
98 | 131 | resource = resource[: -len("/.default")]
|
99 |
| - # TODO: support user-assigned client id |
100 |
| - token = self._client.request_token( |
101 |
| - scopes, |
102 |
| - method="GET", |
103 |
| - headers={"secret": secret}, |
104 |
| - params={"api-version": "2017-09-01", "resource": resource}, |
105 |
| - ) |
| 132 | + secret = os.environ.get(EnvironmentVariables.MSI_SECRET) |
| 133 | + if secret: |
| 134 | + # MSI_ENDPOINT and MSI_SECRET set -> App Service |
| 135 | + token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret) |
| 136 | + else: |
| 137 | + # only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell) |
| 138 | + token = self._request_legacy_token(scopes=scopes, resource=resource) |
106 | 139 | return token
|
| 140 | + |
| 141 | + def _request_app_service_token(self, scopes, resource, secret): |
| 142 | + params = {"api-version": "2017-09-01", "resource": resource} |
| 143 | + if self._client_id: |
| 144 | + params["client_id"] = self._client_id |
| 145 | + return self._client.request_token(scopes, method="GET", headers={"secret": secret}, params=params) |
| 146 | + |
| 147 | + def _request_legacy_token(self, scopes, resource): |
| 148 | + form_data = {"resource": resource} |
| 149 | + if self._client_id: |
| 150 | + form_data["client_id"] = self._client_id |
| 151 | + return self._client.request_token(scopes, method="POST", form_data=form_data) |
0 commit comments