Skip to content

Refactor auth helper methods #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 48 additions & 85 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,44 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
...


def _get_authorization_base_url(server_url: str) -> str:
"""Return the authorization base URL for ``server_url``.

Per MCP spec 2.3.2, the path component must be discarded so that
``https://api.example.com/v1/mcp`` becomes ``https://api.example.com``.
"""
from urllib.parse import urlparse, urlunparse

parsed = urlparse(server_url)
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))


async def _discover_oauth_metadata(server_url: str) -> OAuthMetadata | None:
"""Discover OAuth metadata from the server's well-known endpoint."""

auth_base_url = _get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}

async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 404:
return None
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except Exception:
try:
response = await client.get(url)
if response.status_code == 404:
return None
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except Exception:
logger.exception("Failed to discover OAuth metadata")
return None


class OAuthClientProvider(httpx.Auth):
"""
Authentication for httpx using anyio.
Expand Down Expand Up @@ -110,52 +148,6 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
digest = hashlib.sha256(code_verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")

def _get_authorization_base_url(self, server_url: str) -> str:
"""
Extract base URL by removing path component.

Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com
"""
from urllib.parse import urlparse, urlunparse

parsed = urlparse(server_url)
# Remove path component
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))

async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
"""
Discover OAuth metadata from server's well-known endpoint.
"""
# Extract base URL per MCP spec
auth_base_url = self._get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}

async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 404:
return None
response.raise_for_status()
metadata_json = response.json()
logger.debug(f"OAuth metadata discovered: {metadata_json}")
return OAuthMetadata.model_validate(metadata_json)
except Exception:
# Retry without MCP header for CORS compatibility
try:
response = await client.get(url)
if response.status_code == 404:
return None
response.raise_for_status()
metadata_json = response.json()
logger.debug(
f"OAuth metadata discovered (no MCP header): {metadata_json}"
)
return OAuthMetadata.model_validate(metadata_json)
except Exception:
logger.exception("Failed to discover OAuth metadata")
return None

async def _register_oauth_client(
self,
server_url: str,
Expand All @@ -166,13 +158,13 @@ async def _register_oauth_client(
Register OAuth client with server.
"""
if not metadata:
metadata = await self._discover_oauth_metadata(server_url)
metadata = await _discover_oauth_metadata(server_url)

if metadata and metadata.registration_endpoint:
registration_url = str(metadata.registration_endpoint)
else:
# Use fallback registration endpoint
auth_base_url = self._get_authorization_base_url(server_url)
auth_base_url = _get_authorization_base_url(server_url)
registration_url = urljoin(auth_base_url, "/register")

# Handle default scope
Expand Down Expand Up @@ -321,7 +313,7 @@ async def _perform_oauth_flow(self) -> None:

# Discover OAuth metadata
if not self._metadata:
self._metadata = await self._discover_oauth_metadata(self.server_url)
self._metadata = await _discover_oauth_metadata(self.server_url)

# Ensure client registration
client_info = await self._get_or_register_client()
Expand All @@ -335,7 +327,7 @@ async def _perform_oauth_flow(self) -> None:
auth_url_base = str(self._metadata.authorization_endpoint)
else:
# Use fallback authorization endpoint
auth_base_url = self._get_authorization_base_url(self.server_url)
auth_base_url = _get_authorization_base_url(self.server_url)
auth_url_base = urljoin(auth_base_url, "/authorize")

# Build authorization URL
Expand Down Expand Up @@ -386,7 +378,7 @@ async def _exchange_code_for_token(
token_url = str(self._metadata.token_endpoint)
else:
# Use fallback token endpoint
auth_base_url = self._get_authorization_base_url(self.server_url)
auth_base_url = _get_authorization_base_url(self.server_url)
token_url = urljoin(auth_base_url, "/token")

token_data = {
Expand Down Expand Up @@ -453,7 +445,7 @@ async def _refresh_access_token(self) -> bool:
token_url = str(self._metadata.token_endpoint)
else:
# Use fallback token endpoint
auth_base_url = self._get_authorization_base_url(self.server_url)
auth_base_url = _get_authorization_base_url(self.server_url)
token_url = urljoin(auth_base_url, "/token")

refresh_data = {
Expand Down Expand Up @@ -523,48 +515,19 @@ def __init__(

self._token_lock = anyio.Lock()

def _get_authorization_base_url(self, server_url: str) -> str:
from urllib.parse import urlparse, urlunparse

parsed = urlparse(server_url)
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))

async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
auth_base_url = self._get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}

async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 404:
return None
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except Exception:
try:
response = await client.get(url)
if response.status_code == 404:
return None
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except Exception:
logger.exception("Failed to discover OAuth metadata")
return None

async def _register_oauth_client(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
metadata: OAuthMetadata | None = None,
) -> OAuthClientInformationFull:
if not metadata:
metadata = await self._discover_oauth_metadata(server_url)
metadata = await _discover_oauth_metadata(server_url)

if metadata and metadata.registration_endpoint:
registration_url = str(metadata.registration_endpoint)
else:
auth_base_url = self._get_authorization_base_url(server_url)
auth_base_url = _get_authorization_base_url(server_url)
registration_url = urljoin(auth_base_url, "/register")

if (
Expand Down Expand Up @@ -636,14 +599,14 @@ async def _get_or_register_client(self) -> OAuthClientInformationFull:

async def _request_token(self) -> None:
if not self._metadata:
self._metadata = await self._discover_oauth_metadata(self.server_url)
self._metadata = await _discover_oauth_metadata(self.server_url)

client_info = await self._get_or_register_client()

if self._metadata and self._metadata.token_endpoint:
token_url = str(self._metadata.token_endpoint)
else:
auth_base_url = self._get_authorization_base_url(self.server_url)
auth_base_url = _get_authorization_base_url(self.server_url)
token_url = urljoin(auth_base_url, "/token")

token_data = {
Expand Down