Skip to content

[Corehttp] Update credential classes #38346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions sdk/core/corehttp/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
# Release History

## 1.0.0b6 (Unreleased)

### Features Added

- The `TokenCredential` and `AsyncTokenCredential` protocols have been updated to include a new `get_token_info` method. This method should be used to acquire tokens and return an `AccessTokenInfo` object. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
- Added a new `TokenRequestOptions` class, which is a `TypedDict` with optional parameters, that can be used to define options for token requests through the `get_token_info` method. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
- Added a new `AccessTokenInfo` class, which is returned by `get_token_info` implementations. This class contains the token, its expiration time, and optional additional information like when a token should be refreshed. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
- These policies now also check the `refresh_on` attribute when determining if a new token request should be made.

### Breaking Changes

- The `get_token` method has been removed from the `TokenCredential` and `AsyncTokenCredential` protocols. Implementations should now use the new `get_token_info` method to acquire tokens. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
- The `AccessToken` class has been removed and replaced with a new `AccessTokenInfo` class. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now rely on credentials having the `get_token_info` method defined. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)

### Bugs Fixed

### Other Changes

## 1.0.0b5 (2024-02-29)

### Other Changes
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/corehttp/corehttp/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.0.0b5"
VERSION = "1.0.0b6"
76 changes: 54 additions & 22 deletions sdk/core/corehttp/corehttp/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,71 @@
# -------------------------------------------------------------------------
from __future__ import annotations
from types import TracebackType
from typing import Any, NamedTuple, Optional, AsyncContextManager, Type
from typing import NamedTuple, Optional, AsyncContextManager, Type, TypedDict, ContextManager
from typing_extensions import Protocol, runtime_checkable


class AccessToken(NamedTuple):
"""Represents an OAuth access token."""
class AccessTokenInfo:
"""Information about an OAuth access token.

:param str token: The token string.
:param int expires_on: The token's expiration time in Unix time.
:keyword str token_type: The type of access token. Defaults to 'Bearer'.
:keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively
refreshed. Optional.
"""

token: str
"""The token string."""
expires_on: int
"""The token's expiration time in Unix time."""
token_type: str
"""The type of access token."""
refresh_on: Optional[int]
"""Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional."""

def __init__(
self, token: str, expires_on: int, *, token_type: str = "Bearer", refresh_on: Optional[int] = None
) -> None:
self.token = token
self.expires_on = expires_on
self.token_type = token_type
self.refresh_on = refresh_on

def __repr__(self) -> str:
return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format(
self.token, self.expires_on, self.token_type, self.refresh_on
)

AccessToken.token.__doc__ = """The token string."""
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""

class TokenRequestOptions(TypedDict, total=False):
"""Options to use for access token requests. All parameters are optional."""

@runtime_checkable
class TokenCredential(Protocol):
"""Protocol for classes able to provide OAuth tokens."""
claims: str
"""Additional claims required in the token, such as those returned in a resource provider's claims
challenge following an authorization failure."""
tenant_id: str
"""The tenant ID to include in the token request."""

def get_token(self, *scopes: str, claims: Optional[str] = None, **kwargs: Any) -> AccessToken:
"""Request an access token for `scopes`.

:param str scopes: The type of access needed.
class TokenCredential(Protocol, ContextManager["TokenCredential"]):
"""Protocol for classes able to provide OAuth access tokens."""

:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
"""Request an access token for `scopes`.

:param str scopes: The type of access needed.
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
:paramtype options: TokenRequestOptions

:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
:rtype: AccessTokenInfo
:return: An AccessTokenInfo instance containing information about the token.
"""
...

def close(self) -> None:
pass


class ServiceNamedKey(NamedTuple):
"""Represents a name and key pair."""
Expand All @@ -47,10 +79,11 @@ class ServiceNamedKey(NamedTuple):


__all__ = [
"AccessToken",
"AccessTokenInfo",
"ServiceKeyCredential",
"ServiceNamedKeyCredential",
"TokenCredential",
"TokenRequestOptions",
"AsyncTokenCredential",
]

Expand Down Expand Up @@ -134,16 +167,15 @@ def update(self, name: str, key: str) -> None:
class AsyncTokenCredential(Protocol, AsyncContextManager["AsyncTokenCredential"]):
"""Protocol for classes able to provide OAuth tokens."""

async def get_token(self, *scopes: str, claims: Optional[str] = None, **kwargs: Any) -> AccessToken:
async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
"""Request an access token for `scopes`.

:param str scopes: The type of access needed.
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
:paramtype options: TokenRequestOptions

:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.

:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
:rtype: AccessTokenInfo
:return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time.
"""
...

Expand Down
21 changes: 16 additions & 5 deletions sdk/core/corehttp/corehttp/runtime/policies/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import time
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any

from ...credentials import TokenRequestOptions
from ...rest import HttpResponse, HttpRequest
from . import HTTPPolicy, SansIOHTTPPolicy
from ...exceptions import ServiceRequestError

if TYPE_CHECKING:

from ...credentials import (
AccessToken,
AccessTokenInfo,
TokenCredential,
ServiceKeyCredential,
)
Expand All @@ -39,7 +40,7 @@ def __init__(
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessToken"] = None
self._token: Optional["AccessTokenInfo"] = None

@staticmethod
def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
Expand Down Expand Up @@ -68,7 +69,12 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:

@property
def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or (self._token.expires_on - now < 300)
)


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand All @@ -90,7 +96,7 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
self._enforce_https(request)

if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._token = self._credential.get_token_info(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
Expand All @@ -102,7 +108,12 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
:param ~corehttp.runtime.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
self._token = self._credential.get_token(*scopes, **kwargs)
options: TokenRequestOptions = {}
# Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
for key in list(kwargs.keys()):
if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]
self._token = self._credential.get_token_info(*scopes, options=options)
self._update_headers(request.http_request.headers, self._token.token)

def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar

from ...credentials import AccessToken
from ...credentials import AccessTokenInfo, TokenRequestOptions
from ..pipeline import PipelineRequest, PipelineResponse
from ..pipeline._tools_async import await_result
from ._base_async import AsyncHTTPPolicy
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
self._credential = credential
self._lock_instance = None
self._scopes = scopes
self._token: Optional["AccessToken"] = None
self._token: Optional[AccessTokenInfo] = None

@property
def _lock(self):
Expand All @@ -55,12 +55,12 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
"""
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access

if self._token is None or self._need_new_token():
if self._token is None or self._need_new_token:
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token():
self._token = await await_result(self._credential.get_token, *self._scopes)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
if self._token is None or self._need_new_token:
self._token = await await_result(self._credential.get_token_info, *self._scopes)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token

async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.
Expand All @@ -71,9 +71,15 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
:param ~corehttp.runtime.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
options: TokenRequestOptions = {}
# Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
for key in list(kwargs.keys()):
if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]

async with self._lock:
self._token = await await_result(self._credential.get_token, *scopes, **kwargs)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
self._token = await await_result(self._credential.get_token_info, *scopes, options=options)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token

async def send(
self, request: PipelineRequest[HTTPRequestType]
Expand Down Expand Up @@ -149,5 +155,11 @@ def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
# pylint: disable=unused-argument
return

@property
def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or (self._token.expires_on - now < 300)
)
10 changes: 0 additions & 10 deletions sdk/core/corehttp/samples/sample_async_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@
import asyncio
from typing import Iterable, Union

from corehttp.runtime import AsyncPipelineClient
from corehttp.rest import HttpRequest, AsyncHttpResponse
from corehttp.runtime.policies import (
AsyncHTTPPolicy,
SansIOHTTPPolicy,
HeadersPolicy,
UserAgentPolicy,
AsyncRetryPolicy,
)


async def sample_pipeline_client():
# [START build_async_pipeline_client]
Expand Down
Loading