Skip to content

Commit 66c7e67

Browse files
authored
Merge pull request #1 from sacha-development-stuff/codex/add-support-for-oauth-client-credentials
Add OAuth client credentials grant
2 parents 5441767 + 833a105 commit 66c7e67

File tree

8 files changed

+386
-10
lines changed

8 files changed

+386
-10
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ async def main():
814814
The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers:
815815

816816
```python
817-
from mcp.client.auth import OAuthClientProvider, TokenStorage
817+
from mcp.client.auth import OAuthClientProvider, ClientCredentialsProvider, TokenStorage
818818
from mcp.client.session import ClientSession
819819
from mcp.client.streamable_http import streamablehttp_client
820820
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
@@ -851,6 +851,9 @@ async def main():
851851
callback_handler=lambda: ("auth_code", None),
852852
)
853853

854+
# For machine-to-machine scenarios, use ClientCredentialsProvider
855+
# instead of OAuthClientProvider.
856+
854857
# Use with streamable HTTP client
855858
async with streamablehttp_client(
856859
"https://api.example.com/mcp", auth=oauth_auth

src/mcp/client/auth.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,207 @@ async def _refresh_access_token(self) -> bool:
499499
except Exception:
500500
logger.exception("Token refresh failed")
501501
return False
502+
503+
504+
class ClientCredentialsProvider(httpx.Auth):
505+
"""HTTPX auth using the OAuth2 client credentials grant."""
506+
507+
def __init__(
508+
self,
509+
server_url: str,
510+
client_metadata: OAuthClientMetadata,
511+
storage: TokenStorage,
512+
timeout: float = 300.0,
513+
):
514+
self.server_url = server_url
515+
self.client_metadata = client_metadata
516+
self.storage = storage
517+
self.timeout = timeout
518+
519+
self._current_tokens: OAuthToken | None = None
520+
self._metadata: OAuthMetadata | None = None
521+
self._client_info: OAuthClientInformationFull | None = None
522+
self._token_expiry_time: float | None = None
523+
524+
self._token_lock = anyio.Lock()
525+
526+
def _get_authorization_base_url(self, server_url: str) -> str:
527+
from urllib.parse import urlparse, urlunparse
528+
529+
parsed = urlparse(server_url)
530+
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))
531+
532+
async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
533+
auth_base_url = self._get_authorization_base_url(server_url)
534+
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
535+
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
536+
537+
async with httpx.AsyncClient() as client:
538+
try:
539+
response = await client.get(url, headers=headers)
540+
if response.status_code == 404:
541+
return None
542+
response.raise_for_status()
543+
return OAuthMetadata.model_validate(response.json())
544+
except Exception:
545+
try:
546+
response = await client.get(url)
547+
if response.status_code == 404:
548+
return None
549+
response.raise_for_status()
550+
return OAuthMetadata.model_validate(response.json())
551+
except Exception:
552+
logger.exception("Failed to discover OAuth metadata")
553+
return None
554+
555+
async def _register_oauth_client(
556+
self,
557+
server_url: str,
558+
client_metadata: OAuthClientMetadata,
559+
metadata: OAuthMetadata | None = None,
560+
) -> OAuthClientInformationFull:
561+
if not metadata:
562+
metadata = await self._discover_oauth_metadata(server_url)
563+
564+
if metadata and metadata.registration_endpoint:
565+
registration_url = str(metadata.registration_endpoint)
566+
else:
567+
auth_base_url = self._get_authorization_base_url(server_url)
568+
registration_url = urljoin(auth_base_url, "/register")
569+
570+
if (
571+
client_metadata.scope is None
572+
and metadata
573+
and metadata.scopes_supported is not None
574+
):
575+
client_metadata.scope = " ".join(metadata.scopes_supported)
576+
577+
registration_data = client_metadata.model_dump(
578+
by_alias=True, mode="json", exclude_none=True
579+
)
580+
581+
async with httpx.AsyncClient() as client:
582+
response = await client.post(
583+
registration_url,
584+
json=registration_data,
585+
headers={"Content-Type": "application/json"},
586+
)
587+
588+
if response.status_code not in (200, 201):
589+
raise httpx.HTTPStatusError(
590+
f"Registration failed: {response.status_code}",
591+
request=response.request,
592+
response=response,
593+
)
594+
595+
return OAuthClientInformationFull.model_validate(response.json())
596+
597+
def _has_valid_token(self) -> bool:
598+
if not self._current_tokens or not self._current_tokens.access_token:
599+
return False
600+
601+
if self._token_expiry_time and time.time() > self._token_expiry_time:
602+
return False
603+
return True
604+
605+
async def _validate_token_scopes(self, token_response: OAuthToken) -> None:
606+
if not token_response.scope:
607+
return
608+
609+
requested_scopes: set[str] = set()
610+
if self.client_metadata.scope:
611+
requested_scopes = set(self.client_metadata.scope.split())
612+
returned_scopes = set(token_response.scope.split())
613+
unauthorized_scopes = returned_scopes - requested_scopes
614+
if unauthorized_scopes:
615+
raise Exception(
616+
f"Server granted unauthorized scopes: {unauthorized_scopes}."
617+
)
618+
else:
619+
granted = set(token_response.scope.split())
620+
logger.debug(
621+
"No explicit scopes requested, accepting server-granted scopes: %s",
622+
granted,
623+
)
624+
625+
async def initialize(self) -> None:
626+
self._current_tokens = await self.storage.get_tokens()
627+
self._client_info = await self.storage.get_client_info()
628+
629+
async def _get_or_register_client(self) -> OAuthClientInformationFull:
630+
if not self._client_info:
631+
self._client_info = await self._register_oauth_client(
632+
self.server_url, self.client_metadata, self._metadata
633+
)
634+
await self.storage.set_client_info(self._client_info)
635+
return self._client_info
636+
637+
async def _request_token(self) -> None:
638+
if not self._metadata:
639+
self._metadata = await self._discover_oauth_metadata(self.server_url)
640+
641+
client_info = await self._get_or_register_client()
642+
643+
if self._metadata and self._metadata.token_endpoint:
644+
token_url = str(self._metadata.token_endpoint)
645+
else:
646+
auth_base_url = self._get_authorization_base_url(self.server_url)
647+
token_url = urljoin(auth_base_url, "/token")
648+
649+
token_data = {
650+
"grant_type": "client_credentials",
651+
"client_id": client_info.client_id,
652+
}
653+
654+
if client_info.client_secret:
655+
token_data["client_secret"] = client_info.client_secret
656+
657+
if self.client_metadata.scope:
658+
token_data["scope"] = self.client_metadata.scope
659+
660+
async with httpx.AsyncClient() as client:
661+
response = await client.post(
662+
token_url,
663+
data=token_data,
664+
headers={"Content-Type": "application/x-www-form-urlencoded"},
665+
timeout=30.0,
666+
)
667+
668+
if response.status_code != 200:
669+
raise Exception(
670+
f"Token request failed: {response.status_code} {response.text}"
671+
)
672+
673+
token_response = OAuthToken.model_validate(response.json())
674+
await self._validate_token_scopes(token_response)
675+
676+
if token_response.expires_in:
677+
self._token_expiry_time = time.time() + token_response.expires_in
678+
else:
679+
self._token_expiry_time = None
680+
681+
await self.storage.set_tokens(token_response)
682+
self._current_tokens = token_response
683+
684+
async def ensure_token(self) -> None:
685+
async with self._token_lock:
686+
if self._has_valid_token():
687+
return
688+
await self._request_token()
689+
690+
async def async_auth_flow(
691+
self, request: httpx.Request
692+
) -> AsyncGenerator[httpx.Request, httpx.Response]:
693+
if not self._has_valid_token():
694+
await self.initialize()
695+
await self.ensure_token()
696+
697+
if self._current_tokens and self._current_tokens.access_token:
698+
request.headers["Authorization"] = (
699+
f"Bearer {self._current_tokens.access_token}"
700+
)
701+
702+
response = yield request
703+
704+
if response.status_code == 401:
705+
self._current_tokens = None

src/mcp/server/auth/handlers/token.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,25 @@ class RefreshTokenRequest(BaseModel):
4747
client_secret: str | None = None
4848

4949

50+
class ClientCredentialsRequest(BaseModel):
51+
"""Token request for the client credentials grant."""
52+
53+
grant_type: Literal["client_credentials"]
54+
scope: str | None = Field(None, description="Optional scope parameter")
55+
client_id: str
56+
client_secret: str | None = None
57+
58+
5059
class TokenRequest(
5160
RootModel[
5261
Annotated[
53-
AuthorizationCodeRequest | RefreshTokenRequest,
62+
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest,
5463
Field(discriminator="grant_type"),
5564
]
5665
]
5766
):
5867
root: Annotated[
59-
AuthorizationCodeRequest | RefreshTokenRequest,
68+
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest,
6069
Field(discriminator="grant_type"),
6170
]
6271

@@ -204,6 +213,26 @@ async def handle(self, request: Request):
204213
)
205214
)
206215

216+
case ClientCredentialsRequest():
217+
scopes = (
218+
token_request.scope.split(" ")
219+
if token_request.scope
220+
else client_info.scope.split(" ")
221+
if client_info.scope
222+
else []
223+
)
224+
try:
225+
tokens = await self.provider.exchange_client_credentials(
226+
client_info, scopes
227+
)
228+
except TokenError as e:
229+
return self.response(
230+
TokenErrorResponse(
231+
error=e.error,
232+
error_description=e.error_description,
233+
)
234+
)
235+
207236
case RefreshTokenRequest():
208237
refresh_token = await self.provider.load_refresh_token(
209238
client_info, token_request.refresh_token

src/mcp/server/auth/provider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ async def exchange_refresh_token(
247247
"""
248248
...
249249

250+
async def exchange_client_credentials(
251+
self, client: OAuthClientInformationFull, scopes: list[str]
252+
) -> OAuthToken:
253+
"""Exchange client credentials for an access token."""
254+
...
255+
250256
async def load_access_token(self, token: str) -> AccessTokenT | None:
251257
"""
252258
Loads an access token by its token.

src/mcp/server/auth/routes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ def build_metadata(
164164
scopes_supported=client_registration_options.valid_scopes,
165165
response_types_supported=["code"],
166166
response_modes_supported=None,
167-
grant_types_supported=["authorization_code", "refresh_token"],
167+
grant_types_supported=[
168+
"authorization_code",
169+
"refresh_token",
170+
"client_credentials",
171+
],
168172
token_endpoint_auth_methods_supported=["client_secret_post"],
169173
token_endpoint_auth_signing_alg_values_supported=None,
170174
service_documentation=service_documentation_url,

src/mcp/shared/auth.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ class OAuthClientMetadata(BaseModel):
3939
token_endpoint_auth_method: Literal["none", "client_secret_post"] = (
4040
"client_secret_post"
4141
)
42-
# grant_types: this implementation only supports authorization_code & refresh_token
43-
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
42+
# grant_types: support authorization_code, refresh_token, client_credentials
43+
grant_types: list[
44+
Literal["authorization_code", "refresh_token", "client_credentials"]
45+
] = [
4446
"authorization_code",
4547
"refresh_token",
4648
]
@@ -114,7 +116,14 @@ class OAuthMetadata(BaseModel):
114116
response_types_supported: list[Literal["code"]] = ["code"]
115117
response_modes_supported: list[Literal["query", "fragment"]] | None = None
116118
grant_types_supported: (
117-
list[Literal["authorization_code", "refresh_token"]] | None
119+
list[
120+
Literal[
121+
"authorization_code",
122+
"refresh_token",
123+
"client_credentials",
124+
]
125+
]
126+
| None
118127
) = None
119128
token_endpoint_auth_methods_supported: (
120129
list[Literal["none", "client_secret_post"]] | None

0 commit comments

Comments
 (0)