Skip to content

Commit a3f3c91

Browse files
committed
Implementing CCS Routing info
1 parent e969e64 commit a3f3c91

File tree

3 files changed

+72
-4
lines changed

3 files changed

+72
-4
lines changed

msal/application.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ def initiate_auth_code_flow(
590590
self._client_capabilities, claims_challenge),
591591
)
592592
flow["claims_challenge"] = claims_challenge
593+
if login_hint:
594+
flow["login_hint"] = login_hint # To be relayed to token endpoint
593595
return flow
594596

595597
def get_authorization_request_url(
@@ -726,11 +728,15 @@ def authorize(): # A controller in a web app
726728
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
727729
telemetry_context = self._build_telemetry_context(
728730
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID)
731+
headers = telemetry_context.generate_headers()
732+
if "login_hint" in auth_code_flow: # Then use it as the CCS Routing info
733+
headers["X-AnchorMailbox"] = "UPN:{}".format(
734+
auth_code_flow.pop("login_hint"))
729735
response =_clean_up(self.client.obtain_token_by_auth_code_flow(
730736
auth_code_flow,
731737
auth_response,
732738
scope=self._decorate_scope(scopes) if scopes else None,
733-
headers=telemetry_context.generate_headers(),
739+
headers=headers,
734740
data=dict(
735741
kwargs.pop("data", {}),
736742
claims=_merge_claims_challenge_and_capabilities(
@@ -1178,6 +1184,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
11781184
key=lambda e: int(e.get("last_modification_time", "0")),
11791185
reverse=True):
11801186
logger.debug("Cache attempts an RT")
1187+
headers = telemetry_context.generate_headers()
1188+
if "home_account_id" in query: # Then use it as CCS Routing info
1189+
headers["X-AnchorMailbox"] = "Oid:{}".format(
1190+
query["home_account_id"].replace(".", "@"))
11811191
response = client.obtain_token_by_refresh_token(
11821192
entry, rt_getter=lambda token_item: token_item["secret"],
11831193
on_removing_rt=lambda rt_item: None, # Disable RT removal,
@@ -1189,7 +1199,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
11891199
skip_account_creation=True, # To honor a concurrent remove_account()
11901200
)),
11911201
scope=scopes,
1192-
headers=telemetry_context.generate_headers(),
1202+
headers=headers,
11931203
data=dict(
11941204
kwargs.pop("data", {}),
11951205
claims=_merge_claims_challenge_and_capabilities(
@@ -1284,6 +1294,8 @@ def acquire_token_by_username_password(
12841294
telemetry_context = self._build_telemetry_context(
12851295
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)
12861296
headers = telemetry_context.generate_headers()
1297+
# No need to add CCS Routing info,
1298+
# because username param will be recognized as CCS Routing info.
12871299
data = dict(
12881300
kwargs.pop("data", {}),
12891301
claims=_merge_claims_challenge_and_capabilities(
@@ -1425,6 +1437,9 @@ def acquire_token_interactive(
14251437
self._client_capabilities, claims_challenge)
14261438
telemetry_context = self._build_telemetry_context(
14271439
self.ACQUIRE_TOKEN_INTERACTIVE)
1440+
headers = telemetry_context.generate_headers()
1441+
if login_hint: # Then use it as the CCS Routing info
1442+
headers["X-AnchorMailbox"] = "UPN:{}".format(login_hint)
14281443
response = _clean_up(self.client.obtain_token_by_browser(
14291444
scope=self._decorate_scope(scopes) if scopes else None,
14301445
extra_scope_to_consent=extra_scopes_to_consent,
@@ -1439,7 +1454,7 @@ def acquire_token_interactive(
14391454
"domain_hint": domain_hint,
14401455
},
14411456
data=dict(kwargs.pop("data", {}), claims=claims),
1442-
headers=telemetry_context.generate_headers(),
1457+
headers=headers,
14431458
browser_name=_preferred_browser(),
14441459
**kwargs))
14451460
telemetry_context.update_telemetry(response)

tests/test_ccs.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
try:
3+
from unittest.mock import patch, ANY
4+
except:
5+
from mock import patch, ANY
6+
7+
from tests.http_client import MinimalResponse
8+
from tests.test_token_cache import build_response
9+
10+
import msal
11+
12+
13+
class TestCcsRoutingInfoTestCase(unittest.TestCase):
14+
15+
def test_acquire_token_by_auth_code_flow(self):
16+
app = msal.ClientApplication("client_id")
17+
state = "foo"
18+
flow = app.initiate_auth_code_flow(
19+
["some", "scope"], login_hint="[email protected]", state=state)
20+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
21+
status_code=400, text='{"error": "mock"}')) as mocked_method:
22+
app.acquire_token_by_auth_code_flow(flow, {"state": state, "code": "bar"})
23+
self.assertEqual(
24+
25+
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
26+
"CSS routing info should be derived from login_hint")
27+
28+
def test_acquire_token_silent(self):
29+
uid = "foo"
30+
utid = "bar"
31+
client_id = "my_client_id"
32+
scopes = ["some", "scope"]
33+
authority_url = "https://login.microsoftonline.com/common"
34+
token_cache = msal.TokenCache()
35+
token_cache.add({ # Pre-populate the cache
36+
"client_id": client_id,
37+
"scope": scopes,
38+
"token_endpoint": "{}/oauth2/v2.0/token".format(authority_url),
39+
"response": build_response(
40+
access_token="an expired AT to trigger refresh", expires_in=-99,
41+
uid=uid, utid=utid, refresh_token="this is a RT"),
42+
}) # The add(...) helper populates correct home_account_id for future searching
43+
app = msal.ClientApplication(
44+
client_id, authority=authority_url, token_cache=token_cache)
45+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
46+
status_code=400, text='{"error": "mock"}')) as mocked_method:
47+
account = {"home_account_id": "{}.{}".format(uid, utid)}
48+
app.acquire_token_silent(["scope"], account)
49+
self.assertEqual(
50+
"Oid:{}@{}".format(uid, utid), # It would look like "Oid:foo@bar"
51+
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
52+
"CSS routing info should be derived from home_account_id")
53+

tests/test_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,8 @@ def _test_acquire_token_by_auth_code_flow(
516516
client_id, authority=authority, http_client=MinimalHttpClient())
517517
with AuthCodeReceiver(port=port) as receiver:
518518
flow = self.app.initiate_auth_code_flow(
519+
scope,
519520
redirect_uri="http://localhost:%d" % receiver.get_port(),
520-
scopes=scope,
521521
)
522522
auth_response = receiver.get_auth_response(
523523
auth_uri=flow["auth_uri"], state=flow["state"], timeout=60,

0 commit comments

Comments
 (0)