Skip to content

Implementing CCS Routing info #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import requests

from .oauth2cli import Client, JwtAssertionCreator
from .oauth2cli.oidc import decode_part
from .authority import Authority
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
Expand Down Expand Up @@ -111,6 +112,34 @@ def _preferred_browser():
return None


class _ClientWithCcsRoutingInfo(Client):

def initiate_auth_code_flow(self, **kwargs):
return super(_ClientWithCcsRoutingInfo, self).initiate_auth_code_flow(
client_info=1, # To be used as CSS Routing info
**kwargs)

def obtain_token_by_auth_code_flow(
self, auth_code_flow, auth_response, **kwargs):
# Note: the obtain_token_by_browser() is also covered by this
assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict)
headers = kwargs.pop("headers", {})
client_info = json.loads(
decode_part(auth_response["client_info"])
) if auth_response.get("client_info") else {}
if "uid" in client_info and "utid" in client_info:
# Note: The value of X-AnchorMailbox is also case-insensitive
headers["X-AnchorMailbox"] = "Oid:{uid}@{utid}".format(**client_info)
return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_auth_code_flow(
auth_code_flow, auth_response, headers=headers, **kwargs)

def obtain_token_by_username_password(self, username, password, **kwargs):
headers = kwargs.pop("headers", {})
headers["X-AnchorMailbox"] = "upn:{}".format(username)
return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_username_password(
username, password, headers=headers, **kwargs)


class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
Expand Down Expand Up @@ -481,7 +510,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
authority.device_authorization_endpoint or
urljoin(authority.token_endpoint, "devicecode"),
}
central_client = Client(
central_client = _ClientWithCcsRoutingInfo(
central_configuration,
self.client_id,
http_client=self.http_client,
Expand All @@ -506,7 +535,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
regional_authority.device_authorization_endpoint or
urljoin(regional_authority.token_endpoint, "devicecode"),
}
regional_client = Client(
regional_client = _ClientWithCcsRoutingInfo(
regional_configuration,
self.client_id,
http_client=self.http_client,
Expand Down Expand Up @@ -577,7 +606,7 @@ def initiate_auth_code_flow(
3. and then relay this dict and subsequent auth response to
:func:`~acquire_token_by_auth_code_flow()`.
"""
client = Client(
client = _ClientWithCcsRoutingInfo(
{"authorization_endpoint": self.authority.authorization_endpoint},
self.client_id,
http_client=self.http_client)
Expand Down Expand Up @@ -654,7 +683,7 @@ def get_authorization_request_url(
self.http_client
) if authority else self.authority

client = Client(
client = _ClientWithCcsRoutingInfo(
{"authorization_endpoint": the_authority.authorization_endpoint},
self.client_id,
http_client=self.http_client)
Expand Down Expand Up @@ -1178,6 +1207,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
key=lambda e: int(e.get("last_modification_time", "0")),
reverse=True):
logger.debug("Cache attempts an RT")
headers = telemetry_context.generate_headers()
if "home_account_id" in query: # Then use it as CCS Routing info
headers["X-AnchorMailbox"] = "Oid:{}".format( # case-insensitive value
query["home_account_id"].replace(".", "@"))
response = client.obtain_token_by_refresh_token(
entry, rt_getter=lambda token_item: token_item["secret"],
on_removing_rt=lambda rt_item: None, # Disable RT removal,
Expand All @@ -1189,7 +1222,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
skip_account_creation=True, # To honor a concurrent remove_account()
)),
scope=scopes,
headers=telemetry_context.generate_headers(),
headers=headers,
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
Expand Down
73 changes: 73 additions & 0 deletions tests/test_ccs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
try:
from unittest.mock import patch, ANY
except:
from mock import patch, ANY

from tests.http_client import MinimalResponse
from tests.test_token_cache import build_response

import msal


class TestCcsRoutingInfoTestCase(unittest.TestCase):

def test_acquire_token_by_auth_code_flow(self):
app = msal.ClientApplication("client_id")
state = "foo"
flow = app.initiate_auth_code_flow(
["some", "scope"], login_hint="[email protected]", state=state)
with patch.object(app.http_client, "post", return_value=MinimalResponse(
status_code=400, text='{"error": "mock"}')) as mocked_method:
app.acquire_token_by_auth_code_flow(flow, {
"state": state,
"code": "bar",
"client_info": # MSAL asks for client_info, so it would be available
"eyJ1aWQiOiJhYTkwNTk0OS1hMmI4LTRlMGEtOGFlYS1iMzJlNTNjY2RiNDEiLCJ1dGlkIjoiNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3In0",
})
self.assertEqual(
"Oid:aa905949-a2b8-4e0a-8aea-b32e53ccdb41@72f988bf-86f1-41af-91ab-2d7cd011db47",
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
"CSS routing info should be derived from client_info")

# I've manually tested acquire_token_interactive. No need to automate it,
# because it and acquire_token_by_auth_code_flow() share same code path.

def test_acquire_token_silent(self):
uid = "foo"
utid = "bar"
client_id = "my_client_id"
scopes = ["some", "scope"]
authority_url = "https://login.microsoftonline.com/common"
token_cache = msal.TokenCache()
token_cache.add({ # Pre-populate the cache
"client_id": client_id,
"scope": scopes,
"token_endpoint": "{}/oauth2/v2.0/token".format(authority_url),
"response": build_response(
access_token="an expired AT to trigger refresh", expires_in=-99,
uid=uid, utid=utid, refresh_token="this is a RT"),
}) # The add(...) helper populates correct home_account_id for future searching
app = msal.ClientApplication(
client_id, authority=authority_url, token_cache=token_cache)
with patch.object(app.http_client, "post", return_value=MinimalResponse(
status_code=400, text='{"error": "mock"}')) as mocked_method:
account = {"home_account_id": "{}.{}".format(uid, utid)}
app.acquire_token_silent(["scope"], account)
self.assertEqual(
"Oid:{}@{}".format( # Server accepts case-insensitive value
uid, utid), # It would look like "Oid:foo@bar"
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
"CSS routing info should be derived from home_account_id")

def test_acquire_token_by_username_password(self):
app = msal.ClientApplication("client_id")
username = "[email protected]"
with patch.object(app.http_client, "post", return_value=MinimalResponse(
status_code=400, text='{"error": "mock"}')) as mocked_method:
app.acquire_token_by_username_password(username, "password", ["scope"])
self.assertEqual(
"upn:" + username,
mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'),
"CSS routing info should be derived from client_info")

2 changes: 1 addition & 1 deletion tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,8 @@ def _test_acquire_token_by_auth_code_flow(
client_id, authority=authority, http_client=MinimalHttpClient())
with AuthCodeReceiver(port=port) as receiver:
flow = self.app.initiate_auth_code_flow(
scope,
redirect_uri="http://localhost:%d" % receiver.get_port(),
scopes=scope,
)
auth_response = receiver.get_auth_response(
auth_uri=flow["auth_uri"], state=flow["state"], timeout=60,
Expand Down