Skip to content

Commit 5f55e7d

Browse files
committed
CDT with bearer app token
1 parent 10a7c37 commit 5f55e7d

File tree

7 files changed

+231
-10
lines changed

7 files changed

+231
-10
lines changed

msal/application.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from __future__ import annotations
2+
import base64
3+
import datetime
24
import functools
35
import json
46
import time
@@ -166,6 +168,17 @@ def _preferred_browser():
166168
return None
167169

168170

171+
def _build_req_cnf(jwk:dict, remove_padding:bool = False) -> str:
172+
"""req_cnf usually requires base64url encoding.
173+
174+
https://datatracker.ietf.org/doc/html/draft-ietf-oauth-pop-key-distribution-07#section-4.2.1
175+
https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e967ebeb-9e9f-443e-857a-5208802943c2
176+
"""
177+
raw = json.dumps(jwk)
178+
encoded = base64.urlsafe_b64encode(raw.encode('utf-8')).decode('utf-8')
179+
return encoded.rstrip('=') if remove_padding else encoded
180+
181+
169182
class _ClientWithCcsRoutingInfo(Client):
170183

171184
def initiate_auth_code_flow(self, **kwargs):
@@ -232,6 +245,7 @@ class ClientApplication(object):
232245
_TOKEN_SOURCE_IDP = "identity_provider"
233246
_TOKEN_SOURCE_CACHE = "cache"
234247
_TOKEN_SOURCE_BROKER = "broker"
248+
_XMS_DS_NONCE = "xms_ds_nonce"
235249

236250
_enable_broker = False
237251
_AUTH_SCHEME_UNSUPPORTED = (
@@ -241,8 +255,17 @@ class ClientApplication(object):
241255

242256
_TOKEN_CACHE_DATA: dict[str, str] = { # field_in_data: field_in_cache
243257
"key_id": "key_id", # Some token types (SSH-certs, POP) are bound to a key
258+
"req_ds_cnf": "req_ds_cnf", # Used in CDT scenario
244259
}
245260

261+
@functools.lru_cache(maxsize=2)
262+
def __get_rsa_key(self, _bucket): # _bucket is used with lru_cache pattern
263+
from .crypto import _generate_rsa_key
264+
return _generate_rsa_key()
265+
266+
def _get_rsa_key(self, _bucket=None): # Return the same RSA key, cached for a day
267+
return self.__get_rsa_key(_bucket or datetime.date.today())
268+
246269
def __init__(
247270
self, client_id,
248271
client_credential=None, authority=None, validate_authority=True,
@@ -656,7 +679,12 @@ def __init__(
656679

657680
self._decide_broker(allow_broker, enable_pii_log)
658681
self.token_cache = token_cache or TokenCache()
659-
self.token_cache._set(data_to_at=self._TOKEN_CACHE_DATA)
682+
self.token_cache._set(
683+
data_to_at=self._TOKEN_CACHE_DATA,
684+
response_to_at={ # field_in_resp: field_in_cache
685+
"xms_ds_nonce": "xms_ds_nonce",
686+
},
687+
)
660688
self._region_configured = azure_region
661689
self._region_detected = None
662690
self.client, self._regional_client = self._build_client(
@@ -1559,6 +1587,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15591587
"expires_in": int(expires_in), # OAuth2 specs defines it as int
15601588
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
15611589
}
1590+
if self._XMS_DS_NONCE in entry: # CDT needs this
1591+
access_token_from_cache[self._XMS_DS_NONCE] = entry[
1592+
self._XMS_DS_NONCE]
15621593
if "refresh_on" in entry:
15631594
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
15641595
if int(entry["refresh_on"]) < now: # aging
@@ -2347,7 +2378,16 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
23472378
except that ``allow_broker`` parameter shall remain ``None``.
23482379
"""
23492380

2350-
def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
2381+
def acquire_token_for_client(
2382+
self,
2383+
scopes,
2384+
claims_challenge=None,
2385+
*,
2386+
delegation_constraints: Optional[list] = None,
2387+
delegation_confirmation_key=None, # A Cyprtography's RSAPrivateKey-like object
2388+
# TODO: Support ECC key? https://github.com/pyca/cryptography/issues/4093
2389+
**kwargs
2390+
):
23512391
"""Acquires token for the current confidential client, not for an end user.
23522392
23532393
Since MSAL Python 1.23, it will automatically look for token from cache,
@@ -2370,8 +2410,36 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
23702410
raise ValueError( # We choose to disallow force_refresh
23712411
"Historically, this method does not support force_refresh behavior. "
23722412
)
2373-
return _clean_up(self._acquire_token_silent_with_error(
2374-
scopes, None, claims_challenge=claims_challenge, **kwargs))
2413+
if delegation_constraints:
2414+
private_key = delegation_confirmation_key or self._get_rsa_key()
2415+
from .crypto import _convert_rsa_keys
2416+
_, jwk = _convert_rsa_keys(private_key)
2417+
result = _clean_up(self._acquire_token_silent_with_error(
2418+
scopes, None, claims_challenge=claims_challenge, data=dict(
2419+
kwargs.pop("data", {}),
2420+
req_ds_cnf=_build_req_cnf(jwk) # It is part of token cache key
2421+
if delegation_constraints else None,
2422+
),
2423+
**kwargs))
2424+
if delegation_constraints and not result.get("error"):
2425+
if not result.get(self._XMS_DS_NONCE): # Available in cached token, too
2426+
raise ValueError(
2427+
"The resource did not opt in to xms_ds_cnf claim. "
2428+
"After its opt-in, call this function again with "
2429+
"a new app object or a new delegation_confirmation_key"
2430+
# in order to invalidate the token in cache
2431+
)
2432+
import jwt # Lazy loading
2433+
cdt_envelope = jwt.encode({
2434+
"constraints": delegation_constraints,
2435+
self._XMS_DS_NONCE: result[self._XMS_DS_NONCE],
2436+
}, private_key, algorithm="PS256")
2437+
result["access_token"] = jwt.encode({
2438+
"t": result["access_token"],
2439+
"c": cdt_envelope,
2440+
}, None, algorithm=None, headers={"typ": "cdt+jwt"})
2441+
del result[self._XMS_DS_NONCE] # Caller shouldn't need to know that
2442+
return result
23752443

23762444
def _acquire_token_for_client(
23772445
self,

msal/crypto.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from cryptography.hazmat.primitives.asymmetric import rsa
2+
3+
4+
def _urlsafe_b64encode(n:int, bit_size:int) -> str:
5+
from base64 import urlsafe_b64encode
6+
return urlsafe_b64encode(n.to_bytes(
7+
length=int(bit_size/8),
8+
byteorder="big",
9+
)).decode("utf-8").rstrip("=")
10+
11+
12+
def _to_jwk(public_key: rsa.RSAPublicKey) -> dict:
13+
"""Equivalent to:
14+
15+
numbers = public_key.public_numbers()
16+
result = {
17+
"kty": "RSA",
18+
"n": _urlsafe_b64encode(numbers.n, public_key.key_size),
19+
"e": _urlsafe_b64encode(numbers.e, 24),
20+
}
21+
return result
22+
"""
23+
import jwt
24+
return jwt.get_algorithm_by_name( # PyJWT 2.5.0 https://github.com/jpadilla/pyjwt/releases/tag/2.5.0
25+
"RS256"
26+
).to_jwk(
27+
public_key,
28+
as_dict=True, # PyJWT 2.7.0 https://github.com/jpadilla/pyjwt/releases/tag/2.7.0
29+
)
30+
31+
def _convert_rsa_keys(private_key: rsa.RSAPrivateKey):
32+
return "pairs.private_bytes()", _to_jwk(private_key.public_key())
33+
34+
def _generate_rsa_key() -> rsa.RSAPrivateKey:
35+
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/#cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key
36+
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
37+

msal/token_cache.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import hashlib
23
import json
34
import threading
45
import time
@@ -82,6 +83,7 @@ def __init__(self):
8283
realm=None, target=None,
8384
# Note: New field(s) can be added here
8485
#key_id=None,
86+
req_ds_cnf=None,
8587
**ignored_payload_from_a_real_token:
8688
"-".join([ # Note: Could use a hash here to shorten key length
8789
home_account_id or "",
@@ -91,6 +93,13 @@ def __init__(self):
9193
realm or "",
9294
target or "",
9395
#key_id or "", # So ATs of different key_id can coexist
96+
hashlib.sha256(req_ds_cnf.encode()).hexdigest()
97+
# TODO: Could hash the entire key eventually.
98+
# But before that project, we better first
99+
# change the scope to use input scope
100+
# instead of response scope,
101+
# so that a search() can probably have O(1) hit.
102+
if req_ds_cnf else "", # CDT
94103
]).lower(),
95104
self.CredentialType.ID_TOKEN:
96105
lambda home_account_id=None, environment=None, client_id=None,

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ install_requires =
4444

4545
# MSAL does not use jwt.decode(),
4646
# therefore is insusceptible to CVE-2022-29217 so no need to bump to PyJWT 2.4+
47-
PyJWT[crypto]>=1.0.0,<3
47+
PyJWT[crypto]>=2.7.0,<3
4848

4949
# load_key_and_certificates() is available since 2.5
5050
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/serialization/#cryptography.hazmat.primitives.serialization.pkcs12.load_key_and_certificates

tests/test_application.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
22
# so this test_application file contains only unit tests without dependency.
3+
from __future__ import annotations
34
import json
45
import logging
56
import sys
@@ -11,6 +12,7 @@
1112
ClientApplication, PublicClientApplication, ConfidentialClientApplication,
1213
_str2bytes, _merge_claims_challenge_and_capabilities,
1314
)
15+
from msal.oauth2cli.oidc import decode_part
1416
from tests import unittest
1517
from tests.test_token_cache import build_id_token, build_response
1618
from tests.http_client import MinimalHttpClient, MinimalResponse
@@ -856,3 +858,76 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
856858
)
857859
self.assertEqual(result.get("error"), "broker_error")
858860

861+
862+
class CdtTestCase(unittest.TestCase):
863+
864+
def createConstraint(self, typ: str, action: str, targets: list[str]) -> dict:
865+
return {"ver": "1.0", "typ": typ, "a": action, "targets": targets}
866+
867+
def test_constraint_format(self):
868+
self.assertEqual([
869+
self.createConstraint("ns:usr", "create", ["guid1", "guid2"]),
870+
self.createConstraint("ns:app", "update", ["guid3", "guid4"]),
871+
self.createConstraint("ns:subscription", "read", ["guid5", "guid6"]),
872+
], [ # Format defined in https://microsoft-my.sharepoint-df.com/:w:/p/rohitshende/EZgP9niwOvhKn-CUbj1NgG4BTZ6FSD9_16vXvsaXTiUzkg?e=j5DcQu&nav=eyJoIjoiODU5NDAyNjI4In0
873+
{"ver": "1.0", "typ": "ns:usr", "a": "create", "targets": ["guid1", "guid2"]},
874+
{"ver": "1.0", "typ": "ns:app", "a": "update", "targets": ["guid3", "guid4"]},
875+
{"ver": "1.0", "typ": "ns:subscription", "a": "read", "targets": [
876+
"guid5", "guid6",
877+
]},
878+
], "Constraint format is correct") # MSAL actually accepts arbitrary JSON blob
879+
880+
def assertCdt(self, result: dict, constraints: list[dict]) -> None:
881+
self.assertIsNotNone(
882+
result.get("access_token"), "Encountered {}: {}".format(
883+
result.get("error"), result.get("error_description")))
884+
_expectancy = "The return value should look like a Bearer response"
885+
self.assertEqual(result["token_type"], "Bearer", _expectancy)
886+
self.assertNotIn("xms_ds_nonce", result, _expectancy)
887+
headers = json.loads(decode_part(result["access_token"].split(".")[0]))
888+
self.assertEqual(headers.get("typ"), "cdt+jwt", "typ should be cdt+jwt")
889+
payload = json.loads(decode_part(result["access_token"].split(".")[1]))
890+
self.assertIsNotNone(payload.get("t") and payload.get("c"))
891+
cdt_envelope = json.loads(decode_part(payload["c"].split(".")[1]))
892+
self.assertIn("xms_ds_nonce", cdt_envelope)
893+
self.assertEqual(cdt_envelope["constraints"], constraints)
894+
895+
def assertAppObtainsCdt(self, client_app, scopes) -> None:
896+
constraints1 = [self.createConstraint("ns:usr", "create", ["guid1"])]
897+
result = client_app.acquire_token_for_client(
898+
scopes, delegation_constraints=constraints1,
899+
)
900+
self.assertCdt(result, constraints1)
901+
902+
constraints2 = [self.createConstraint("ns:app", "update", ["guid2"])]
903+
result = client_app.acquire_token_for_client(
904+
scopes, delegation_constraints=constraints2,
905+
)
906+
self.assertEqual(result["token_source"], "cache", "App token Should hit cache")
907+
self.assertCdt(result, constraints2)
908+
909+
result = client_app.acquire_token_for_client(
910+
scopes, delegation_constraints=constraints2,
911+
delegation_confirmation_key=client_app._get_rsa_key("new"),
912+
)
913+
self.assertEqual(
914+
result["token_source"], "identity_provider",
915+
"Different key should result in a new app token")
916+
self.assertCdt(result, constraints2)
917+
918+
@patch("msal.authority.tenant_discovery", new=Mock(return_value={
919+
"authorization_endpoint": "https://contoso.com/placeholder",
920+
"token_endpoint": "https://contoso.com/placeholder",
921+
}))
922+
def test_acquire_token_for_client_should_return_a_cdt(self):
923+
app = msal.ConfidentialClientApplication("id", client_credential="secret")
924+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
925+
status_code=200, text=json.dumps({
926+
"token_type": "Bearer",
927+
"access_token": "app token",
928+
"expires_in": 3600,
929+
"xms_ds_nonce": "nonce",
930+
}))) as mocked_post:
931+
self.assertAppObtainsCdt(app, ["scope1", "scope2"])
932+
self.assertEqual(mocked_post.call_count, 2)
933+

tests/test_crypto.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from unittest import TestCase
2+
3+
from msal.crypto import _generate_rsa_key, _convert_rsa_keys
4+
5+
6+
class CryptoTestCase(TestCase):
7+
def test_key_generation(self):
8+
key = _generate_rsa_key()
9+
_, jwk = _convert_rsa_keys(key)
10+
self.assertEqual(jwk.get("kty"), "RSA")
11+
self.assertIsNotNone(jwk.get("n") and jwk.get("e"))
12+

tests/test_e2e.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
LAB_OBO_PUBLIC_CLIENT_ID=...
66
LAB_OBO_CONFIDENTIAL_CLIENT_ID=...
77
"""
8+
from __future__ import annotations
89
try:
910
from dotenv import load_dotenv # Use this only in local dev machine
1011
load_dotenv() # take environment variables from .env.
@@ -27,8 +28,10 @@
2728

2829
import msal
2930
from tests.http_client import MinimalHttpClient, MinimalResponse
31+
from tests.test_application import CdtTestCase
3032
from msal.oauth2cli import AuthCodeReceiver
3133
from msal.oauth2cli.oidc import decode_part
34+
from msal.application import _build_req_cnf
3235

3336
try:
3437
import pymsalruntime
@@ -533,7 +536,7 @@ def tearDownClass(cls):
533536
cls.session.close()
534537

535538
@classmethod
536-
def get_lab_app_object(cls, client_id=None, **query): # https://msidlab.com/swagger/index.html
539+
def get_lab_app_object(cls, client_id=None, **query) -> dict: # https://msidlab.com/swagger/index.html
537540
url = "https://msidlab.com/api/app/{}".format(client_id or "")
538541
resp = cls.session.get(url, params=query)
539542
result = resp.json()[0]
@@ -791,12 +794,12 @@ def test_user_account(self):
791794
self._test_user_account()
792795

793796

794-
def _data_for_pop(key):
795-
raw_req_cnf = json.dumps({"kid": key, "xms_ksl": "sw"})
797+
def _data_for_pop(key_id):
796798
return { # Sampled from Azure CLI's plugin connectedk8s
797799
'token_type': 'pop',
798-
'key_id': key,
799-
"req_cnf": base64.urlsafe_b64encode(raw_req_cnf.encode('utf-8')).decode('utf-8').rstrip('='),
800+
'key_id': key_id,
801+
"req_cnf": _build_req_cnf(
802+
{"kid": key_id, "xms_ksl": "sw"}, remove_padding=True),
800803
# Note: Sending raw_req_cnf without base64 encoding would result in an http 500 error
801804
} # See also https://github.com/Azure/azure-cli-extensions/blob/main/src/connectedk8s/azext_connectedk8s/_clientproxyutils.py#L86-L92
802805

@@ -817,6 +820,23 @@ def test_user_account(self):
817820
self._test_user_account()
818821

819822

823+
class CdtTestCase(LabBasedTestCase, CdtTestCase):
824+
def test_acquire_token_for_client_should_return_a_cdt(self):
825+
resource = self.get_lab_app_object( # This resource has opted in to CDT
826+
publicClient="no", signinAudience="AzureAdMyOrg")
827+
client_app = msal.ConfidentialClientApplication(
828+
# Any CCA can use a CDT, as long as the resource opted in for a CDT
829+
# Here we use the OBO app which is in same tenant as the resource.
830+
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
831+
client_credential=os.getenv("LAB_OBO_CLIENT_SECRET"),
832+
authority="{}{}.onmicrosoft.com".format(
833+
resource["authority"],
834+
resource["labName"].lower().rstrip(".com"),
835+
),
836+
)
837+
self.assertAppObtainsCdt(client_app, [f"{resource['appId']}/.default"])
838+
839+
820840
class WorldWideTestCase(LabBasedTestCase):
821841

822842
def test_aad_managed_user(self): # Pure cloud

0 commit comments

Comments
 (0)