Skip to content

Commit 6973a99

Browse files
committed
Only invoke broker for selected flows (grants)
ROPC also bypass broker, for now Unit tests
1 parent 88c4bf8 commit 6973a99

File tree

5 files changed

+154
-16
lines changed

5 files changed

+154
-16
lines changed

msal/__main__.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,20 @@
55
66
Usage 1: Run it on the fly.
77
python -m msal
8+
Note: We choose to not define a console script to avoid name conflict.
89
910
Usage 2: Build an all-in-one executable file for bug bash.
1011
shiv -e msal.__main__._main -o msaltest-on-os-name.pyz .
11-
Note: We choose to not define a console script to avoid name conflict.
1212
"""
13-
import base64, getpass, json, logging, sys, msal
13+
import base64, getpass, json, logging, sys, os, atexit, msal
14+
15+
_token_cache_filename = "msal_cache.bin"
16+
global_cache = msal.SerializableTokenCache()
17+
atexit.register(lambda:
18+
open(_token_cache_filename, "w").write(global_cache.serialize())
19+
# Hint: The following optional line persists only when state changed
20+
if global_cache.has_state_changed else None
21+
)
1422

1523
_AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
1624
_VISUAL_STUDIO = "04f0c124-f2bc-4f59-8241-bf6df9866bbd"
@@ -66,7 +74,7 @@ def _select_account(app):
6674
if accounts:
6775
return _select_options(
6876
accounts,
69-
option_renderer=lambda a: a["username"],
77+
option_renderer=lambda a: "{}, came from {}".format(a["username"], a["account_source"]),
7078
header="Account(s) already signed in inside MSAL Python:",
7179
)
7280
else:
@@ -76,7 +84,7 @@ def _acquire_token_silent(app):
7684
"""acquire_token_silent() - with an account already signed into MSAL Python."""
7785
account = _select_account(app)
7886
if account:
79-
print_json(app.acquire_token_silent(
87+
print_json(app.acquire_token_silent_with_error(
8088
_input_scopes(),
8189
account=account,
8290
force_refresh=_input_boolean("Bypass MSAL Python's token cache?"),
@@ -122,6 +130,15 @@ def _acquire_token_by_username_password(app):
122130
print_json(app.acquire_token_by_username_password(
123131
_input("username: "), getpass.getpass("password: "), scopes=_input_scopes()))
124132

133+
def _acquire_token_by_device_flow(app):
134+
"""acquire_token_by_device_flow() - Note that this one does not go through broker"""
135+
flow = app.initiate_device_flow(scopes=_input_scopes())
136+
print(flow["message"])
137+
sys.stdout.flush() # Some terminal needs this to ensure the message is shown
138+
input("After you completed the step above, press ENTER in this console to continue...")
139+
result = app.acquire_token_by_device_flow(flow) # By default it will block
140+
print_json(result)
141+
125142
_JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
126143
_SSH_CERT_DATA = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1}
127144
_SSH_CERT_SCOPE = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]
@@ -182,6 +199,27 @@ def _exit(app):
182199

183200
def _main():
184201
print("Welcome to the Msal Python {} Tester (Experimental)\n".format(msal.__version__))
202+
cache_choice = _select_options([
203+
{
204+
"choice": "empty",
205+
"desc": "Start with an empty token cache. Suitable for one-off tests.",
206+
},
207+
{
208+
"choice": "reuse",
209+
"desc": "Reuse the previous token cache {} (if any) "
210+
"which was created during last test app exit. "
211+
"Useful for testing acquire_token_silent() repeatedly".format(
212+
_token_cache_filename),
213+
},
214+
],
215+
option_renderer=lambda o: o["desc"],
216+
header="What token cache state do you want to begin with?",
217+
accept_nonempty_string=False)
218+
if cache_choice["choice"] == "reuse" and os.path.exists(_token_cache_filename):
219+
try:
220+
global_cache.deserialize(open(_token_cache_filename, "r").read())
221+
except IOError:
222+
pass # Use empty token cache
185223
chosen_app = _select_options([
186224
{"client_id": _AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
187225
{"client_id": _VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
@@ -207,6 +245,7 @@ def _main():
207245
),
208246
enable_broker_on_windows=enable_broker,
209247
enable_pii_log=enable_pii_log,
248+
token_cache=global_cache,
210249
)
211250
if enable_debug_log:
212251
logging.basicConfig(level=logging.DEBUG)
@@ -215,6 +254,7 @@ def _main():
215254
_acquire_token_silent,
216255
_acquire_token_interactive,
217256
_acquire_token_by_username_password,
257+
_acquire_token_by_device_flow,
218258
_acquire_ssh_cert_silently,
219259
_acquire_ssh_cert_interactive,
220260
_acquire_pop_token_interactive,

msal/application.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .mex import send_request as mex_send_request
1818
from .wstrust_request import send_request as wst_send_request
1919
from .wstrust_response import *
20-
from .token_cache import TokenCache, _get_username
20+
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
2121
import msal.telemetry
2222
from .region import _detect_region
2323
from .throttled_http_client import ThrottledHttpClient
@@ -1104,6 +1104,7 @@ def _find_msal_accounts(self, environment):
11041104
"home_account_id": a.get("home_account_id"),
11051105
"environment": a.get("environment"),
11061106
"username": a.get("username"),
1107+
"account_source": a.get("account_source"),
11071108

11081109
# The following fields for backward compatibility, for now
11091110
"authority_type": a.get("authority_type"),
@@ -1398,7 +1399,10 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
13981399
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
13991400
return self._acquire_token_by_cloud_shell(scopes, data=data)
14001401

1401-
if self._enable_broker and account is not None:
1402+
if self._enable_broker and account and account.get("account_source") in (
1403+
_GRANT_TYPE_BROKER, # Broker successfully established this account previously.
1404+
None, # Unknown data from older MSAL. Broker might still work.
1405+
):
14021406
from .broker import _acquire_token_silently
14031407
response = _acquire_token_silently(
14041408
"https://{}/{}".format(self.authority.instance, self.authority.tenant),
@@ -1409,8 +1413,12 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
14091413
self._client_capabilities, claims_challenge),
14101414
correlation_id=correlation_id,
14111415
**data)
1412-
if response: # The broker provided a decisive outcome, so we use it
1413-
return self._process_broker_response(response, scopes, data)
1416+
if response: # Broker provides a decisive outcome
1417+
account_was_established_by_broker = account.get(
1418+
"account_source") == _GRANT_TYPE_BROKER
1419+
broker_attempt_succeeded_just_now = "error" not in response
1420+
if account_was_established_by_broker or broker_attempt_succeeded_just_now:
1421+
return self._process_broker_response(response, scopes, data)
14141422

14151423
if account:
14161424
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
@@ -1441,6 +1449,8 @@ def _process_broker_response(self, response, scopes, data):
14411449
response=response,
14421450
data=data,
14431451
_account_id=response["_account_id"],
1452+
environment=self.authority.instance, # Be consistent with non-broker flows
1453+
grant_type=_GRANT_TYPE_BROKER, # A pseudo grant type for TokenCache to mark account_source as broker
14441454
))
14451455
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
14461456
return _clean_up(response)
@@ -1628,7 +1638,7 @@ def acquire_token_by_username_password(
16281638
"""
16291639
claims = _merge_claims_challenge_and_capabilities(
16301640
self._client_capabilities, claims_challenge)
1631-
if self._enable_broker:
1641+
if False: # Disabled, for now. It was if self._enable_broker:
16321642
from .broker import _signin_silently
16331643
response = _signin_silently(
16341644
"https://{}/{}".format(self.authority.instance, self.authority.tenant),

msal/broker.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,22 @@ def _convert_error(error, client_id):
7070

7171

7272
def _read_account_by_id(account_id, correlation_id):
73-
"""Return an instance of MSALRuntimeError or MSALRuntimeAccount, or None"""
73+
"""Return an instance of MSALRuntimeAccount, or log error and return None"""
7474
callback_data = _CallbackData()
7575
pymsalruntime.read_account_by_id(
7676
account_id,
7777
correlation_id,
7878
lambda result, callback_data=callback_data: callback_data.complete(result)
7979
)
8080
callback_data.signal.wait()
81-
return (callback_data.result.get_error() or callback_data.result.get_account()
82-
or None) # None happens when the account was not created by broker
81+
error = callback_data.result.get_error()
82+
if error:
83+
logger.debug("read_account_by_id() error: %s", _convert_error(error, None))
84+
return None
85+
account = callback_data.result.get_account()
86+
if account:
87+
return account
88+
return None # None happens when the account was not created by broker
8389

8490

8591
def _convert_result(result, client_id, expected_token_type=None): # Mimic an on-the-wire response from AAD
@@ -196,8 +202,6 @@ def _acquire_token_silently(
196202
# acquireTokenSilently is expected to fail. - Sam Wilson
197203
correlation_id = correlation_id or _get_new_correlation_id()
198204
account = _read_account_by_id(account_id, correlation_id)
199-
if isinstance(account, pymsalruntime.MSALRuntimeError):
200-
return _convert_error(account, client_id)
201205
if account is None:
202206
return
203207
params = pymsalruntime.MSALRuntimeAuthParameters(client_id, authority)
@@ -221,8 +225,6 @@ def _acquire_token_silently(
221225
def _signout_silently(client_id, account_id, correlation_id=None):
222226
correlation_id = correlation_id or _get_new_correlation_id()
223227
account = _read_account_by_id(account_id, correlation_id)
224-
if isinstance(account, pymsalruntime.MSALRuntimeError):
225-
return _convert_error(account, client_id)
226228
if account is None:
227229
return
228230
callback_data = _CallbackData()

msal/token_cache.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
from .authority import canonicalize
77
from .oauth2cli.oidc import decode_part, decode_id_token
8+
from .oauth2cli.oauth2 import Client
89

910

1011
logger = logging.getLogger(__name__)
12+
_GRANT_TYPE_BROKER = "broker"
1113

1214
def is_subdict_of(small, big):
1315
return dict(big, **small) == big
@@ -210,6 +212,11 @@ def __add(self, event, now=None):
210212
else self.AuthorityType.MSSTS),
211213
# "client_info": response.get("client_info"), # Optional
212214
}
215+
grant_types_that_establish_an_account = (
216+
_GRANT_TYPE_BROKER, "authorization_code", "password",
217+
Client.DEVICE_FLOW["GRANT_TYPE"])
218+
if event.get("grant_type") in grant_types_that_establish_an_account:
219+
account["account_source"] = event["grant_type"]
213220
self.modify(self.CredentialType.ACCOUNT, account, account)
214221

215222
if id_token:

tests/test_account_source.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
try:
3+
from unittest.mock import patch
4+
except:
5+
from mock import patch
6+
try:
7+
import pymsalruntime
8+
broker_available = True
9+
except ImportError:
10+
broker_available = False
11+
import msal
12+
from tests import unittest
13+
from tests.test_token_cache import build_response
14+
from tests.http_client import MinimalResponse
15+
16+
17+
SCOPE = "scope_foo"
18+
TOKEN_RESPONSE = build_response(
19+
access_token="at",
20+
uid="uid", utid="utid", # So that it will create an account
21+
scope=SCOPE, refresh_token="rt", # So that non-broker's acquire_token_silent() would work
22+
)
23+
24+
def _mock_post(url, headers=None, *args, **kwargs):
25+
return MinimalResponse(status_code=200, text=json.dumps(TOKEN_RESPONSE))
26+
27+
@unittest.skipUnless(broker_available, "These test cases need pip install msal[broker]")
28+
@patch("msal.broker._acquire_token_silently", return_value=dict(
29+
TOKEN_RESPONSE, _account_id="placeholder"))
30+
@patch.object(msal.authority, "tenant_discovery", return_value={
31+
"authorization_endpoint": "https://contoso.com/placeholder",
32+
"token_endpoint": "https://contoso.com/placeholder",
33+
}) # Otherwise it would fail on OIDC discovery
34+
class TestAccountSourceBehavior(unittest.TestCase):
35+
36+
def test_device_flow_and_its_silent_call_should_bypass_broker(self, _, mocked_broker_ats):
37+
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
38+
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=_mock_post)
39+
self.assertEqual(result["token_source"], "identity_provider")
40+
41+
account = app.get_accounts()[0]
42+
self.assertEqual(account["account_source"], "urn:ietf:params:oauth:grant-type:device_code")
43+
44+
result = app.acquire_token_silent_with_error(
45+
[SCOPE], account, force_refresh=True, post=_mock_post)
46+
mocked_broker_ats.assert_not_called()
47+
self.assertEqual(result["token_source"], "identity_provider")
48+
49+
def test_ropc_flow_and_its_silent_call_should_bypass_broker(self, _, mocked_broker_ats):
50+
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
51+
with patch.object(app.authority, "user_realm_discovery", return_value={}):
52+
result = app.acquire_token_by_username_password(
53+
"username", "placeholder", [SCOPE], post=_mock_post)
54+
self.assertEqual(result["token_source"], "identity_provider")
55+
56+
account = app.get_accounts()[0]
57+
self.assertEqual(account["account_source"], "password")
58+
59+
result = app.acquire_token_silent_with_error(
60+
[SCOPE], account, force_refresh=True, post=_mock_post)
61+
mocked_broker_ats.assert_not_called()
62+
self.assertEqual(result["token_source"], "identity_provider")
63+
64+
def test_interactive_flow_and_its_silent_call_should_invoke_broker(self, _, mocked_broker_ats):
65+
app = msal.PublicClientApplication("client_id", enable_broker_on_windows=True)
66+
with patch.object(app, "_acquire_token_interactive_via_broker", return_value=dict(
67+
TOKEN_RESPONSE, _account_id="placeholder")):
68+
result = app.acquire_token_interactive(
69+
[SCOPE], parent_window_handle=app.CONSOLE_WINDOW_HANDLE)
70+
self.assertEqual(result["token_source"], "broker")
71+
72+
account = app.get_accounts()[0]
73+
self.assertEqual(account["account_source"], "broker")
74+
75+
result = app.acquire_token_silent_with_error(
76+
[SCOPE], account, force_refresh=True, post=_mock_post)
77+
mocked_broker_ats.assert_called_once()
78+
self.assertEqual(result["token_source"], "broker")
79+

0 commit comments

Comments
 (0)