Skip to content

Commit ef3c9d4

Browse files
committed
Dedicate ManagedIdentity API
1 parent 0c57056 commit ef3c9d4

File tree

4 files changed

+89
-42
lines changed

4 files changed

+89
-42
lines changed

msal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@
3333
)
3434
from .oauth2cli.oidc import Prompt
3535
from .token_cache import TokenCache, SerializableTokenCache
36+
from .imds import ManagedIdentity
3637

msal/application.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2001,21 +2001,6 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
20012001
- an error response would contain "error" and usually "error_description".
20022002
"""
20032003
# TBD: force_refresh behavior
2004-
if self.client_credential is None:
2005-
from .imds import _scope_to_resource, _obtain_token
2006-
response = _obtain_token(
2007-
self.http_client,
2008-
" ".join(map(_scope_to_resource, scopes)),
2009-
client_id=self.client_id, # None for system-assigned, GUID for user-assigned
2010-
)
2011-
if "error" not in response:
2012-
self.token_cache.add(dict(
2013-
client_id=self.client_id,
2014-
scope=response["scope"].split() if "scope" in response else scopes,
2015-
token_endpoint=self.authority.token_endpoint,
2016-
response=response.copy(),
2017-
))
2018-
return response
20192004
if self.authority.tenant.lower() in ["common", "organizations"]:
20202005
warnings.warn(
20212006
"Using /common or /organizations authority "

msal/imds.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77
import os
8+
import socket
89
import time
910
try: # Python 2
1011
from urlparse import urlparse
@@ -57,6 +58,9 @@ def _obtain_token_on_azure_vm(http_client, resource, client_id=None):
5758
raise
5859

5960
def _obtain_token_on_app_service(http_client, endpoint, identity_header, resource, client_id=None):
61+
"""Obtains token for
62+
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_
63+
"""
6064
# Prerequisite: Create your app service https://docs.microsoft.com/en-us/azure/app-service/quickstart-python
6165
# Assign it a managed identity https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp
6266
# SSH into your container for testing https://docs.microsoft.com/en-us/azure/app-service/configure-linux-open-ssh-session
@@ -73,7 +77,7 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
7377
headers={
7478
"X-IDENTITY-HEADER": identity_header,
7579
"Metadata": "true", # Unnecessary yet harmless for App Service,
76-
# It will be needed by Azure Automation
80+
# It will be needed by Azure Automation
7781
# https://docs.microsoft.com/en-us/azure/automation/enable-managed-identity-for-automation#get-access-token-for-system-assigned-managed-identity-using-http-get
7882
},
7983
)
@@ -95,3 +99,53 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
9599
logger.debug("IMDS emits unexpected payload: %s", resp.text)
96100
raise
97101

102+
103+
class ManagedIdentity(object):
104+
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
105+
106+
def __init__(self, http_client, client_id=None, token_cache=None):
107+
self._http_client = http_client
108+
self._client_id = client_id
109+
self._token_cache = token_cache
110+
111+
def acquire_token(self, resource):
112+
access_token_from_cache = None
113+
if self._token_cache:
114+
matches = self._token_cache.find(
115+
self._token_cache.CredentialType.ACCESS_TOKEN,
116+
target=[resource],
117+
query=dict(
118+
client_id=self._client_id,
119+
environment=self._instance,
120+
realm=self._tenant,
121+
home_account_id=None,
122+
),
123+
)
124+
now = time.time()
125+
for entry in matches:
126+
expires_in = int(entry["expires_on"]) - now
127+
if expires_in < 5*60: # Then consider it expired
128+
continue # Removal is not necessary, it will be overwritten
129+
logger.debug("Cache hit an AT")
130+
access_token_from_cache = { # Mimic a real response
131+
"access_token": entry["secret"],
132+
"token_type": entry.get("token_type", "Bearer"),
133+
"expires_in": int(expires_in), # OAuth2 specs defines it as int
134+
}
135+
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
136+
break # With a fallback in hand, we break here to go refresh
137+
return access_token_from_cache # It is still good as new
138+
result = _obtain_token(self._http_client, resource, client_id=self._client_id)
139+
if self._token_cache and "access_token" in result:
140+
self._token_cache.add(dict(
141+
client_id=self._client_id,
142+
scope=[resource],
143+
token_endpoint="https://{}/{}".format(self._instance, self._tenant),
144+
response=result,
145+
params={},
146+
data={},
147+
#grant_type="placeholder",
148+
))
149+
return result
150+
return access_token_from_cache
151+

tests/msaltest.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import getpass, logging, pprint, sys, msal
1+
import functools, getpass, logging, pprint, sys, requests, msal
22

33

44
AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
@@ -141,15 +141,16 @@ def remove_account(app):
141141
app.remove_account(account)
142142
print('Account "{}" and/or its token(s) are signed out from MSAL Python'.format(account["username"]))
143143

144-
def acquire_token_for_client(app):
145-
"""acquire_token_for_client() - Only for confidential client"""
146-
pprint.pprint(app.acquire_token_for_client(_input_scopes()))
144+
def acquire_token_for_managed_identity(app):
145+
"""acquire_token() - Only for managed identity"""
146+
resource = "https://management.azure.com/" # TODO: Are there other resources?
147+
pprint.pprint(app.acquire_token(resource))
147148

148149
def exit(app):
149150
"""Exit"""
150151
bug_link = (
151152
"https://identitydivision.visualstudio.com/Engineering/_queries/query/79b3a352-a775-406f-87cd-a487c382a8ed/"
152-
if app._enable_broker else
153+
if getattr(app, "_enable_broker", None) else
153154
"https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/new/choose"
154155
)
155156
print("Bye. If you found a bug, please report it here: {}".format(bug_link))
@@ -161,12 +162,19 @@ def main():
161162
{"client_id": AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
162163
{"client_id": VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
163164
{"client_id": "95de633a-083e-42f5-b444-a4295d8e9314", "name": "Whiteboard Services (Non MSA-PT app. Accepts AAD & MSA accounts.)"},
164-
{"client_id": None, "client_secret": None, "name": "System-assigned Managed Identity (Only works when running inside a supported environment, such as Azure VM, Azure App Service, Azure Automation)"},
165+
{"managed_identity_client_id": None, "name": "System-assigned Managed Identity (Only works when running inside a supported environment, such as Azure VM, Azure App Service, Azure Automation)"},
165166
],
166167
option_renderer=lambda a: a["name"],
167168
header="Impersonate this app (or you can type in the client_id of your own app)",
168169
accept_nonempty_string=True)
169-
authority = _select_options([
170+
if isinstance(chosen_app, dict) and "managed_identity_client_id" in chosen_app:
171+
app = msal.ManagedIdentity(
172+
requests.Session(),
173+
client_id=chosen_app["managed_identity_client_id"],
174+
token_cache=msal.TokenCache(),
175+
)
176+
else:
177+
authority = _select_options([
170178
"https://login.microsoftonline.com/common",
171179
"https://login.microsoftonline.com/organizations",
172180
"https://login.microsoftonline.com/microsoft.onmicrosoft.com",
@@ -175,33 +183,32 @@ def main():
175183
],
176184
header="Input authority (Note that MSA-PT apps would NOT use the /common authority)",
177185
accept_nonempty_string=True,
178-
)
179-
if isinstance(chosen_app, dict) and "client_secret" in chosen_app:
180-
app = msal.ConfidentialClientApplication(
181-
chosen_app["client_id"],
182-
client_credential=chosen_app["client_secret"],
183-
authority=authority,
184-
)
185-
else:
186+
)
186187
app = msal.PublicClientApplication(
187188
chosen_app["client_id"] if isinstance(chosen_app, dict) else chosen_app,
188189
authority=authority,
189190
allow_broker=_input_boolean("Allow broker? (Azure CLI currently only supports @microsoft.com accounts when enabling broker)"),
190191
)
191192
if _input_boolean("Enable MSAL Python's DEBUG log?"):
192193
logging.basicConfig(level=logging.DEBUG)
194+
methods_to_be_tested = functools.reduce(lambda x, y: x + y, [
195+
methods for app_type, methods in {
196+
msal.PublicClientApplication: [
197+
acquire_token_interactive,
198+
acquire_ssh_cert_silently,
199+
acquire_ssh_cert_interactive,
200+
],
201+
msal.ClientApplication: [
202+
acquire_token_silent,
203+
acquire_token_by_username_password,
204+
remove_account,
205+
],
206+
msal.ManagedIdentity: [acquire_token_for_managed_identity],
207+
}.items() if isinstance(app, app_type)])
193208
while True:
194-
func = _select_options(list(filter(None, [
195-
acquire_token_silent,
196-
acquire_token_interactive,
197-
acquire_token_by_username_password,
198-
acquire_ssh_cert_silently,
199-
acquire_ssh_cert_interactive,
200-
remove_account,
201-
acquire_token_for_client if isinstance(
202-
app, msal.ConfidentialClientApplication) else None,
203-
exit,
204-
])), option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:")
209+
func = _select_options(
210+
methods_to_be_tested + [exit],
211+
option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:")
205212
try:
206213
func(app)
207214
except ValueError as e:

0 commit comments

Comments
 (0)