Skip to content

Support for http request injection propagation to tools #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from mcp.server.stdio import stdio_server
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import (
AnyFunction,
EmbeddedResource,
Expand Down Expand Up @@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
) -> Callable[
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context

Expand Down Expand Up @@ -260,7 +262,7 @@ async def list_tools(self) -> list[MCPTool]:
for info in tools
]

def get_context(self) -> Context[ServerSession, object]:
def get_context(self) -> Context[ServerSession, object, Request]:
"""
Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error.
Expand Down Expand Up @@ -893,7 +895,7 @@ def _convert_to_content(
return [TextContent(type="text", text=result)]


class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
"""Context object providing access to MCP capabilities.

This provides a cleaner interface to MCP's RequestContext functionality.
Expand Down Expand Up @@ -927,13 +929,15 @@ def my_tool(x: int, ctx: Context) -> str:
The context is optional - tools that don't need it can omit the parameter.
"""

_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
_fastmcp: FastMCP | None

def __init__(
self,
*,
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
request_context: (
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
) = None,
fastmcp: FastMCP | None = None,
**kwargs: Any,
):
Expand All @@ -949,7 +953,9 @@ def fastmcp(self) -> FastMCP:
return self._fastmcp

@property
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
def request_context(
self,
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
"""Access to the underlying request context."""
if self._request_context is None:
raise ValueError("Context is not available outside of a request")
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT


class Tool(BaseModel):
Expand Down Expand Up @@ -85,7 +85,7 @@ def from_function(
async def run(
self,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any:
"""Run the tool with arguments."""
try:
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/server/fastmcp/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools.base import Tool
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.shared.context import LifespanContextT
from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import ToolAnnotations

if TYPE_CHECKING:
Expand Down Expand Up @@ -65,7 +65,7 @@ async def call_tool(
self,
name: str,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Any:
"""Call a tool by name with arguments."""
tool = self.get_tool(name)
Expand Down
27 changes: 20 additions & 7 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ async def main():
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, Generic, TypeVar
from typing import Any, Generic

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from typing_extensions import TypeVar

import mcp.types as types
from mcp.server.lowlevel.helper_types import ReadResourceContents
Expand All @@ -85,15 +86,16 @@ async def main():
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder

logger = logging.getLogger(__name__)

LifespanResultT = TypeVar("LifespanResultT")
RequestT = TypeVar("RequestT", default=Any)

# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
contextvars.ContextVar("request_ctx")
)

Expand All @@ -111,7 +113,7 @@ def __init__(


@asynccontextmanager
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.

Args:
Expand All @@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
yield {}


class Server(Generic[LifespanResultT]):
class Server(Generic[LifespanResultT, RequestT]):
def __init__(
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
[Server[LifespanResultT, RequestT]],
AbstractAsyncContextManager[LifespanResultT],
] = lifespan,
):
self.name = name
Expand Down Expand Up @@ -215,7 +218,9 @@ def get_capabilities(
)

@property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
def request_context(
self,
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()

Expand Down Expand Up @@ -555,6 +560,13 @@ async def _handle_request(

token = None
try:
# Extract request context from message metadata
request_data = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context

# Set our global state that can be retrieved via
# app.get_request_context()
token = request_ctx.set(
Expand All @@ -563,6 +575,7 @@ async def _handle_request(
message.request_meta,
session,
lifespan_context,
request=request_data,
)
)
response = await handler(req)
Expand Down
6 changes: 4 additions & 2 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -203,7 +203,9 @@ async def handle_post_message(
await writer.send(err)
return

session_message = SessionMessage(message)
# Pass the ASGI scope for framework-agnostic access to request data
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:

def __init__(
self,
app: MCPServer[Any],
app: MCPServer[Any, Any],
event_store: EventStore | None = None,
json_response: bool = False,
stateless: bool = False,
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/shared/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT", default=Any)


@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestT | None = None
2 changes: 2 additions & 0 deletions src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class ServerMessageMetadata:
"""Metadata specific to server messages."""

related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: object | None = None


MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
Expand Down
3 changes: 3 additions & 0 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def __init__(
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
message_metadata: MessageMetadata = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
Expand Down Expand Up @@ -364,6 +366,7 @@ async def _receive_loop(self) -> None:
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)

self._in_flight[responder.request_id] = responder
Expand Down
Loading
Loading