Skip to content

Add OAuth client credentials grant #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ async def main():
The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers:

```python
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.auth import OAuthClientProvider, ClientCredentialsProvider, TokenStorage
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
Expand Down Expand Up @@ -851,6 +851,9 @@ async def main():
callback_handler=lambda: ("auth_code", None),
)

# For machine-to-machine scenarios, use ClientCredentialsProvider
# instead of OAuthClientProvider.

# Use with streamable HTTP client
async with streamablehttp_client(
"https://api.example.com/mcp", auth=oauth_auth
Expand Down
204 changes: 204 additions & 0 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,207 @@ async def _refresh_access_token(self) -> bool:
except Exception:
logger.exception("Token refresh failed")
return False


class ClientCredentialsProvider(httpx.Auth):
"""HTTPX auth using the OAuth2 client credentials grant."""

def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
timeout: float = 300.0,
):
self.server_url = server_url
self.client_metadata = client_metadata
self.storage = storage
self.timeout = timeout

self._current_tokens: OAuthToken | None = None
self._metadata: OAuthMetadata | None = None
self._client_info: OAuthClientInformationFull | None = None
self._token_expiry_time: float | None = None

self._token_lock = anyio.Lock()

def _get_authorization_base_url(self, server_url: str) -> str:
from urllib.parse import urlparse, urlunparse

parsed = urlparse(server_url)
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))

async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
auth_base_url = self._get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}

async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 404:
return None
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except Exception:
try:
response = await client.get(url)
if response.status_code == 404:
return None
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except Exception:
logger.exception("Failed to discover OAuth metadata")
return None

async def _register_oauth_client(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
metadata: OAuthMetadata | None = None,
) -> OAuthClientInformationFull:
if not metadata:
metadata = await self._discover_oauth_metadata(server_url)

if metadata and metadata.registration_endpoint:
registration_url = str(metadata.registration_endpoint)
else:
auth_base_url = self._get_authorization_base_url(server_url)
registration_url = urljoin(auth_base_url, "/register")

if (
client_metadata.scope is None
and metadata
and metadata.scopes_supported is not None
):
client_metadata.scope = " ".join(metadata.scopes_supported)

registration_data = client_metadata.model_dump(
by_alias=True, mode="json", exclude_none=True
)

async with httpx.AsyncClient() as client:
response = await client.post(
registration_url,
json=registration_data,
headers={"Content-Type": "application/json"},
)

if response.status_code not in (200, 201):
raise httpx.HTTPStatusError(
f"Registration failed: {response.status_code}",
request=response.request,
response=response,
)

return OAuthClientInformationFull.model_validate(response.json())

def _has_valid_token(self) -> bool:
if not self._current_tokens or not self._current_tokens.access_token:
return False

if self._token_expiry_time and time.time() > self._token_expiry_time:
return False
return True

async def _validate_token_scopes(self, token_response: OAuthToken) -> None:
if not token_response.scope:
return

requested_scopes: set[str] = set()
if self.client_metadata.scope:
requested_scopes = set(self.client_metadata.scope.split())
returned_scopes = set(token_response.scope.split())
unauthorized_scopes = returned_scopes - requested_scopes
if unauthorized_scopes:
raise Exception(
f"Server granted unauthorized scopes: {unauthorized_scopes}."
)
else:
granted = set(token_response.scope.split())
logger.debug(
"No explicit scopes requested, accepting server-granted scopes: %s",
granted,
)

async def initialize(self) -> None:
self._current_tokens = await self.storage.get_tokens()
self._client_info = await self.storage.get_client_info()

async def _get_or_register_client(self) -> OAuthClientInformationFull:
if not self._client_info:
self._client_info = await self._register_oauth_client(
self.server_url, self.client_metadata, self._metadata
)
await self.storage.set_client_info(self._client_info)
return self._client_info

async def _request_token(self) -> None:
if not self._metadata:
self._metadata = await self._discover_oauth_metadata(self.server_url)

client_info = await self._get_or_register_client()

if self._metadata and self._metadata.token_endpoint:
token_url = str(self._metadata.token_endpoint)
else:
auth_base_url = self._get_authorization_base_url(self.server_url)
token_url = urljoin(auth_base_url, "/token")

token_data = {
"grant_type": "client_credentials",
"client_id": client_info.client_id,
}

if client_info.client_secret:
token_data["client_secret"] = client_info.client_secret

if self.client_metadata.scope:
token_data["scope"] = self.client_metadata.scope

async with httpx.AsyncClient() as client:
response = await client.post(
token_url,
data=token_data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=30.0,
)

if response.status_code != 200:
raise Exception(
f"Token request failed: {response.status_code} {response.text}"
)

token_response = OAuthToken.model_validate(response.json())
await self._validate_token_scopes(token_response)

if token_response.expires_in:
self._token_expiry_time = time.time() + token_response.expires_in
else:
self._token_expiry_time = None

await self.storage.set_tokens(token_response)
self._current_tokens = token_response

async def ensure_token(self) -> None:
async with self._token_lock:
if self._has_valid_token():
return
await self._request_token()

async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
if not self._has_valid_token():
await self.initialize()
await self.ensure_token()

if self._current_tokens and self._current_tokens.access_token:
request.headers["Authorization"] = (
f"Bearer {self._current_tokens.access_token}"
)

response = yield request

if response.status_code == 401:
self._current_tokens = None
33 changes: 31 additions & 2 deletions src/mcp/server/auth/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,25 @@ class RefreshTokenRequest(BaseModel):
client_secret: str | None = None


class ClientCredentialsRequest(BaseModel):
"""Token request for the client credentials grant."""

grant_type: Literal["client_credentials"]
scope: str | None = Field(None, description="Optional scope parameter")
client_id: str
client_secret: str | None = None


class TokenRequest(
RootModel[
Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest,
Field(discriminator="grant_type"),
]
]
):
root: Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest,
Field(discriminator="grant_type"),
]

Expand Down Expand Up @@ -204,6 +213,26 @@ async def handle(self, request: Request):
)
)

case ClientCredentialsRequest():
scopes = (
token_request.scope.split(" ")
if token_request.scope
else client_info.scope.split(" ")
if client_info.scope
else []
)
try:
tokens = await self.provider.exchange_client_credentials(
client_info, scopes
)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)

case RefreshTokenRequest():
refresh_token = await self.provider.load_refresh_token(
client_info, token_request.refresh_token
Expand Down
6 changes: 6 additions & 0 deletions src/mcp/server/auth/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ async def exchange_refresh_token(
"""
...

async def exchange_client_credentials(
self, client: OAuthClientInformationFull, scopes: list[str]
) -> OAuthToken:
"""Exchange client credentials for an access token."""
...

async def load_access_token(self, token: str) -> AccessTokenT | None:
"""
Loads an access token by its token.
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/server/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def build_metadata(
scopes_supported=client_registration_options.valid_scopes,
response_types_supported=["code"],
response_modes_supported=None,
grant_types_supported=["authorization_code", "refresh_token"],
grant_types_supported=[
"authorization_code",
"refresh_token",
"client_credentials",
],
token_endpoint_auth_methods_supported=["client_secret_post"],
token_endpoint_auth_signing_alg_values_supported=None,
service_documentation=service_documentation_url,
Expand Down
15 changes: 12 additions & 3 deletions src/mcp/shared/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ class OAuthClientMetadata(BaseModel):
token_endpoint_auth_method: Literal["none", "client_secret_post"] = (
"client_secret_post"
)
# grant_types: this implementation only supports authorization_code & refresh_token
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
# grant_types: support authorization_code, refresh_token, client_credentials
grant_types: list[
Literal["authorization_code", "refresh_token", "client_credentials"]
] = [
"authorization_code",
"refresh_token",
]
Expand Down Expand Up @@ -114,7 +116,14 @@ class OAuthMetadata(BaseModel):
response_types_supported: list[Literal["code"]] = ["code"]
response_modes_supported: list[Literal["query", "fragment"]] | None = None
grant_types_supported: (
list[Literal["authorization_code", "refresh_token"]] | None
list[
Literal[
"authorization_code",
"refresh_token",
"client_credentials",
]
]
| None
) = None
token_endpoint_auth_methods_supported: (
list[Literal["none", "client_secret_post"]] | None
Expand Down
Loading