Skip to content

Commit a17d42f

Browse files
authored
[Corehttp] Update credential classes (#38346)
Add new `AccessTokenInfo` class and updated `TokenCredential` protocols to use `get_token_info`. Signed-off-by: Paul Van Eck <[email protected]>
1 parent c3fcc48 commit a17d42f

File tree

8 files changed

+224
-98
lines changed

8 files changed

+224
-98
lines changed

sdk/core/corehttp/CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
# Release History
22

3+
## 1.0.0b6 (Unreleased)
4+
5+
### Features Added
6+
7+
- The `TokenCredential` and `AsyncTokenCredential` protocols have been updated to include a new `get_token_info` method. This method should be used to acquire tokens and return an `AccessTokenInfo` object. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
8+
- Added a new `TokenRequestOptions` class, which is a `TypedDict` with optional parameters, that can be used to define options for token requests through the `get_token_info` method. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
9+
- Added a new `AccessTokenInfo` class, which is returned by `get_token_info` implementations. This class contains the token, its expiration time, and optional additional information like when a token should be refreshed. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
10+
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
11+
- These policies now also check the `refresh_on` attribute when determining if a new token request should be made.
12+
13+
### Breaking Changes
14+
15+
- The `get_token` method has been removed from the `TokenCredential` and `AsyncTokenCredential` protocols. Implementations should now use the new `get_token_info` method to acquire tokens. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
16+
- The `AccessToken` class has been removed and replaced with a new `AccessTokenInfo` class. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
17+
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now rely on credentials having the `get_token_info` method defined. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
18+
19+
### Bugs Fixed
20+
21+
### Other Changes
22+
323
## 1.0.0b5 (2024-02-29)
424

525
### Other Changes

sdk/core/corehttp/corehttp/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
# regenerated.
1010
# --------------------------------------------------------------------------
1111

12-
VERSION = "1.0.0b5"
12+
VERSION = "1.0.0b6"

sdk/core/corehttp/corehttp/credentials.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,71 @@
55
# -------------------------------------------------------------------------
66
from __future__ import annotations
77
from types import TracebackType
8-
from typing import Any, NamedTuple, Optional, AsyncContextManager, Type
8+
from typing import NamedTuple, Optional, AsyncContextManager, Type, TypedDict, ContextManager
99
from typing_extensions import Protocol, runtime_checkable
1010

1111

12-
class AccessToken(NamedTuple):
13-
"""Represents an OAuth access token."""
12+
class AccessTokenInfo:
13+
"""Information about an OAuth access token.
14+
15+
:param str token: The token string.
16+
:param int expires_on: The token's expiration time in Unix time.
17+
:keyword str token_type: The type of access token. Defaults to 'Bearer'.
18+
:keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively
19+
refreshed. Optional.
20+
"""
1421

1522
token: str
23+
"""The token string."""
1624
expires_on: int
25+
"""The token's expiration time in Unix time."""
26+
token_type: str
27+
"""The type of access token."""
28+
refresh_on: Optional[int]
29+
"""Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional."""
30+
31+
def __init__(
32+
self, token: str, expires_on: int, *, token_type: str = "Bearer", refresh_on: Optional[int] = None
33+
) -> None:
34+
self.token = token
35+
self.expires_on = expires_on
36+
self.token_type = token_type
37+
self.refresh_on = refresh_on
1738

39+
def __repr__(self) -> str:
40+
return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format(
41+
self.token, self.expires_on, self.token_type, self.refresh_on
42+
)
1843

19-
AccessToken.token.__doc__ = """The token string."""
20-
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""
2144

45+
class TokenRequestOptions(TypedDict, total=False):
46+
"""Options to use for access token requests. All parameters are optional."""
2247

23-
@runtime_checkable
24-
class TokenCredential(Protocol):
25-
"""Protocol for classes able to provide OAuth tokens."""
48+
claims: str
49+
"""Additional claims required in the token, such as those returned in a resource provider's claims
50+
challenge following an authorization failure."""
51+
tenant_id: str
52+
"""The tenant ID to include in the token request."""
2653

27-
def get_token(self, *scopes: str, claims: Optional[str] = None, **kwargs: Any) -> AccessToken:
28-
"""Request an access token for `scopes`.
2954

30-
:param str scopes: The type of access needed.
55+
class TokenCredential(Protocol, ContextManager["TokenCredential"]):
56+
"""Protocol for classes able to provide OAuth access tokens."""
3157

32-
:keyword str claims: Additional claims required in the token, such as those returned in a resource
33-
provider's claims challenge following an authorization failure.
58+
def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
59+
"""Request an access token for `scopes`.
3460
61+
:param str scopes: The type of access needed.
62+
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
63+
:paramtype options: TokenRequestOptions
3564
36-
:rtype: AccessToken
37-
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
65+
:rtype: AccessTokenInfo
66+
:return: An AccessTokenInfo instance containing information about the token.
3867
"""
3968
...
4069

70+
def close(self) -> None:
71+
pass
72+
4173

4274
class ServiceNamedKey(NamedTuple):
4375
"""Represents a name and key pair."""
@@ -47,10 +79,11 @@ class ServiceNamedKey(NamedTuple):
4779

4880

4981
__all__ = [
50-
"AccessToken",
82+
"AccessTokenInfo",
5183
"ServiceKeyCredential",
5284
"ServiceNamedKeyCredential",
5385
"TokenCredential",
86+
"TokenRequestOptions",
5487
"AsyncTokenCredential",
5588
]
5689

@@ -134,16 +167,15 @@ def update(self, name: str, key: str) -> None:
134167
class AsyncTokenCredential(Protocol, AsyncContextManager["AsyncTokenCredential"]):
135168
"""Protocol for classes able to provide OAuth tokens."""
136169

137-
async def get_token(self, *scopes: str, claims: Optional[str] = None, **kwargs: Any) -> AccessToken:
170+
async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
138171
"""Request an access token for `scopes`.
139172
140173
:param str scopes: The type of access needed.
174+
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
175+
:paramtype options: TokenRequestOptions
141176
142-
:keyword str claims: Additional claims required in the token, such as those returned in a resource
143-
provider's claims challenge following an authorization failure.
144-
145-
:rtype: AccessToken
146-
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
177+
:rtype: AccessTokenInfo
178+
:return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time.
147179
"""
148180
...
149181

sdk/core/corehttp/corehttp/runtime/policies/_authentication.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import time
88
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any
99

10+
from ...credentials import TokenRequestOptions
1011
from ...rest import HttpResponse, HttpRequest
1112
from . import HTTPPolicy, SansIOHTTPPolicy
1213
from ...exceptions import ServiceRequestError
1314

1415
if TYPE_CHECKING:
1516

1617
from ...credentials import (
17-
AccessToken,
18+
AccessTokenInfo,
1819
TokenCredential,
1920
ServiceKeyCredential,
2021
)
@@ -39,7 +40,7 @@ def __init__(
3940
super(_BearerTokenCredentialPolicyBase, self).__init__()
4041
self._scopes = scopes
4142
self._credential = credential
42-
self._token: Optional["AccessToken"] = None
43+
self._token: Optional["AccessTokenInfo"] = None
4344

4445
@staticmethod
4546
def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
@@ -68,7 +69,12 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
6869

6970
@property
7071
def _need_new_token(self) -> bool:
71-
return not self._token or self._token.expires_on - time.time() < 300
72+
now = time.time()
73+
return (
74+
not self._token
75+
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
76+
or (self._token.expires_on - now < 300)
77+
)
7278

7379

7480
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
@@ -90,7 +96,7 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
9096
self._enforce_https(request)
9197

9298
if self._token is None or self._need_new_token:
93-
self._token = self._credential.get_token(*self._scopes)
99+
self._token = self._credential.get_token_info(*self._scopes)
94100
self._update_headers(request.http_request.headers, self._token.token)
95101

96102
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
@@ -102,7 +108,12 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
102108
:param ~corehttp.runtime.pipeline.PipelineRequest request: the request
103109
:param str scopes: required scopes of authentication
104110
"""
105-
self._token = self._credential.get_token(*scopes, **kwargs)
111+
options: TokenRequestOptions = {}
112+
# Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
113+
for key in list(kwargs.keys()):
114+
if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
115+
options[key] = kwargs.pop(key) # type: ignore[literal-required]
116+
self._token = self._credential.get_token_info(*scopes, options=options)
106117
self._update_headers(request.http_request.headers, self._token.token)
107118

108119
def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:

sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar
99

10-
from ...credentials import AccessToken
10+
from ...credentials import AccessTokenInfo, TokenRequestOptions
1111
from ..pipeline import PipelineRequest, PipelineResponse
1212
from ..pipeline._tools_async import await_result
1313
from ._base_async import AsyncHTTPPolicy
@@ -38,7 +38,7 @@ def __init__(
3838
self._credential = credential
3939
self._lock_instance = None
4040
self._scopes = scopes
41-
self._token: Optional["AccessToken"] = None
41+
self._token: Optional[AccessTokenInfo] = None
4242

4343
@property
4444
def _lock(self):
@@ -55,12 +55,12 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
5555
"""
5656
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
5757

58-
if self._token is None or self._need_new_token():
58+
if self._token is None or self._need_new_token:
5959
async with self._lock:
6060
# double check because another coroutine may have acquired a token while we waited to acquire the lock
61-
if self._token is None or self._need_new_token():
62-
self._token = await await_result(self._credential.get_token, *self._scopes)
63-
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
61+
if self._token is None or self._need_new_token:
62+
self._token = await await_result(self._credential.get_token_info, *self._scopes)
63+
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token
6464

6565
async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
6666
"""Acquire a token from the credential and authorize the request with it.
@@ -71,9 +71,15 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
7171
:param ~corehttp.runtime.pipeline.PipelineRequest request: the request
7272
:param str scopes: required scopes of authentication
7373
"""
74+
options: TokenRequestOptions = {}
75+
# Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
76+
for key in list(kwargs.keys()):
77+
if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
78+
options[key] = kwargs.pop(key) # type: ignore[literal-required]
79+
7480
async with self._lock:
75-
self._token = await await_result(self._credential.get_token, *scopes, **kwargs)
76-
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
81+
self._token = await await_result(self._credential.get_token_info, *scopes, options=options)
82+
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token
7783

7884
async def send(
7985
self, request: PipelineRequest[HTTPRequestType]
@@ -149,5 +155,11 @@ def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
149155
# pylint: disable=unused-argument
150156
return
151157

158+
@property
152159
def _need_new_token(self) -> bool:
153-
return not self._token or self._token.expires_on - time.time() < 300
160+
now = time.time()
161+
return (
162+
not self._token
163+
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
164+
or (self._token.expires_on - now < 300)
165+
)

sdk/core/corehttp/samples/sample_async_pipeline_client.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,6 @@
1616
import asyncio
1717
from typing import Iterable, Union
1818

19-
from corehttp.runtime import AsyncPipelineClient
20-
from corehttp.rest import HttpRequest, AsyncHttpResponse
21-
from corehttp.runtime.policies import (
22-
AsyncHTTPPolicy,
23-
SansIOHTTPPolicy,
24-
HeadersPolicy,
25-
UserAgentPolicy,
26-
AsyncRetryPolicy,
27-
)
28-
2919

3020
async def sample_pipeline_client():
3121
# [START build_async_pipeline_client]

0 commit comments

Comments
 (0)