Skip to content

Commit 9b98575

Browse files
authored
[Identity] Allow use of client assertion in OBO cred (#35812)
The new kwarg `client_assertion_func` was added to allow passing in client assertion callbacks to OBO credential. Signed-off-by: Paul Van Eck <[email protected]>
1 parent 0b99ee1 commit 9b98575

File tree

9 files changed

+229
-36
lines changed

9 files changed

+229
-36
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44

55
### Features Added
66

7-
### Breaking Changes
8-
9-
### Bugs Fixed
10-
11-
### Other Changes
7+
- `OnBehalfOfCredential` now supports client assertion callbacks through the `client_assertion_func` keyword argument. This enables authenticating with client assertions such as federated credentials. ([#35812](https://github.com/Azure/azure-sdk-for-python/pull/35812))
128

139
## 1.17.0b1 (2024-05-13)
1410

sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import time
6-
from typing import Any, Optional
6+
from typing import Any, Optional, Callable, Union, Dict
77

88
import msal
99

@@ -28,14 +28,18 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
2828
description of the on-behalf-of flow.
2929
3030
:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
31-
:param str client_id: The service principal's client ID
31+
:param str client_id: The service principal's client ID.
3232
:keyword str client_secret: Optional. A client secret to authenticate the service principal.
33-
Either **client_secret** or **client_certificate** must be provided.
33+
One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided.
3434
:keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including
35-
the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must
36-
be provided.
35+
the private key to authenticate the service principal. One of **client_secret**, **client_certificate**,
36+
or **client_assertion_func** must be provided.
37+
:keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the
38+
application to Microsoft Entra ID. This function is called each time the credential requests a token. It must
39+
return a valid assertion for the target resource.
40+
:paramtype client_assertion_func: Callable[[], str]
3741
:keyword str user_assertion: Required. The access token the credential will use as the user assertion when
38-
requesting on-behalf-of tokens
42+
requesting on-behalf-of tokens.
3943
4044
:keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com",
4145
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
@@ -65,14 +69,31 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
6569
:caption: Create an OnBehalfOfCredential.
6670
"""
6771

68-
def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None:
69-
self._assertion = kwargs.pop("user_assertion", None)
72+
def __init__(
73+
self,
74+
tenant_id: str,
75+
client_id: str,
76+
*,
77+
client_certificate: Optional[bytes] = None,
78+
client_secret: Optional[str] = None,
79+
client_assertion_func: Optional[Callable[[], str]] = None,
80+
user_assertion: str,
81+
**kwargs: Any
82+
) -> None:
83+
self._assertion = user_assertion
7084
if not self._assertion:
71-
raise TypeError('"user_assertion" is required.')
72-
client_certificate = kwargs.pop("client_certificate", None)
73-
client_secret = kwargs.pop("client_secret", None)
85+
raise TypeError('"user_assertion" must not be empty.')
7486

75-
if client_certificate:
87+
if client_assertion_func:
88+
if client_certificate or client_secret:
89+
raise ValueError(
90+
"It is invalid to specify more than one of the following: "
91+
'"client_assertion_func", "client_certificate" or "client_secret".'
92+
)
93+
credential: Union[str, Dict[str, Any]] = {
94+
"client_assertion": client_assertion_func,
95+
}
96+
elif client_certificate:
7697
if client_secret:
7798
raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.')
7899
try:
@@ -86,7 +107,7 @@ def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None:
86107
elif client_secret:
87108
credential = client_secret
88109
else:
89-
raise TypeError('Either "client_certificate" or "client_secret" must be provided')
110+
raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided')
90111

91112
super(OnBehalfOfCredential, self).__init__(
92113
client_id=client_id, client_credential=credential, tenant_id=tenant_id, **kwargs

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def _get_client_secret_request(self, scopes: Iterable[str], secret: str, **kwarg
267267
def _get_on_behalf_of_request(
268268
self,
269269
scopes: Iterable[str],
270-
client_credential: Union[str, AadClientCertificate],
270+
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
271271
user_assertion: str,
272272
**kwargs: Any
273273
) -> HttpRequest:
@@ -288,6 +288,10 @@ def _get_on_behalf_of_request(
288288
if isinstance(client_credential, AadClientCertificate):
289289
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
290290
data["client_assertion_type"] = JWT_BEARER_ASSERTION
291+
elif isinstance(client_credential, dict):
292+
func = client_credential["client_assertion"]
293+
data["client_assertion"] = func()
294+
data["client_assertion_type"] = JWT_BEARER_ASSERTION
291295
else:
292296
data["client_secret"] = client_credential
293297

@@ -318,7 +322,7 @@ def _get_refresh_token_request(self, scopes: Iterable[str], refresh_token: str,
318322
def _get_refresh_token_on_behalf_of_request(
319323
self,
320324
scopes: Iterable[str],
321-
client_credential: Union[str, AadClientCertificate],
325+
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
322326
refresh_token: str,
323327
**kwargs: Any
324328
) -> HttpRequest:
@@ -338,6 +342,10 @@ def _get_refresh_token_on_behalf_of_request(
338342
if isinstance(client_credential, AadClientCertificate):
339343
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
340344
data["client_assertion_type"] = JWT_BEARER_ASSERTION
345+
elif isinstance(client_credential, dict):
346+
func = client_credential["client_assertion"]
347+
data["client_assertion"] = func()
348+
data["client_assertion_type"] = JWT_BEARER_ASSERTION
341349
else:
342350
data["client_secret"] = client_credential
343351
request = self._post(data, **kwargs)

sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MsalCredential: # pylint: disable=too-many-instance-attributes
2525
def __init__(
2626
self,
2727
client_id: str,
28-
client_credential: Optional[Union[str, Dict[str, str]]] = None,
28+
client_credential: Optional[Union[str, Dict[str, Any]]] = None,
2929
*,
3030
additionally_allowed_tenants: Optional[List[str]] = None,
3131
authority: Optional[str] = None,

sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import logging
6-
from typing import Optional, Union, Any
6+
from typing import Optional, Union, Any, Dict, Callable
77

88
from azure.core.exceptions import ClientAuthenticationError
99
from azure.core.credentials import AccessToken
@@ -25,14 +25,18 @@ class OnBehalfOfCredential(AsyncContextManager, GetTokenMixin):
2525
description of the on-behalf-of flow.
2626
2727
:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
28-
:param str client_id: The service principal's client ID
28+
:param str client_id: The service principal's client ID.
2929
:keyword str client_secret: Optional. A client secret to authenticate the service principal.
30-
Either **client_secret** or **client_certificate** must be provided.
30+
One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided.
3131
:keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including
32-
the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must
33-
be provided.
32+
the private key to authenticate the service principal. One of **client_secret**, **client_certificate**,
33+
or **client_assertion_func** must be provided.
34+
:keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the
35+
application to Microsoft Entra ID. This function is called each time the credential requests a token. It must
36+
return a valid assertion for the target resource.
37+
:paramtype client_assertion_func: Callable[[], str]
3438
:keyword str user_assertion: Required. The access token the credential will use as the user assertion when
35-
requesting on-behalf-of tokens
39+
requesting on-behalf-of tokens.
3640
3741
:keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com",
3842
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
@@ -62,29 +66,39 @@ def __init__(
6266
*,
6367
client_certificate: Optional[bytes] = None,
6468
client_secret: Optional[str] = None,
69+
client_assertion_func: Optional[Callable[[], str]] = None,
6570
user_assertion: str,
6671
**kwargs: Any
6772
) -> None:
6873
super().__init__()
6974
validate_tenant_id(tenant_id)
7075

7176
self._assertion = user_assertion
72-
73-
if client_certificate:
77+
if not self._assertion:
78+
raise TypeError('"user_assertion" must not be empty.')
79+
80+
if client_assertion_func:
81+
if client_certificate or client_secret:
82+
raise ValueError(
83+
"It is invalid to specify more than one of the following: "
84+
'"client_assertion_func", "client_certificate" or "client_secret".'
85+
)
86+
self._client_credential: Union[str, AadClientCertificate, Dict[str, Any]] = {
87+
"client_assertion": client_assertion_func,
88+
}
89+
elif client_certificate:
7490
if client_secret:
7591
raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.')
7692
try:
7793
cert = get_client_credential(None, kwargs.pop("password", None), client_certificate)
7894
except ValueError as ex:
7995
message = '"client_certificate" is not a valid certificate in PEM or PKCS12 format'
8096
raise ValueError(message) from ex
81-
self._client_credential: Union[str, AadClientCertificate] = AadClientCertificate(
82-
cert["private_key"], password=cert.get("passphrase")
83-
)
97+
self._client_credential = AadClientCertificate(cert["private_key"], password=cert.get("passphrase"))
8498
elif client_secret:
8599
self._client_credential = client_secret
86100
else:
87-
raise TypeError('Either "client_certificate" or "client_secret" must be provided')
101+
raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided')
88102

89103
# note AadClient handles "authority" and any pipeline kwargs
90104
self._client = AadClient(tenant_id, client_id, **kwargs)

sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import time
6-
from typing import Iterable, Optional, Union
6+
from typing import Iterable, Optional, Union, Dict, Any
77

88
from azure.core.credentials import AccessToken
99
from azure.core.pipeline import AsyncPipeline
@@ -57,15 +57,23 @@ async def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_tok
5757
return await self._run_pipeline(request, **kwargs)
5858

5959
async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-too-long
60-
self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], refresh_token: str, **kwargs
60+
self,
61+
scopes: Iterable[str],
62+
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
63+
refresh_token: str,
64+
**kwargs
6165
) -> AccessToken:
6266
request = self._get_refresh_token_on_behalf_of_request(
6367
scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs
6468
)
6569
return await self._run_pipeline(request, **kwargs)
6670

6771
async def obtain_token_on_behalf_of(
68-
self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], user_assertion: str, **kwargs
72+
self,
73+
scopes: Iterable[str],
74+
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
75+
user_assertion: str,
76+
**kwargs
6977
) -> AccessToken:
7078
request = self._get_on_behalf_of_request(
7179
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
"""
6+
FILE: on_behalf_of_client_assertion.py
7+
DESCRIPTION:
8+
This sample demonstrates the use of OnBehalfOfCredential to authenticate the Key Vault SecretClient using a managed
9+
identity as the client assertion. More information about the On-Behalf-Of flow can be found here:
10+
https://learn.microsoft.com/entra/identity-platform/v2-oauth2-on-behalf-of-flow.
11+
USAGE:
12+
python on_behalf_of_client_assertion.py
13+
14+
**Note** - This sample requires the `azure-keyvault-secrets` package.
15+
"""
16+
# [START obo_client_assertion]
17+
from azure.identity import OnBehalfOfCredential, ManagedIdentityCredential
18+
from azure.keyvault.secrets import SecretClient
19+
20+
21+
# Replace the following variables with your own values.
22+
tenant_id = "<tenant_id>"
23+
client_id = "<client_id>"
24+
user_assertion = "<user_assertion>"
25+
26+
managed_identity_credential = ManagedIdentityCredential()
27+
28+
29+
def get_managed_identity_token() -> str:
30+
# This function should return an access token obtained from a managed identity.
31+
access_token = managed_identity_credential.get_token("api://AzureADTokenExchange")
32+
return access_token.token
33+
34+
35+
credential = OnBehalfOfCredential(
36+
tenant_id=tenant_id,
37+
client_id=client_id,
38+
user_assertion=user_assertion,
39+
client_assertion_func=get_managed_identity_token,
40+
)
41+
42+
client = SecretClient(vault_url="https://<your-key-vault-name>.vault.azure.net/", credential=credential)
43+
# [END obo_client_assertion]

sdk/identity/azure-identity/tests/test_obo.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
1313
from azure.identity import OnBehalfOfCredential, UsernamePasswordCredential
1414
from azure.identity._constants import EnvironmentVariables
15+
from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION
1516
from azure.identity._internal.user_agent import USER_AGENT
1617
import pytest
1718
from urllib.parse import urlparse
@@ -228,3 +229,53 @@ def test_no_client_credential():
228229
"""The credential should raise ValueError when ctoring with no client_secret or client_certificate"""
229230
with pytest.raises(TypeError):
230231
credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion")
232+
233+
234+
def test_client_assertion_func():
235+
"""The credential should accept a client_assertion_func"""
236+
expected_client_assertion = "client-assertion"
237+
expected_user_assertion = "user-assertion"
238+
expected_token = "***"
239+
func_call_count = 0
240+
241+
def client_assertion_func():
242+
nonlocal func_call_count
243+
func_call_count += 1
244+
return expected_client_assertion
245+
246+
def send(request, **kwargs):
247+
parsed = urlparse(request.url)
248+
tenant = parsed.path.split("/")[1]
249+
if "/oauth2/v2.0/token" not in parsed.path:
250+
return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant))
251+
252+
assert request.data.get("client_assertion") == expected_client_assertion
253+
assert request.data.get("client_assertion_type") == JWT_BEARER_ASSERTION
254+
assert request.data.get("assertion") == expected_user_assertion
255+
return mock_response(json_payload=build_aad_response(access_token=expected_token))
256+
257+
transport = Mock(send=Mock(wraps=send))
258+
credential = OnBehalfOfCredential(
259+
"tenant-id",
260+
"client-id",
261+
client_assertion_func=client_assertion_func,
262+
user_assertion=expected_user_assertion,
263+
transport=transport,
264+
)
265+
266+
access_token = credential.get_token("scope")
267+
assert access_token.token == expected_token
268+
assert func_call_count == 1
269+
270+
271+
def test_client_assertion_func_with_client_certificate():
272+
"""The credential should raise ValueError when ctoring with both client_assertion_func and client_certificate"""
273+
with pytest.raises(ValueError) as ex:
274+
credential = OnBehalfOfCredential(
275+
"tenant-id",
276+
"client-id",
277+
client_assertion_func=lambda: "client-assertion",
278+
client_certificate=b"certificate",
279+
user_assertion="assertion",
280+
)
281+
assert "It is invalid to specify more than one of the following" in str(ex.value)

0 commit comments

Comments
 (0)