Skip to content

Commit cb88462

Browse files
lochiiconnectivityDavid Freedmanrayluo
authored
Add support for acquiring a token with a pre-signed JWT (#271)
* Add support for acquiring a token with a client provided, pre-signed JWT. Useful for where the signing takes place externally for example using Azure Key Vault (AKV). AKV sample included. * Changes to parameter name for #271 * Address comment in #271 "No need to repeat this statement twice in both if and else" * merge rayluo / microsoft-authentication-library-for-python:patch1 * Update msal/application.py Co-authored-by: Ray Luo <[email protected]> * Update tests/test_e2e.py Co-authored-by: Ray Luo <[email protected]> * Resolve merge conflict Co-authored-by: David Freedman <[email protected]> Co-authored-by: Ray Luo <[email protected]>
1 parent 088aa54 commit cb88462

File tree

4 files changed

+186
-23
lines changed

4 files changed

+186
-23
lines changed

msal/application.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ def __init__(
131131
"The provided signature value did not match the expected signature value",
132132
you may try use only the leaf cert (in PEM/str format) instead.
133133
134+
*Added in version 1.13.0*:
135+
It can also be a completly pre-signed assertion that you've assembled yourself.
136+
Simply pass a container containing only the key "client_assertion", like this::
137+
138+
{
139+
"client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
140+
}
141+
134142
:param dict client_claims:
135143
*Added in version 0.5.0*:
136144
It is a dictionary of extra claims that would be signed by
@@ -391,28 +399,32 @@ def _build_client(self, client_credential, authority):
391399
default_headers['x-app-ver'] = self.app_version
392400
default_body = {"client_info": 1}
393401
if isinstance(client_credential, dict):
394-
assert ("private_key" in client_credential
395-
and "thumbprint" in client_credential)
396-
headers = {}
397-
if 'public_certificate' in client_credential:
398-
headers["x5c"] = extract_certs(client_credential['public_certificate'])
399-
if not client_credential.get("passphrase"):
400-
unencrypted_private_key = client_credential['private_key']
401-
else:
402-
from cryptography.hazmat.primitives import serialization
403-
from cryptography.hazmat.backends import default_backend
404-
unencrypted_private_key = serialization.load_pem_private_key(
405-
_str2bytes(client_credential["private_key"]),
406-
_str2bytes(client_credential["passphrase"]),
407-
backend=default_backend(), # It was a required param until 2020
408-
)
409-
assertion = JwtAssertionCreator(
410-
unencrypted_private_key, algorithm="RS256",
411-
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
412-
client_assertion = assertion.create_regenerative_assertion(
413-
audience=authority.token_endpoint, issuer=self.client_id,
414-
additional_claims=self.client_claims or {})
402+
assert (("private_key" in client_credential
403+
and "thumbprint" in client_credential) or
404+
"client_assertion" in client_credential)
415405
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
406+
if 'client_assertion' in client_credential:
407+
client_assertion = client_credential['client_assertion']
408+
else:
409+
headers = {}
410+
if 'public_certificate' in client_credential:
411+
headers["x5c"] = extract_certs(client_credential['public_certificate'])
412+
if not client_credential.get("passphrase"):
413+
unencrypted_private_key = client_credential['private_key']
414+
else:
415+
from cryptography.hazmat.primitives import serialization
416+
from cryptography.hazmat.backends import default_backend
417+
unencrypted_private_key = serialization.load_pem_private_key(
418+
_str2bytes(client_credential["private_key"]),
419+
_str2bytes(client_credential["passphrase"]),
420+
backend=default_backend(), # It was a required param until 2020
421+
)
422+
assertion = JwtAssertionCreator(
423+
unencrypted_private_key, algorithm="RS256",
424+
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
425+
client_assertion = assertion.create_regenerative_assertion(
426+
audience=authority.token_endpoint, issuer=self.client_id,
427+
additional_claims=self.client_claims or {})
416428
else:
417429
default_body['client_secret'] = client_credential
418430
central_configuration = {

sample/vault_jwt_sample.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
The configuration file would look like this (sans those // comments):
3+
{
4+
"tenant": "your_tenant_name",
5+
// Your target tenant, DNS name
6+
"client_id": "your_client_id",
7+
// Target app ID in Azure AD
8+
"scope": ["https://graph.microsoft.com/.default"],
9+
// Specific to Client Credentials Grant i.e. acquire_token_for_client(),
10+
// you don't specify, in the code, the individual scopes you want to access.
11+
// Instead, you statically declared them when registering your application.
12+
// Therefore the only possible scope is "resource/.default"
13+
// (here "https://graph.microsoft.com/.default")
14+
// which means "the static permissions defined in the application".
15+
"vault_tenant": "your_vault_tenant_name",
16+
// Your Vault tenant may be different to your target tenant
17+
// If that's not the case, you can set this to the same
18+
// as "tenant"
19+
"vault_clientid": "your_vault_client_id",
20+
// Client ID of your vault app in your vault tenant
21+
"vault_clientsecret": "your_vault_client_secret",
22+
// Secret for your vault app
23+
"vault_url": "your_vault_url",
24+
// URL of your vault app
25+
"cert": "your_cert_name",
26+
// Name of your certificate in your vault
27+
"cert_thumb": "your_cert_thumbprint",
28+
// Thumbprint of your certificate
29+
"endpoint": "https://graph.microsoft.com/v1.0/users"
30+
// For this resource to work, you need to visit Application Permissions
31+
// page in portal, declare scope User.Read.All, which needs admin consent
32+
// https://github.com/Azure-Samples/ms-identity-python-daemon/blob/master/2-Call-MsGraph-WithCertificate/README.md
33+
}
34+
You can then run this sample with a JSON configuration file:
35+
python sample.py parameters.json
36+
"""
37+
38+
import base64
39+
import json
40+
import logging
41+
import requests
42+
import sys
43+
import time
44+
import uuid
45+
import msal
46+
47+
# Optional logging
48+
# logging.basicConfig(level=logging.DEBUG) # Enable DEBUG log for entire script
49+
# logging.getLogger("msal").setLevel(logging.INFO) # Optionally disable MSAL DEBUG logs
50+
51+
from azure.keyvault import KeyVaultClient, KeyVaultAuthentication
52+
from azure.common.credentials import ServicePrincipalCredentials
53+
from cryptography.hazmat.backends import default_backend
54+
from cryptography.hazmat.primitives import hashes
55+
56+
config = json.load(open(sys.argv[1]))
57+
58+
def auth_vault_callback(server, resource, scope):
59+
credentials = ServicePrincipalCredentials(
60+
client_id=config['vault_clientid'],
61+
secret=config['vault_clientsecret'],
62+
tenant=config['vault_tenant'],
63+
resource='https://vault.azure.net'
64+
)
65+
token = credentials.token
66+
return token['token_type'], token['access_token']
67+
68+
69+
def make_vault_jwt():
70+
71+
header = {
72+
'alg': 'RS256',
73+
'typ': 'JWT',
74+
'x5t': base64.b64encode(
75+
config['cert_thumb'].decode('hex'))
76+
}
77+
header_b64 = base64.b64encode(json.dumps(header).encode('utf-8'))
78+
79+
body = {
80+
'aud': "https://login.microsoftonline.com/%s/oauth2/token" %
81+
config['tenant'],
82+
'exp': (int(time.time()) + 600),
83+
'iss': config['client_id'],
84+
'jti': str(uuid.uuid4()),
85+
'nbf': int(time.time()),
86+
'sub': config['client_id']
87+
}
88+
body_b64 = base64.b64encode(json.dumps(body).encode('utf-8'))
89+
90+
full_b64 = b'.'.join([header_b64, body_b64])
91+
92+
client = KeyVaultClient(KeyVaultAuthentication(auth_vault_callback))
93+
chosen_hash = hashes.SHA256()
94+
hasher = hashes.Hash(chosen_hash, default_backend())
95+
hasher.update(full_b64)
96+
digest = hasher.finalize()
97+
signed_digest = client.sign(config['vault_url'],
98+
config['cert'], '', 'RS256',
99+
digest).result
100+
101+
full_token = b'.'.join([full_b64, base64.b64encode(signed_digest)])
102+
103+
return full_token
104+
105+
106+
authority = "https://login.microsoftonline.com/%s" % config['tenant']
107+
108+
app = msal.ConfidentialClientApplication(
109+
config['client_id'], authority=authority, client_credential={"client_assertion": make_vault_jwt()}
110+
)
111+
112+
# The pattern to acquire a token looks like this.
113+
result = None
114+
115+
# Firstly, looks up a token from cache
116+
# Since we are looking for token for the current app, NOT for an end user,
117+
# notice we give account parameter as None.
118+
result = app.acquire_token_silent(config["scope"], account=None)
119+
120+
if not result:
121+
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
122+
result = app.acquire_token_for_client(scopes=config["scope"])
123+
124+
if "access_token" in result:
125+
# Calling graph using the access token
126+
graph_data = requests.get( # Use token to call downstream service
127+
config["endpoint"],
128+
headers={'Authorization': 'Bearer ' + result['access_token']},).json()
129+
print("Graph API call result: %s" % json.dumps(graph_data, indent=2))
130+
else:
131+
print(result.get("error"))
132+
print(result.get("error_description"))
133+
print(result.get("correlation_id")) # You may need this when reporting a bug
134+

tests/test_client.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,15 @@ class TestClient(Oauth2TestCase):
8585
@classmethod
8686
def setUpClass(cls):
8787
http_client = MinimalHttpClient()
88-
if "client_certificate" in CONFIG:
88+
if "client_assertion" in CONFIG:
89+
cls.client = Client(
90+
CONFIG["openid_configuration"],
91+
CONFIG['client_id'],
92+
http_client=http_client,
93+
client_assertion=CONFIG["client_assertion"],
94+
client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT,
95+
)
96+
elif "client_certificate" in CONFIG:
8997
private_key_path = CONFIG["client_certificate"]["private_key_path"]
9098
with open(os.path.join(THIS_FOLDER, private_key_path)) as f:
9199
private_key = f.read() # Expecting PEM format

tests/test_e2e.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,16 @@ def test_subject_name_issuer_authentication(self):
344344
self.assertIn('access_token', result)
345345
self.assertCacheWorksForApp(result, scope)
346346

347+
def test_client_assertion(self):
348+
self.skipUnlessWithConfig(["client_id", "client_assertion"])
349+
self.app = msal.ConfidentialClientApplication(
350+
self.config['client_id'], authority=self.config["authority"],
351+
client_credential={"client_assertion": self.config["client_assertion"]},
352+
http_client=MinimalHttpClient())
353+
scope = self.config.get("scope", [])
354+
result = self.app.acquire_token_for_client(scope)
355+
self.assertIn('access_token', result)
356+
self.assertCacheWorksForApp(result, scope)
347357

348358
@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
349359
class DeviceFlowTestCase(E2eTestCase): # A leaf class so it will be run only once
@@ -882,4 +892,3 @@ def test_acquire_token_silent_with_an_empty_cache_should_return_none(self):
882892

883893
if __name__ == "__main__":
884894
unittest.main()
885-

0 commit comments

Comments
 (0)