Skip to content

Commit 093efca

Browse files
authored
Corehttp auth flows (#40084)
* Added `auth_flows` support in `BearerTokenCredentialPolicy` * update * updates * updates * disable mypy "typeddict-unknown-key" * update * update * Update changelog * update * update * updates * updates * update type of auth flows
1 parent a1987ab commit 093efca

File tree

5 files changed

+79
-15
lines changed

5 files changed

+79
-15
lines changed

sdk/core/corehttp/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
1111
- These policies now also check the `refresh_on` attribute when determining if a new token request should be made.
1212
- Added `model` attribute to `HttpResponseError` to allow accessing error attributes based on a known model. [#39636](https://github.com/Azure/azure-sdk-for-python/pull/39636)
13+
- Added `auth_flows` support in `BearerTokenCredentialPolicy`. [#40084](https://github.com/Azure/azure-sdk-for-python/pull/40084)
1314

1415
### Breaking Changes
1516

sdk/core/corehttp/corehttp/runtime/policies/_authentication.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# -------------------------------------------------------------------------
66
from __future__ import annotations
77
import time
8-
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any
8+
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union
99

1010
from ...credentials import TokenRequestOptions
1111
from ...rest import HttpResponse, HttpRequest
@@ -32,15 +32,23 @@ class _BearerTokenCredentialPolicyBase:
3232
:param credential: The credential.
3333
:type credential: ~corehttp.credentials.TokenCredential
3434
:param str scopes: Lets you specify the type of access needed.
35+
:keyword auth_flows: A list of authentication flows to use for the credential.
36+
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
3537
"""
3638

39+
# pylint: disable=unused-argument
3740
def __init__(
38-
self, credential: "TokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
41+
self,
42+
credential: "TokenCredential",
43+
*scopes: str,
44+
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
45+
**kwargs: Any,
3946
) -> None:
4047
super(_BearerTokenCredentialPolicyBase, self).__init__()
4148
self._scopes = scopes
4249
self._credential = credential
4350
self._token: Optional["AccessTokenInfo"] = None
51+
self._auth_flows = auth_flows
4452

4553
@staticmethod
4654
def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
@@ -83,20 +91,30 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
8391
:param credential: The credential.
8492
:type credential: ~corehttp.TokenCredential
8593
:param str scopes: Lets you specify the type of access needed.
94+
:keyword auth_flows: A list of authentication flows to use for the credential.
95+
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
8696
:raises: :class:`~corehttp.exceptions.ServiceRequestError`
8797
"""
8898

89-
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
99+
def on_request(
100+
self,
101+
request: PipelineRequest[HTTPRequestType],
102+
*,
103+
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
104+
) -> None:
90105
"""Called before the policy sends a request.
91106
92107
The base implementation authorizes the request with a bearer token.
93108
94109
:param ~corehttp.runtime.pipeline.PipelineRequest request: the request
110+
:keyword auth_flows: A list of authentication flows to use for the credential.
111+
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
95112
"""
96113
self._enforce_https(request)
97114

98115
if self._token is None or self._need_new_token:
99-
self._token = self._credential.get_token_info(*self._scopes)
116+
options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
117+
self._token = self._credential.get_token_info(*self._scopes, options=options)
100118
self._update_headers(request.http_request.headers, self._token.token)
101119

102120
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
@@ -124,7 +142,7 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT
124142
:return: The pipeline response object
125143
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
126144
"""
127-
self.on_request(request)
145+
self.on_request(request, auth_flows=self._auth_flows)
128146
try:
129147
response = self.next.send(request)
130148
except Exception:

sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# -------------------------------------------------------------------------
66
from __future__ import annotations
77
import time
8-
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar
8+
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar, Union
99

1010
from ...credentials import AccessTokenInfo, TokenRequestOptions
1111
from ..pipeline import PipelineRequest, PipelineResponse
@@ -29,28 +29,43 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
2929
:param credential: The credential.
3030
:type credential: ~corehttp.credentials.TokenCredential
3131
:param str scopes: Lets you specify the type of access needed.
32+
:keyword auth_flows: A list of authentication flows to use for the credential.
33+
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
3234
"""
3335

36+
# pylint: disable=unused-argument
3437
def __init__(
35-
self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
38+
self,
39+
credential: "AsyncTokenCredential",
40+
*scopes: str,
41+
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
42+
**kwargs: Any,
3643
) -> None:
3744
super().__init__()
3845
self._credential = credential
3946
self._lock_instance = None
4047
self._scopes = scopes
4148
self._token: Optional[AccessTokenInfo] = None
49+
self._auth_flows = auth_flows
4250

4351
@property
4452
def _lock(self):
4553
if self._lock_instance is None:
4654
self._lock_instance = get_running_async_lock()
4755
return self._lock_instance
4856

49-
async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
57+
async def on_request(
58+
self,
59+
request: PipelineRequest[HTTPRequestType],
60+
*,
61+
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
62+
) -> None:
5063
"""Adds a bearer token Authorization header to request and sends request to next policy.
5164
5265
:param request: The pipeline request object to be modified.
5366
:type request: ~corehttp.runtime.pipeline.PipelineRequest
67+
:keyword auth_flows: A list of authentication flows to use for the credential.
68+
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
5469
:raises: :class:`~corehttp.exceptions.ServiceRequestError`
5570
"""
5671
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
@@ -59,7 +74,8 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
5974
async with self._lock:
6075
# double check because another coroutine may have acquired a token while we waited to acquire the lock
6176
if self._token is None or self._need_new_token:
62-
self._token = await await_result(self._credential.get_token_info, *self._scopes)
77+
options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
78+
self._token = await await_result(self._credential.get_token_info, *self._scopes, options=options)
6379
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token
6480

6581
async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
@@ -91,7 +107,7 @@ async def send(
91107
:return: The pipeline response object
92108
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
93109
"""
94-
await await_result(self.on_request, request)
110+
await await_result(self.on_request, request, auth_flows=self._auth_flows)
95111
try:
96112
response = await self.next.send(request)
97113
except Exception:

sdk/core/corehttp/tests/async_tests/test_authentication_async.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ async def send(self, request):
212212
pipeline = AsyncPipeline(transport=transport, policies=[policy])
213213
await pipeline.run(HttpRequest("GET", "https://localhost"))
214214

215-
policy.on_request.assert_called_once_with(policy.request)
215+
policy.on_request.assert_called_once_with(policy.request, auth_flows=None)
216216
policy.on_response.assert_called_once_with(policy.request, policy.response)
217217

218218
# the policy should call on_exception when next.send() raises
@@ -275,7 +275,7 @@ def get_completed_future(result=None):
275275
@pytest.mark.asyncio
276276
async def test_async_token_credential_inheritance():
277277
class TestTokenCredential(AsyncTokenCredential):
278-
async def get_token_info(self, *scopes, options=None):
278+
async def get_token_info(self, *scopes, options={}):
279279
return "TOKEN"
280280

281281
cred = TestTokenCredential()
@@ -319,3 +319,18 @@ async def test_need_new_token():
319319
# Token is not close to expiring, but refresh_on is in the past.
320320
policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1)
321321
assert policy._need_new_token
322+
323+
324+
@pytest.mark.asyncio
325+
async def test_send_with_auth_flows():
326+
auth_flows = [{"type": "flow1"}, {"type": "flow2"}]
327+
credential = Mock(
328+
spec_set=["get_token_info"],
329+
get_token_info=Mock(return_value=get_completed_future(AccessTokenInfo("***", int(time.time()) + 3600))),
330+
)
331+
policy = AsyncBearerTokenCredentialPolicy(credential, "scope", auth_flows=auth_flows)
332+
transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200))))
333+
334+
pipeline = AsyncPipeline(transport=transport, policies=[policy])
335+
await pipeline.run(HttpRequest("GET", "https://localhost"))
336+
policy._credential.get_token_info.assert_called_with("scope", options={"auth_flows": auth_flows})

sdk/core/corehttp/tests/test_authentication.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_bearer_policy_default_context():
139139

140140
pipeline.run(HttpRequest("GET", "https://localhost"))
141141

142-
credential.get_token_info.assert_called_once_with(expected_scope)
142+
credential.get_token_info.assert_called_once_with(expected_scope, options={})
143143

144144

145145
def test_bearer_policy_context_unmodified_by_default():
@@ -194,7 +194,7 @@ def test_bearer_policy_cannot_complete_challenge():
194194

195195
assert response.http_response is expected_response
196196
assert transport.send.call_count == 1
197-
credential.get_token_info.assert_called_once_with(expected_scope)
197+
credential.get_token_info.assert_called_once_with(expected_scope, options={})
198198

199199

200200
def test_bearer_policy_calls_sansio_methods():
@@ -221,7 +221,7 @@ def send(self, request):
221221
pipeline = Pipeline(transport=transport, policies=[policy])
222222
pipeline.run(HttpRequest("GET", "https://localhost"))
223223

224-
policy.on_request.assert_called_once_with(policy.request)
224+
policy.on_request.assert_called_once_with(policy.request, auth_flows=None)
225225
policy.on_response.assert_called_once_with(policy.request, policy.response)
226226

227227
# the policy should call on_exception when next.send() raises
@@ -415,3 +415,17 @@ def test_need_new_token():
415415
# Token is not close to expiring, but refresh_on is in the past.
416416
policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1)
417417
assert policy._need_new_token
418+
419+
420+
def test_send_with_auth_flows():
421+
auth_flows = [{"type": "flow1"}, {"type": "flow2"}]
422+
credential = Mock(
423+
spec_set=["get_token_info"],
424+
get_token_info=Mock(return_value=AccessTokenInfo("***", int(time.time()) + 3600)),
425+
)
426+
policy = BearerTokenCredentialPolicy(credential, "scope", auth_flows=auth_flows)
427+
transport = Mock(send=Mock(return_value=Mock(status_code=200)))
428+
429+
pipeline = Pipeline(transport=transport, policies=[policy])
430+
pipeline.run(HttpRequest("GET", "https://localhost"))
431+
policy._credential.get_token_info.assert_called_with("scope", options={"auth_flows": auth_flows})

0 commit comments

Comments
 (0)