Skip to content

Expose token_source for observability #610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ class ClientApplication(object):
REMOVE_ACCOUNT_ID = "903"

ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect"
_TOKEN_SOURCE = "token_source"
_TOKEN_SOURCE_IDP = "identity_provider"
_TOKEN_SOURCE_CACHE = "cache"
_TOKEN_SOURCE_BROKER = "broker"

def __init__(
self, client_id,
Expand Down Expand Up @@ -998,6 +1002,8 @@ def authorize(): # A controller in a web app
self._client_capabilities,
auth_code_flow.pop("claims_challenge", None))),
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1070,6 +1076,8 @@ def acquire_token_by_authorization_code(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1218,6 +1226,8 @@ def _acquire_token_by_cloud_shell(self, scopes, data=None):
data=data or {},
authority_type=_AUTHORITY_TYPE_CLOUDSHELL,
))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
return response

def acquire_token_silent(
Expand Down Expand Up @@ -1395,6 +1405,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"access_token": entry["secret"],
"token_type": entry.get("token_type", "Bearer"),
"expires_in": int(expires_in), # OAuth2 specs defines it as int
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
Expand Down Expand Up @@ -1437,6 +1448,8 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
result = self._acquire_token_for_client(
scopes, refresh_reason, claims_challenge=claims_challenge,
**kwargs)
if result and "access_token" in result:
result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
if (result and "error" not in result) or (not access_token_from_cache):
return result
except http_exceptions:
Expand All @@ -1455,6 +1468,7 @@ def _process_broker_response(self, response, scopes, data):
data=data,
_account_id=response["_account_id"],
))
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
return _clean_up(response)

def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
Expand Down Expand Up @@ -1611,6 +1625,8 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
on_updating_rt=False,
on_removing_rt=lambda rt_item: None, # No OP
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1658,6 +1674,7 @@ def acquire_token_by_username_password(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)
headers = telemetry_context.generate_headers()
data = dict(kwargs.pop("data", {}), claims=claims)
response = None
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[msal.telemetry.CLIENT_REQUEST_ID])
Expand All @@ -1666,13 +1683,14 @@ def acquire_token_by_username_password(
user_realm_result, username, password, scopes=scopes,
data=data,
headers=headers, **kwargs))
telemetry_context.update_telemetry(response)
return response
response = _clean_up(self.client.obtain_token_by_username_password(
if response is None: # Either ADFS or not federated
response = _clean_up(self.client.obtain_token_by_username_password(
username, password, scope=scopes,
headers=headers,
data=data,
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1859,7 +1877,7 @@ def acquire_token_interactive(
logger.warning(
"Ignoring parameter extra_scopes_to_consent, "
"which is not supported by broker")
return self._acquire_token_interactive_via_broker(
response = self._acquire_token_interactive_via_broker(
scopes,
parent_window_handle,
enable_msa_passthrough,
Expand All @@ -1870,6 +1888,7 @@ def acquire_token_interactive(
login_hint=login_hint,
max_age=max_age,
)
return self._process_broker_response(response, scopes, data)

on_before_launching_ui(ui="browser")
telemetry_context = self._build_telemetry_context(
Expand All @@ -1892,6 +1911,8 @@ def acquire_token_interactive(
headers=telemetry_context.generate_headers(),
browser_name=_preferred_browser(),
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1928,7 +1949,7 @@ def _acquire_token_interactive_via_broker(
claims=claims,
**data)
if response and "error" not in response:
return self._process_broker_response(response, scopes, data)
return response
# login_hint undecisive or not exists
if prompt == "none" or not prompt: # Must/Can attempt _signin_silently()
logger.debug("Calling broker._signin_silently()")
Expand All @@ -1949,9 +1970,7 @@ def _acquire_token_interactive_via_broker(
if is_wrong_account:
logger.debug(wrong_account_error_message)
if prompt == "none":
return self._process_broker_response( # It is either token or error
response, scopes, data
) if not is_wrong_account else {
return response if not is_wrong_account else {
"error": "broker_error",
"error_description": wrong_account_error_message,
}
Expand All @@ -1966,11 +1985,11 @@ def _acquire_token_interactive_via_broker(
"_broker_status") in recoverable_errors:
pass # It will fall back to the _signin_interactively()
else:
return self._process_broker_response(response, scopes, data)
return response

logger.debug("Falls back to broker._signin_interactively()")
on_before_launching_ui(ui="broker")
response = _signin_interactively(
return _signin_interactively(
authority, self.client_id, scopes,
None if parent_window_handle is self.CONSOLE_WINDOW_HANDLE
else parent_window_handle,
Expand All @@ -1981,7 +2000,6 @@ def _acquire_token_interactive_via_broker(
max_age=max_age,
enable_msa_pt=enable_msa_passthrough,
**data)
return self._process_broker_response(response, scopes, data)

def initiate_device_flow(self, scopes=None, **kwargs):
"""Initiate a Device Flow instance,
Expand Down Expand Up @@ -2036,6 +2054,8 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
),
headers=telemetry_context.generate_headers(),
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -2145,5 +2165,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
headers=telemetry_context.generate_headers(),
# TBD: Expose a login_hint (or ccs_routing_hint) param for web app
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response
1 change: 1 addition & 0 deletions sample/confidential_client_certificate_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def acquire_and_use_token():
result = global_app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/confidential_client_secret_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def acquire_and_use_token():
result = global_app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/device_flow_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def acquire_and_use_token():
# and then keep calling acquire_token_by_device_flow(flow) in your own customized loop.

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/interactive_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def acquire_and_use_token():
)

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_response = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/username_password_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def acquire_and_use_token():
config["username"], config["password"], scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/vault_jwt_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def acquire_and_use_token():
result = global_app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
17 changes: 17 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def tester(url, **kwargs):
self.scopes, self.account, post=tester)
self.assertEqual("", result.get("classification"))


class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -263,6 +264,7 @@ def test_get_accounts_should_find_accounts_under_different_alias(self):
def test_acquire_token_silent_should_find_at_under_different_alias(self):
result = self.app.acquire_token_silent(self.scopes, self.account)
self.assertNotEqual(None, result)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(self.access_token, result.get('access_token'))

def test_acquire_token_silent_should_find_rt_under_different_alias(self):
Expand Down Expand Up @@ -360,6 +362,7 @@ def test_fresh_token_should_be_returned_from_cache(self):
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
self.fail("I/O shouldn't happen in cache hit AT scenario")
)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

Expand All @@ -374,6 +377,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"refresh_in": 123,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

Expand All @@ -385,6 +389,7 @@ def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(old_at, result.get("access_token"))

def test_expired_token_and_unavailable_aad_should_return_error(self):
Expand All @@ -409,6 +414,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"refresh_in": 123,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

Expand Down Expand Up @@ -444,6 +450,7 @@ def test_maintaining_offline_state_and_sending_them(self):
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
self.fail("I/O shouldn't happen in cache hit AT scenario")
)
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE)
self.assertEqual(cached_access_token, result.get("access_token"))

error1 = "error_1"
Expand Down Expand Up @@ -477,6 +484,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"The previous error should result in same success counter plus latest error info")
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post)
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))

def mock_post(url, headers=None, *args, **kwargs):
Expand All @@ -485,6 +493,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"The previous success should reset all offline telemetry counters")
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post)
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand All @@ -503,6 +512,7 @@ def mock_post(url, headers=None, *args, **kwargs):
result = self.app.acquire_token_by_auth_code_flow(
{"state": state, "code_verifier": "bar"}, {"state": state, "code": "012"},
post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))

def test_acquire_token_by_refresh_token(self):
Expand All @@ -511,6 +521,7 @@ def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|85,1|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_by_refresh_token("rt", ["s"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand All @@ -529,6 +540,7 @@ def mock_post(url, headers=None, *args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_by_device_flow(
{"device_code": "123"}, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))

def test_acquire_token_by_username_password(self):
Expand All @@ -538,6 +550,7 @@ def mock_post(url, headers=None, *args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_by_username_password(
"username", "password", ["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand All @@ -556,6 +569,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"expires_in": 0,
}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual("AT 1", result.get("access_token"), "Shall get a new token")

def mock_post(url, headers=None, *args, **kwargs):
Expand All @@ -566,13 +580,15 @@ def mock_post(url, headers=None, *args, **kwargs):
"refresh_in": -100, # A hack to make sure it will attempt refresh
}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual("AT 2", result.get("access_token"), "Shall get a new token")

def mock_post(url, headers=None, *args, **kwargs):
# 1/0 # TODO: Make sure this was called
self.assertEqual("4|730,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual("AT 2", result.get("access_token"), "Shall get aging token")

def test_acquire_token_on_behalf_of(self):
Expand All @@ -581,6 +597,7 @@ def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|523,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_on_behalf_of("assertion", ["s"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand Down