21
21
import msal .telemetry
22
22
from .region import _detect_region
23
23
from .throttled_http_client import ThrottledHttpClient
24
+ from .cloudshell import _is_running_in_cloud_shell
24
25
25
26
26
27
# The __init__.py will import this. Not the other way around.
27
- __version__ = "1.17.0 " # When releasing, also check and bump our dependencies's versions if needed
28
+ __version__ = "1.18.0b1 " # When releasing, also check and bump our dependencies's versions if needed
28
29
29
30
logger = logging .getLogger (__name__ )
30
-
31
+ _AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL"
31
32
32
33
def extract_certs (public_cert_content ):
33
34
# Parses raw public certificate file contents and returns a list of strings
@@ -636,6 +637,7 @@ def initiate_auth_code_flow(
636
637
domain_hint = None , # type: Optional[str]
637
638
claims_challenge = None ,
638
639
max_age = None ,
640
+ response_mode = None , # type: Optional[str]
639
641
):
640
642
"""Initiate an auth code flow.
641
643
@@ -677,6 +679,20 @@ def initiate_auth_code_flow(
677
679
678
680
New in version 1.15.
679
681
682
+ :param str response_mode:
683
+ OPTIONAL. Specifies the method with which response parameters should be returned.
684
+ The default value is equivalent to ``query``, which is still secure enough in MSAL Python
685
+ (because MSAL Python does not transfer tokens via query parameter in the first place).
686
+ For even better security, we recommend using the value ``form_post``.
687
+ In "form_post" mode, response parameters
688
+ will be encoded as HTML form values that are transmitted via the HTTP POST method and
689
+ encoded in the body using the application/x-www-form-urlencoded format.
690
+ Valid values can be either "form_post" for HTTP POST to callback URI or
691
+ "query" (the default) for HTTP GET with parameters encoded in query string.
692
+ More information on possible values
693
+ `here <https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#ResponseModes>`
694
+ and `here <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html#FormPostResponseMode>`
695
+
680
696
:return:
681
697
The auth code flow. It is a dict in this form::
682
698
@@ -707,6 +723,7 @@ def initiate_auth_code_flow(
707
723
claims = _merge_claims_challenge_and_capabilities (
708
724
self ._client_capabilities , claims_challenge ),
709
725
max_age = max_age ,
726
+ response_mode = response_mode ,
710
727
)
711
728
flow ["claims_challenge" ] = claims_challenge
712
729
return flow
@@ -970,6 +987,10 @@ def get_accounts(self, username=None):
970
987
return accounts
971
988
972
989
def _find_msal_accounts (self , environment ):
990
+ interested_authority_types = [
991
+ TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS ]
992
+ if _is_running_in_cloud_shell ():
993
+ interested_authority_types .append (_AUTHORITY_TYPE_CLOUDSHELL )
973
994
grouped_accounts = {
974
995
a .get ("home_account_id" ): # Grouped by home tenant's id
975
996
{ # These are minimal amount of non-tenant-specific account info
@@ -985,8 +1006,7 @@ def _find_msal_accounts(self, environment):
985
1006
for a in self .token_cache .find (
986
1007
TokenCache .CredentialType .ACCOUNT ,
987
1008
query = {"environment" : environment })
988
- if a ["authority_type" ] in (
989
- TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS )
1009
+ if a ["authority_type" ] in interested_authority_types
990
1010
}
991
1011
return list (grouped_accounts .values ())
992
1012
@@ -1046,6 +1066,21 @@ def _forget_me(self, home_account):
1046
1066
TokenCache .CredentialType .ACCOUNT , query = owned_by_home_account ):
1047
1067
self .token_cache .remove_account (a )
1048
1068
1069
+ def _acquire_token_by_cloud_shell (self , scopes , data = None ):
1070
+ from .cloudshell import _obtain_token
1071
+ response = _obtain_token (
1072
+ self .http_client , scopes , client_id = self .client_id , data = data )
1073
+ if "error" not in response :
1074
+ self .token_cache .add (dict (
1075
+ client_id = self .client_id ,
1076
+ scope = response ["scope" ].split () if "scope" in response else scopes ,
1077
+ token_endpoint = self .authority .token_endpoint ,
1078
+ response = response .copy (),
1079
+ data = data or {},
1080
+ authority_type = _AUTHORITY_TYPE_CLOUDSHELL ,
1081
+ ))
1082
+ return response
1083
+
1049
1084
def acquire_token_silent (
1050
1085
self ,
1051
1086
scopes , # type: List[str]
@@ -1179,6 +1214,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
1179
1214
authority , # This can be different than self.authority
1180
1215
force_refresh = False , # type: Optional[boolean]
1181
1216
claims_challenge = None ,
1217
+ correlation_id = None ,
1182
1218
** kwargs ):
1183
1219
access_token_from_cache = None
1184
1220
if not (force_refresh or claims_challenge ): # Bypass AT when desired or using claims
@@ -1217,9 +1253,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
1217
1253
refresh_reason = msal .telemetry .FORCE_REFRESH # TODO: It could also mean claims_challenge
1218
1254
assert refresh_reason , "It should have been established at this point"
1219
1255
try :
1256
+ if account and account .get ("authority_type" ) == _AUTHORITY_TYPE_CLOUDSHELL :
1257
+ return self ._acquire_token_by_cloud_shell (
1258
+ scopes , data = kwargs .get ("data" ))
1220
1259
result = _clean_up (self ._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
1221
1260
authority , self ._decorate_scope (scopes ), account ,
1222
1261
refresh_reason = refresh_reason , claims_challenge = claims_challenge ,
1262
+ correlation_id = correlation_id ,
1223
1263
** kwargs ))
1224
1264
if (result and "error" not in result ) or (not access_token_from_cache ):
1225
1265
return result
@@ -1558,6 +1598,9 @@ def acquire_token_interactive(
1558
1598
- A dict containing an "error" key, when token refresh failed.
1559
1599
"""
1560
1600
self ._validate_ssh_cert_input_data (kwargs .get ("data" , {}))
1601
+ if _is_running_in_cloud_shell () and prompt == "none" :
1602
+ return self ._acquire_token_by_cloud_shell (
1603
+ scopes , data = kwargs .pop ("data" , {}))
1561
1604
claims = _merge_claims_challenge_and_capabilities (
1562
1605
self ._client_capabilities , claims_challenge )
1563
1606
telemetry_context = self ._build_telemetry_context (
@@ -1659,6 +1702,11 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
1659
1702
- an error response would contain "error" and usually "error_description".
1660
1703
"""
1661
1704
# TBD: force_refresh behavior
1705
+ if self .authority .tenant .lower () in ["common" , "organizations" ]:
1706
+ warnings .warn (
1707
+ "Using /common or /organizations authority "
1708
+ "in acquire_token_for_client() is unreliable. "
1709
+ "Please use a specific tenant instead." , DeprecationWarning )
1662
1710
self ._validate_ssh_cert_input_data (kwargs .get ("data" , {}))
1663
1711
telemetry_context = self ._build_telemetry_context (
1664
1712
self .ACQUIRE_TOKEN_FOR_CLIENT_ID )
0 commit comments