Skip to content

Commit 78d83e0

Browse files
authored
[Identity] Propagate additional token type values (#37579)
Signed-off-by: Paul Van Eck <[email protected]>
1 parent 45f482f commit 78d83e0

File tree

8 files changed

+48
-10
lines changed

8 files changed

+48
-10
lines changed

sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[Acces
136136
if result and "access_token" in result and "expires_in" in result:
137137
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None
138138
return AccessTokenInfo(
139-
result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on
139+
result["access_token"],
140+
now + int(result["expires_in"]),
141+
token_type=result.get("token_type", "Bearer"),
142+
refresh_on=refresh_on,
140143
)
141144

142145
return None
@@ -157,4 +160,9 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
157160
pass # non-fatal; we'll use the assertion again next time instead of a refresh token
158161

159162
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None
160-
return AccessTokenInfo(result["access_token"], request_time + int(result["expires_in"]), refresh_on=refresh_on)
163+
return AccessTokenInfo(
164+
result["access_token"],
165+
request_time + int(result["expires_in"]),
166+
token_type=result.get("token_type", "Bearer"),
167+
refresh_on=refresh_on,
168+
)

sdk/identity/azure-identity/azure/identity/_credentials/silent.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,15 @@ def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
186186
result = client_application.acquire_token_silent_with_error(
187187
list(scopes), account=account, claims_challenge=kwargs.get("claims")
188188
)
189+
189190
if result and "access_token" in result and "expires_in" in result:
190-
return AccessTokenInfo(result["access_token"], now + int(result["expires_in"]))
191+
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None
192+
return AccessTokenInfo(
193+
result["access_token"],
194+
now + int(result["expires_in"]),
195+
token_type=result.get("token_type", "Bearer"),
196+
refresh_on=refresh_on,
197+
)
191198

192199
# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
193200
if result:

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optio
9595
expires_on = int(token["expires_on"])
9696
if expires_on > int(time.time()):
9797
refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None
98-
return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on)
98+
return AccessTokenInfo(
99+
token["secret"], expires_on, token_type=token.get("token_type", "Bearer"), refresh_on=refresh_on
100+
)
99101
return None
100102

101103
def get_cached_refresh_tokens(self, scopes: Iterable[str], **kwargs) -> List[Dict]:
@@ -178,7 +180,9 @@ def _process_response(self, response: PipelineResponse, request_time: int, **kwa
178180
content["refresh_in"] = expires_in // 2
179181

180182
refresh_on = request_time + int(content["refresh_in"]) if "refresh_in" in content else None
181-
token = AccessTokenInfo(content["access_token"], expires_on, refresh_on=refresh_on)
183+
token = AccessTokenInfo(
184+
content["access_token"], expires_on, token_type=content.get("token_type", "Bearer"), refresh_on=refresh_on
185+
)
182186

183187
# caching is the final step because 'add' mutates 'content'
184188
cache.add(

sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[Acces
3434
return AccessTokenInfo(
3535
result["access_token"],
3636
request_time + int(result["expires_in"]),
37+
token_type=result.get("token_type", "Bearer"),
3738
refresh_on=refresh_on,
3839
)
3940
return None
@@ -51,5 +52,6 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
5152
return AccessTokenInfo(
5253
result["access_token"],
5354
request_time + int(result["expires_in"]),
55+
token_type=result.get("token_type", "Bearer"),
5456
refresh_on=refresh_on,
5557
)

sdk/identity/azure-identity/azure/identity/_internal/interactive.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,10 @@ def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
286286
if result and "access_token" in result and "expires_in" in result:
287287
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None
288288
return AccessTokenInfo(
289-
result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on
289+
result["access_token"],
290+
now + int(result["expires_in"]),
291+
token_type=result.get("token_type", "Bearer"),
292+
refresh_on=refresh_on,
290293
)
291294

292295
# if we get this far, result is either None or the content of a Microsoft Entra ID error response

sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ def _process_response(self, response: PipelineResponse, request_time: int) -> Ac
7676
content["refresh_in"] = expires_in // 2
7777

7878
refresh_on = request_time + int(content["refresh_in"]) if "refresh_in" in content else None
79-
token = AccessTokenInfo(content["access_token"], content["expires_on"], refresh_on=refresh_on)
79+
token = AccessTokenInfo(
80+
content["access_token"],
81+
content["expires_on"],
82+
token_type=content.get("token_type", "Bearer"),
83+
refresh_on=refresh_on,
84+
)
8085

8186
# caching is the final step because TokenCache.add mutates its "event"
8287
self._cache.add(
@@ -93,7 +98,9 @@ def get_cached_token(self, *scopes: str) -> Optional[AccessTokenInfo]:
9398
expires_on = int(token["expires_on"])
9499
refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None
95100
if expires_on > now and (not refresh_on or refresh_on > now):
96-
return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on)
101+
return AccessTokenInfo(
102+
token["secret"], expires_on, token_type=token.get("token_type", "Bearer"), refresh_on=refresh_on
103+
)
97104

98105
return None
99106

sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pyl
5353
now = int(time.time())
5454
if result and "access_token" in result and "expires_in" in result:
5555
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None
56-
return AccessTokenInfo(result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on)
56+
return AccessTokenInfo(
57+
result["access_token"],
58+
now + int(result["expires_in"]),
59+
token_type=result.get("token_type", "Bearer"),
60+
refresh_on=refresh_on,
61+
)
5762
if result and "error" in result:
5863
error_desc = cast(str, result["error"])
5964
error_message = self.get_unavailable_message(error_desc)

sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def _get_cached_access_token(
243243
expires_on = int(token["expires_on"])
244244
refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None
245245
if expires_on - 300 > int(time.time()):
246-
return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on)
246+
return AccessTokenInfo(
247+
token["secret"], expires_on, token_type=token.get("token_type", "Bearer"), refresh_on=refresh_on
248+
)
247249
except Exception as ex: # pylint:disable=broad-except
248250
message = "Error accessing cached data: {}".format(ex)
249251
raise CredentialUnavailableError(message=message) from ex

0 commit comments

Comments
 (0)