14
14
import requests
15
15
16
16
from .oauth2cli import Client , JwtAssertionCreator
17
+ from .oauth2cli .oidc import decode_part
17
18
from .authority import Authority
18
19
from .mex import send_request as mex_send_request
19
20
from .wstrust_request import send_request as wst_send_request
@@ -111,6 +112,34 @@ def _preferred_browser():
111
112
return None
112
113
113
114
115
+ class _ClientWithCcsRoutingInfo (Client ):
116
+
117
+ def initiate_auth_code_flow (self , ** kwargs ):
118
+ return super (_ClientWithCcsRoutingInfo , self ).initiate_auth_code_flow (
119
+ client_info = 1 , # To be used as CSS Routing info
120
+ ** kwargs )
121
+
122
+ def obtain_token_by_auth_code_flow (
123
+ self , auth_code_flow , auth_response , ** kwargs ):
124
+ # Note: the obtain_token_by_browser() is also covered by this
125
+ assert isinstance (auth_code_flow , dict ) and isinstance (auth_response , dict )
126
+ headers = kwargs .pop ("headers" , {})
127
+ client_info = json .loads (
128
+ decode_part (auth_response ["client_info" ])
129
+ ) if auth_response .get ("client_info" ) else {}
130
+ if "uid" in client_info and "utid" in client_info :
131
+ # Note: The value of X-AnchorMailbox is also case-insensitive
132
+ headers ["X-AnchorMailbox" ] = "Oid:{uid}@{utid}" .format (** client_info )
133
+ return super (_ClientWithCcsRoutingInfo , self ).obtain_token_by_auth_code_flow (
134
+ auth_code_flow , auth_response , headers = headers , ** kwargs )
135
+
136
+ def obtain_token_by_username_password (self , username , password , ** kwargs ):
137
+ headers = kwargs .pop ("headers" , {})
138
+ headers ["X-AnchorMailbox" ] = "upn:{}" .format (username )
139
+ return super (_ClientWithCcsRoutingInfo , self ).obtain_token_by_username_password (
140
+ username , password , headers = headers , ** kwargs )
141
+
142
+
114
143
class ClientApplication (object ):
115
144
116
145
ACQUIRE_TOKEN_SILENT_ID = "84"
@@ -481,7 +510,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
481
510
authority .device_authorization_endpoint or
482
511
urljoin (authority .token_endpoint , "devicecode" ),
483
512
}
484
- central_client = Client (
513
+ central_client = _ClientWithCcsRoutingInfo (
485
514
central_configuration ,
486
515
self .client_id ,
487
516
http_client = self .http_client ,
@@ -506,7 +535,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
506
535
regional_authority .device_authorization_endpoint or
507
536
urljoin (regional_authority .token_endpoint , "devicecode" ),
508
537
}
509
- regional_client = Client (
538
+ regional_client = _ClientWithCcsRoutingInfo (
510
539
regional_configuration ,
511
540
self .client_id ,
512
541
http_client = self .http_client ,
@@ -577,7 +606,7 @@ def initiate_auth_code_flow(
577
606
3. and then relay this dict and subsequent auth response to
578
607
:func:`~acquire_token_by_auth_code_flow()`.
579
608
"""
580
- client = Client (
609
+ client = _ClientWithCcsRoutingInfo (
581
610
{"authorization_endpoint" : self .authority .authorization_endpoint },
582
611
self .client_id ,
583
612
http_client = self .http_client )
@@ -654,7 +683,7 @@ def get_authorization_request_url(
654
683
self .http_client
655
684
) if authority else self .authority
656
685
657
- client = Client (
686
+ client = _ClientWithCcsRoutingInfo (
658
687
{"authorization_endpoint" : the_authority .authorization_endpoint },
659
688
self .client_id ,
660
689
http_client = self .http_client )
@@ -1178,6 +1207,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
1178
1207
key = lambda e : int (e .get ("last_modification_time" , "0" )),
1179
1208
reverse = True ):
1180
1209
logger .debug ("Cache attempts an RT" )
1210
+ headers = telemetry_context .generate_headers ()
1211
+ if "home_account_id" in query : # Then use it as CCS Routing info
1212
+ headers ["X-AnchorMailbox" ] = "Oid:{}" .format ( # case-insensitive value
1213
+ query ["home_account_id" ].replace ("." , "@" ))
1181
1214
response = client .obtain_token_by_refresh_token (
1182
1215
entry , rt_getter = lambda token_item : token_item ["secret" ],
1183
1216
on_removing_rt = lambda rt_item : None , # Disable RT removal,
@@ -1189,7 +1222,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
1189
1222
skip_account_creation = True , # To honor a concurrent remove_account()
1190
1223
)),
1191
1224
scope = scopes ,
1192
- headers = telemetry_context . generate_headers () ,
1225
+ headers = headers ,
1193
1226
data = dict (
1194
1227
kwargs .pop ("data" , {}),
1195
1228
claims = _merge_claims_challenge_and_capabilities (
0 commit comments