Skip to content

Commit b520b3e

Browse files
committed
Implementing CCS Routing info
X-AnchorMailbox's value is case-insensitive Both auth code flow and interactive flow switch to client_info Add upn:username for ROPC per recent discussion
1 parent b446a5e commit b520b3e

File tree

3 files changed

+112
-6
lines changed

3 files changed

+112
-6
lines changed

msal/application.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import requests
1515

1616
from .oauth2cli import Client, JwtAssertionCreator
17+
from .oauth2cli.oidc import decode_part
1718
from .authority import Authority
1819
from .mex import send_request as mex_send_request
1920
from .wstrust_request import send_request as wst_send_request
@@ -111,6 +112,34 @@ def _preferred_browser():
111112
return None
112113

113114

115+
class _ClientWithCcsRoutingInfo(Client):
116+
117+
def initiate_auth_code_flow(self, **kwargs):
118+
return super(_ClientWithCcsRoutingInfo, self).initiate_auth_code_flow(
119+
client_info=1, # To be used as CSS Routing info
120+
**kwargs)
121+
122+
def obtain_token_by_auth_code_flow(
123+
self, auth_code_flow, auth_response, **kwargs):
124+
# Note: the obtain_token_by_browser() is also covered by this
125+
assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict)
126+
headers = kwargs.pop("headers", {})
127+
client_info = json.loads(
128+
decode_part(auth_response["client_info"])
129+
) if auth_response.get("client_info") else {}
130+
if "uid" in client_info and "utid" in client_info:
131+
# Note: The value of X-AnchorMailbox is also case-insensitive
132+
headers["X-AnchorMailbox"] = "Oid:{uid}@{utid}".format(**client_info)
133+
return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_auth_code_flow(
134+
auth_code_flow, auth_response, headers=headers, **kwargs)
135+
136+
def obtain_token_by_username_password(self, username, password, **kwargs):
137+
headers = kwargs.pop("headers", {})
138+
headers["X-AnchorMailbox"] = "upn:{}".format(username)
139+
return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_username_password(
140+
username, password, headers=headers, **kwargs)
141+
142+
114143
class ClientApplication(object):
115144

116145
ACQUIRE_TOKEN_SILENT_ID = "84"
@@ -481,7 +510,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
481510
authority.device_authorization_endpoint or
482511
urljoin(authority.token_endpoint, "devicecode"),
483512
}
484-
central_client = Client(
513+
central_client = _ClientWithCcsRoutingInfo(
485514
central_configuration,
486515
self.client_id,
487516
http_client=self.http_client,
@@ -506,7 +535,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
506535
regional_authority.device_authorization_endpoint or
507536
urljoin(regional_authority.token_endpoint, "devicecode"),
508537
}
509-
regional_client = Client(
538+
regional_client = _ClientWithCcsRoutingInfo(
510539
regional_configuration,
511540
self.client_id,
512541
http_client=self.http_client,
@@ -577,7 +606,7 @@ def initiate_auth_code_flow(
577606
3. and then relay this dict and subsequent auth response to
578607
:func:`~acquire_token_by_auth_code_flow()`.
579608
"""
580-
client = Client(
609+
client = _ClientWithCcsRoutingInfo(
581610
{"authorization_endpoint": self.authority.authorization_endpoint},
582611
self.client_id,
583612
http_client=self.http_client)
@@ -654,7 +683,7 @@ def get_authorization_request_url(
654683
self.http_client
655684
) if authority else self.authority
656685

657-
client = Client(
686+
client = _ClientWithCcsRoutingInfo(
658687
{"authorization_endpoint": the_authority.authorization_endpoint},
659688
self.client_id,
660689
http_client=self.http_client)
@@ -1178,6 +1207,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
11781207
key=lambda e: int(e.get("last_modification_time", "0")),
11791208
reverse=True):
11801209
logger.debug("Cache attempts an RT")
1210+
headers = telemetry_context.generate_headers()
1211+
if "home_account_id" in query: # Then use it as CCS Routing info
1212+
headers["X-AnchorMailbox"] = "Oid:{}".format( # case-insensitive value
1213+
query["home_account_id"].replace(".", "@"))
11811214
response = client.obtain_token_by_refresh_token(
11821215
entry, rt_getter=lambda token_item: token_item["secret"],
11831216
on_removing_rt=lambda rt_item: None, # Disable RT removal,
@@ -1189,7 +1222,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
11891222
skip_account_creation=True, # To honor a concurrent remove_account()
11901223
)),
11911224
scope=scopes,
1192-
headers=telemetry_context.generate_headers(),
1225+
headers=headers,
11931226
data=dict(
11941227
kwargs.pop("data", {}),
11951228
claims=_merge_claims_challenge_and_capabilities(

tests/test_ccs.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import unittest
2+
try:
3+
from unittest.mock import patch, ANY
4+
except:
5+
from mock import patch, ANY
6+
7+
from tests.http_client import MinimalResponse
8+
from tests.test_token_cache import build_response
9+
10+
import msal
11+
12+
13+
class TestCcsRoutingInfoTestCase(unittest.TestCase):
14+
15+
def test_acquire_token_by_auth_code_flow(self):
16+
app = msal.ClientApplication("client_id")
17+
state = "foo"
18+
flow = app.initiate_auth_code_flow(
19+
["some", "scope"], login_hint="[email protected]", state=state)
20+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
21+
status_code=400, text='{"error": "mock"}')) as mocked_method:
22+
app.acquire_token_by_auth_code_flow(flow, {
23+
"state": state,
24+
"code": "bar",
25+
"client_info": # MSAL asks for client_info, so it would be available
26+
"eyJ1aWQiOiJhYTkwNTk0OS1hMmI4LTRlMGEtOGFlYS1iMzJlNTNjY2RiNDEiLCJ1dGlkIjoiNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3In0",
27+
})
28+
self.assertEqual(
29+
"Oid:aa905949-a2b8-4e0a-8aea-b32e53ccdb41@72f988bf-86f1-41af-91ab-2d7cd011db47",
30+
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
31+
"CSS routing info should be derived from client_info")
32+
33+
# I've manually tested acquire_token_interactive. No need to automate it,
34+
# because it and acquire_token_by_auth_code_flow() share same code path.
35+
36+
def test_acquire_token_silent(self):
37+
uid = "foo"
38+
utid = "bar"
39+
client_id = "my_client_id"
40+
scopes = ["some", "scope"]
41+
authority_url = "https://login.microsoftonline.com/common"
42+
token_cache = msal.TokenCache()
43+
token_cache.add({ # Pre-populate the cache
44+
"client_id": client_id,
45+
"scope": scopes,
46+
"token_endpoint": "{}/oauth2/v2.0/token".format(authority_url),
47+
"response": build_response(
48+
access_token="an expired AT to trigger refresh", expires_in=-99,
49+
uid=uid, utid=utid, refresh_token="this is a RT"),
50+
}) # The add(...) helper populates correct home_account_id for future searching
51+
app = msal.ClientApplication(
52+
client_id, authority=authority_url, token_cache=token_cache)
53+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
54+
status_code=400, text='{"error": "mock"}')) as mocked_method:
55+
account = {"home_account_id": "{}.{}".format(uid, utid)}
56+
app.acquire_token_silent(["scope"], account)
57+
self.assertEqual(
58+
"Oid:{}@{}".format( # Server accepts case-insensitive value
59+
uid, utid), # It would look like "Oid:foo@bar"
60+
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
61+
"CSS routing info should be derived from home_account_id")
62+
63+
def test_acquire_token_by_username_password(self):
64+
app = msal.ClientApplication("client_id")
65+
username = "[email protected]"
66+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
67+
status_code=400, text='{"error": "mock"}')) as mocked_method:
68+
app.acquire_token_by_username_password(username, "password", ["scope"])
69+
self.assertEqual(
70+
"upn:" + username,
71+
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
72+
"CSS routing info should be derived from client_info")
73+

tests/test_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,8 @@ def _test_acquire_token_by_auth_code_flow(
516516
client_id, authority=authority, http_client=MinimalHttpClient())
517517
with AuthCodeReceiver(port=port) as receiver:
518518
flow = self.app.initiate_auth_code_flow(
519+
scope,
519520
redirect_uri="http://localhost:%d" % receiver.get_port(),
520-
scopes=scope,
521521
)
522522
auth_response = receiver.get_auth_response(
523523
auth_uri=flow["auth_uri"], state=flow["state"], timeout=60,

0 commit comments

Comments
 (0)