Skip to content

Commit a185274

Browse files
committed
ClientCredentialRequest logic are now in its own class
1 parent fc9cd84 commit a185274

File tree

3 files changed

+72
-56
lines changed

3 files changed

+72
-56
lines changed

msal/application.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import request
1+
from .client_credential import ClientCredentialRequest
22

33

44
class ClientApplication(object):
@@ -35,7 +35,7 @@ def __init__(self, client_id, client_credential, user_token_cache, **kwargs):
3535
self.app_token_cache = None # TODO
3636

3737
def acquire_token_for_client(self, scope, policy=''):
38-
return request.ClientCredentialRequest(
38+
return ClientCredentialRequest(
3939
client_id=self.client_id, client_credential=self.client_credential,
4040
scope=scope, policy=policy, authority=self.authority).run()
4141

msal/client_credential.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import time
2+
import binascii
3+
import base64
4+
import uuid
5+
6+
import jwt
7+
8+
from .oauth2 import ClientCredentialGrant
9+
from .request import BaseRequest
10+
11+
12+
class ClientCredentialRequest(BaseRequest):
13+
def __init__(self, **kwargs):
14+
super(ClientCredentialRequest, self).__init__(**kwargs)
15+
self.grant = ClientCredentialGrant(
16+
self.client_id, token_endpoint=self.token_endpoint)
17+
18+
def get_token(self):
19+
if isinstance(self.client_credential, dict):
20+
return self.get_token_by_certificate(
21+
self.client_credential['certificate'],
22+
self.client_credential['thumbprint'])
23+
else:
24+
return self.get_token_by_secret(self.client_credential)
25+
26+
def get_token_by_secret(self, secret):
27+
return self.grant.get_token(scope=self.scope, client_secret=secret)
28+
29+
def get_token_by_certificate(self, pem, thumbprint):
30+
JWT_BEARER = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer'
31+
assertion = create_jwt_assertion(
32+
pem, thumbprint, self.grant.token_endpoint, self.client_id)
33+
return self.grant.get_token(
34+
client_assertion_type=JWT_BEARER, client_assertion=assertion,
35+
scope=self.scope)
36+
37+
38+
def create_jwt_assertion(
39+
private_pem, thumbprint, audience, issuer,
40+
subject=None, # If None is specified, the value of issuer will be used
41+
not_valid_before=None, # If None, the current time will be used
42+
jwt_id=None): # If None is specified, a UUID will be generated
43+
assert '-----BEGIN PRIVATE KEY-----' in private_pem, "Need a standard PEM"
44+
nbf = time.time() if not_valid_before is None else not_valid_before
45+
payload = { # key names are all from JWT standard names
46+
'aud': audience,
47+
'iss': issuer,
48+
'sub': subject or issuer,
49+
'nbf': nbf,
50+
'exp': nbf + 10*60, # 10 minutes
51+
'jti': str(uuid.uuid4()) if jwt_id is None else jwt_id,
52+
}
53+
# Per http://self-issued.info/docs/draft-jones-json-web-token-01.html
54+
h = {'x5t': base64.urlsafe_b64encode(binascii.a2b_hex(thumbprint)).decode()}
55+
return jwt.encode(payload, private_pem, algorithm='RS256', headers=h)
56+

msal/request.py

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
import time
22

3-
from . import oauth2
43
from .exceptions import MsalServiceError
54

65

76
class BaseRequest(object):
8-
TOKEN_ENDPOINT_PATH = 'oauth2/v2.0/token'
97

108
def __init__(
11-
self, authority=None, token_cache=None, scope=None, policy="",
9+
self, authority=None, token_cache=None,
10+
scope=None, policy="", # TBD: If scope and policy are paramters
11+
# of both high level ClientApplication.acquire_token()
12+
# and low level oauth2.*Grant.get_token(),
13+
# shouldn't they be the parameters for run()?
1214
client_id=None, client_credential=None, authenticator=None,
1315
support_adfs=False, restrict_to_single_user=False):
1416
if not scope:
1517
raise ValueError("scope cannot be empty")
1618
self.__dict__.update(locals())
1719

20+
# TODO: Temporary solution here
21+
self.token_endpoint = authority
22+
if authority.startswith('https://login.microsoftonline.com/common/'):
23+
self.token_endpoint += 'oauth2/v2.0/token'
24+
elif authority.startswith('https://login.windows.net/'): # AAD?
25+
self.token_endpoint += 'oauth2/token'
26+
if policy:
27+
self.token_endpoint += '?policy={}'.format(policy)
28+
1829
def run(self):
1930
"""Returns a dictionary, which typically contains following keys:
2031
@@ -55,54 +66,3 @@ def __timestamp(self, seconds_from_now=None): # Returns timestamp IN SECOND
5566
def get_token(self):
5667
raise NotImplemented("Use proper sub-class instead")
5768

58-
59-
class ClientCredentialRequest(BaseRequest):
60-
def get_token(self):
61-
token_endpoint="%s%s?policy=%s" % (
62-
self.authority, self.TOKEN_ENDPOINT_PATH, self.policy)
63-
if isinstance(self.client_credential, dict): # certification logic
64-
return ClientCredentialCertificateGrant(
65-
self.client_id, token_endpoint=token_endpoint
66-
).get_token(
67-
self.client_credential['certificate'],
68-
self.client_credential['thumbprint'],
69-
scope=self.scope)
70-
else:
71-
return oauth2.ClientCredentialGrant(
72-
self.client_id, token_endpoint=token_endpoint
73-
).get_token(
74-
scope=self.scope, client_secret=self.client_credential)
75-
76-
77-
import binascii
78-
import base64
79-
import uuid
80-
81-
import jwt
82-
83-
84-
def create(private_pem, thumbprint, audience, issuer, subject=None):
85-
assert private_pem.startswith('-----BEGIN PRIVATE KEY-----'), "Wrong format"
86-
payload = { # key names are all from JWT standard names
87-
'aud': audience,
88-
'iss': issuer,
89-
'sub': subject or issuer,
90-
'nbf': time.time(),
91-
'exp': time.time() + 10*60, # 10 minutes
92-
'jti': str(uuid.uuid4()),
93-
}
94-
# http://self-issued.info/docs/draft-jones-json-web-token-01.html
95-
h = {'x5t': base64.urlsafe_b64encode(binascii.a2b_hex(thumbprint)).decode()}
96-
return jwt.encode(payload, private_pem, algorithm='RS256', headers=h) # .decode() # TODO: Is the decode() really necessary?
97-
98-
99-
class ClientCredentialCertificateGrant(oauth2.ClientCredentialGrant):
100-
def get_token(self, pem, thumbprint, scope=None):
101-
JWT_BEARER = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer'
102-
assertion = create(pem, thumbprint, self.token_endpoint, self.client_id)
103-
import logging
104-
logging.warning('assertion: %s', assertion)
105-
return super(ClientCredentialCertificateGrant, self).get_token(
106-
client_assertion_type=JWT_BEARER, client_assertion=assertion,
107-
scope=scope)
108-

0 commit comments

Comments
 (0)