Skip to content

Commit b0bd437

Browse files
authored
Synchronous device code credential (Azure#6464)
1 parent a08c25a commit b0bd437

File tree

9 files changed

+182
-19
lines changed

9 files changed

+182
-19
lines changed

sdk/identity/azure-identity/azure/identity/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
CertificateCredential,
88
ChainedTokenCredential,
99
ClientSecretCredential,
10+
DeviceCodeCredential,
1011
EnvironmentCredential,
1112
ManagedIdentityCredential,
1213
UsernamePasswordCredential,
@@ -35,6 +36,7 @@ def __init__(self, **kwargs):
3536
"ChainedTokenCredential",
3637
"ClientSecretCredential",
3738
"DefaultAzureCredential",
39+
"DeviceCodeCredential",
3840
"EnvironmentCredential",
3941
"InteractiveBrowserCredential",
4042
"ManagedIdentityCredential",

sdk/identity/azure-identity/azure/identity/_internal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
from .auth_code_redirect_handler import AuthCodeRedirectServer
6+
from .exception_wrapper import wrap_exceptions
67
from .msal_credentials import ConfidentialClientCredential, PublicClientCredential
78
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import functools
6+
7+
from six import raise_from
8+
9+
from azure.core.exceptions import ClientAuthenticationError
10+
11+
12+
def wrap_exceptions(fn):
13+
"""Prevents leaking exceptions defined outside azure-core by raising ClientAuthenticationError from them."""
14+
15+
@functools.wraps(fn)
16+
def wrapper(*args, **kwargs):
17+
try:
18+
return fn(*args, **kwargs)
19+
except ClientAuthenticationError:
20+
raise
21+
except Exception as ex:
22+
auth_error = ClientAuthenticationError(message="Authentication failed: {}".format(ex))
23+
raise_from(auth_error, ex)
24+
25+
return wrapper

sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from azure.core.credentials import AccessToken
1313
from azure.core.exceptions import ClientAuthenticationError
1414

15+
from .exception_wrapper import wrap_exceptions
1516
from .msal_transport_adapter import MsalTransportAdapter
1617

1718
try:
@@ -75,6 +76,7 @@ def _create_app(self, cls):
7576
class ConfidentialClientCredential(MsalCredential):
7677
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API"""
7778

79+
@wrap_exceptions
7880
def get_token(self, *scopes):
7981
# type: (str) -> AccessToken
8082

sdk/identity/azure-identity/azure/identity/browser_auth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from azure.core.credentials import AccessToken
1919
from azure.core.exceptions import ClientAuthenticationError
2020

21-
from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential
21+
from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential, wrap_exceptions
2222

2323

2424
class InteractiveBrowserCredential(ConfidentialClientCredential):
@@ -48,6 +48,7 @@ def __init__(self, client_id, client_secret, **kwargs):
4848
client_id=client_id, client_credential=client_secret, authority=authority, **kwargs
4949
)
5050

51+
@wrap_exceptions
5152
def get_token(self, *scopes):
5253
# type: (str) -> AccessToken
5354
"""

sdk/identity/azure-identity/azure/identity/credentials.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ._authn_client import AuthnClient
1717
from ._base import ClientSecretCredentialBase, CertificateCredentialBase
18-
from ._internal import PublicClientCredential
18+
from ._internal import PublicClientCredential, wrap_exceptions
1919
from ._managed_identity import ImdsCredential, MsiCredential
2020
from .constants import Endpoints, EnvironmentVariables
2121

@@ -26,8 +26,9 @@
2626

2727
if TYPE_CHECKING:
2828
# pylint:disable=unused-import
29-
from typing import Any, Dict, Mapping, Optional, Union
29+
from typing import Any, Callable, Dict, Mapping, Optional, Union
3030
from azure.core.credentials import TokenCredential
31+
3132
EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"]
3233

3334
# pylint:disable=too-few-public-methods
@@ -249,6 +250,86 @@ def _get_error_message(history):
249250
return "No valid token received. {}".format(". ".join(attempts))
250251

251252

253+
class DeviceCodeCredential(PublicClientCredential):
254+
"""
255+
Authenticates users through the device code flow. When ``get_token`` is called, this credential acquires a
256+
verification URL and code from Azure Active Directory. A user must browse to the URL, enter the code, and
257+
authenticate with Directory. If the user authenticates successfully, the credential receives an access token.
258+
259+
This credential doesn't cache tokens--each ``get_token`` call begins a new authentication flow.
260+
261+
For more information about the device code flow, see Azure Active Directory documentation:
262+
https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code
263+
264+
:param str client_id: the application's ID
265+
:param prompt_callback: (optional) A callback enabling control of how authentication instructions are presented.
266+
If not provided, the credential will print instructions to stdout.
267+
:type prompt_callback: A callable accepting arguments (``verification_uri``, ``user_code``, ``expires_in``):
268+
- ``verification_uri`` (str) the URL the user must visit
269+
- ``user_code`` (str) the code the user must enter there
270+
- ``expires_in`` (int) the number of seconds the code will be valid
271+
272+
**Keyword arguments:**
273+
274+
- *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the
275+
'organizations' tenant, which supports only Azure Active Directory work or school accounts.
276+
277+
- *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device code
278+
as set by Azure Active Directory, which also prevails when ``timeout`` is longer.
279+
280+
"""
281+
282+
def __init__(self, client_id, prompt_callback=None, **kwargs):
283+
# type: (str, Optional[Callable[[str, str], None]], Any) -> None
284+
self._timeout = kwargs.pop("timeout", None) # type: Optional[int]
285+
self._prompt_callback = prompt_callback
286+
super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs)
287+
288+
@wrap_exceptions
289+
def get_token(self, *scopes):
290+
# type (*str) -> AccessToken
291+
"""
292+
Request an access token for `scopes`. This credential won't cache the token. Each call begins a new
293+
authentication flow.
294+
295+
:param str scopes: desired scopes for the token
296+
:rtype: :class:`azure.core.credentials.AccessToken`
297+
:raises: :class:`azure.core.exceptions.ClientAuthenticationError`
298+
"""
299+
300+
# MSAL requires scopes be a list
301+
scopes = list(scopes) # type: ignore
302+
now = int(time.time())
303+
304+
app = self._get_app()
305+
flow = app.initiate_device_flow(scopes)
306+
if "error" in flow:
307+
raise ClientAuthenticationError(
308+
message="Couldn't begin authentication: {}".format(flow.get("error_description") or flow.get("error"))
309+
)
310+
311+
if self._prompt_callback:
312+
self._prompt_callback(flow["verification_uri"], flow["user_code"], flow["expires_in"])
313+
else:
314+
print(flow["message"])
315+
316+
if self._timeout is not None and self._timeout < flow["expires_in"]:
317+
deadline = now + self._timeout
318+
result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline)
319+
else:
320+
result = app.acquire_token_by_device_flow(flow)
321+
322+
if "access_token" not in result:
323+
if result.get("error") == "authorization_pending":
324+
message = "Timed out waiting for user to authenticate"
325+
else:
326+
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error"))
327+
raise ClientAuthenticationError(message=message)
328+
329+
token = AccessToken(result["access_token"], now + int(result["expires_in"]))
330+
return token
331+
332+
252333
class UsernamePasswordCredential(PublicClientCredential):
253334
"""
254335
Authenticates a user with a username and password. In general, Microsoft doesn't recommend this kind of
@@ -267,8 +348,9 @@ class UsernamePasswordCredential(PublicClientCredential):
267348
268349
**Keyword arguments:**
269350
270-
*tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the
271-
'organizations' tenant.
351+
- **tenant (str)** - a tenant ID or a domain associated with a tenant. If not provided, defaults to the
352+
'organizations' tenant.
353+
272354
"""
273355

274356
def __init__(self, client_id, username, password, **kwargs):
@@ -277,6 +359,7 @@ def __init__(self, client_id, username, password, **kwargs):
277359
self._username = username
278360
self._password = password
279361

362+
@wrap_exceptions
280363
def get_token(self, *scopes):
281364
# type (*str) -> AccessToken
282365
"""

sdk/identity/azure-identity/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,6 @@
6969
"azure",
7070
]
7171
),
72-
install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1"],
72+
install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1", "six>=1.6"],
7373
extras_require={":python_version<'3.0'": ["azure-nspkg"], ":python_version<'3.5'": ["typing"]},
7474
)

sdk/identity/azure-identity/tests/test_identity.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,21 @@
1313
except ImportError: # python < 3.3
1414
from mock import Mock, patch
1515

16-
import pytest
1716
from azure.core.credentials import AccessToken
1817
from azure.core.exceptions import ClientAuthenticationError
1918
from azure.identity import (
19+
ChainedTokenCredential,
2020
ClientSecretCredential,
2121
DefaultAzureCredential,
22+
DeviceCodeCredential,
2223
EnvironmentCredential,
2324
ManagedIdentityCredential,
24-
ChainedTokenCredential,
2525
InteractiveBrowserCredential,
2626
UsernamePasswordCredential,
2727
)
2828
from azure.identity._managed_identity import ImdsCredential
2929
from azure.identity.constants import EnvironmentVariables
30+
import pytest
3031

3132
from helpers import mock_response, Request, validating_transport
3233

@@ -123,11 +124,6 @@ def test_client_secret_environment_credential(monkeypatch):
123124
assert token.token == access_token
124125

125126

126-
def test_environment_credential_error():
127-
with pytest.raises(ClientAuthenticationError):
128-
EnvironmentCredential().get_token("scope")
129-
130-
131127
def test_credential_chain_error_message():
132128
def raise_authn_error(message):
133129
raise ClientAuthenticationError(message)
@@ -244,6 +240,65 @@ def test_default_credential():
244240
DefaultAzureCredential()
245241

246242

243+
def test_device_code_credential():
244+
expected_token = "access-token"
245+
user_code = "user-code"
246+
verification_uri = "verification-uri"
247+
expires_in = 42
248+
249+
transport = validating_transport(
250+
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
251+
responses=[
252+
# expected requests: discover tenant, start device code flow, poll for completion
253+
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
254+
mock_response(
255+
json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in}
256+
),
257+
mock_response(
258+
json_payload={
259+
"access_token": expected_token,
260+
"expires_in": expires_in,
261+
"scope": "scope",
262+
"token_type": "Bearer",
263+
"refresh_token": "_",
264+
}
265+
),
266+
],
267+
)
268+
269+
callback = Mock()
270+
credential = DeviceCodeCredential(
271+
client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False
272+
)
273+
274+
token = credential.get_token("scope")
275+
assert token.token == expected_token
276+
277+
# prompt_callback should have been called as documented
278+
assert callback.call_count == 1
279+
assert callback.call_args[0] == (verification_uri, user_code, expires_in)
280+
281+
282+
def test_device_code_credential_timeout():
283+
transport = validating_transport(
284+
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
285+
responses=[
286+
# expected requests: discover tenant, start device code flow, poll for completion
287+
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
288+
mock_response(json_payload={"device_code": "_", "user_code": "_", "verification_uri": "_"}),
289+
mock_response(json_payload={"error": "authorization_pending"}),
290+
],
291+
)
292+
293+
credential = DeviceCodeCredential(
294+
client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.1, instance_discovery=False
295+
)
296+
297+
with pytest.raises(ClientAuthenticationError) as ex:
298+
credential.get_token("scope")
299+
assert "timed out" in ex.value.message.lower()
300+
301+
247302
@patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser
248303
def test_interactive_credential():
249304
oauth_state = "state"

sdk/identity/azure-identity/tests/test_identity_async.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,6 @@ async def test_client_secret_environment_credential(monkeypatch):
121121
assert token.token == access_token
122122

123123

124-
@pytest.mark.asyncio
125-
async def test_environment_credential_error():
126-
with pytest.raises(ClientAuthenticationError):
127-
await EnvironmentCredential().get_token("scope")
128-
129-
130124
@pytest.mark.asyncio
131125
async def test_credential_chain_error_message():
132126
def raise_authn_error(message):

0 commit comments

Comments
 (0)