Skip to content

Commit a5e7efe

Browse files
committed
generic
1 parent 0612dcb commit a5e7efe

File tree

12 files changed

+54
-51
lines changed

12 files changed

+54
-51
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from mcp.server.stdio import stdio_server
5050
from mcp.server.streamable_http import EventStore
5151
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
52-
from mcp.shared.context import LifespanContextT, RequestContext
52+
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
5353
from mcp.types import (
5454
AnyFunction,
5555
EmbeddedResource,
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
124124
def lifespan_wrapper(
125125
app: FastMCP,
126126
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
127-
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
127+
) -> Callable[
128+
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
129+
]:
128130
@asynccontextmanager
129-
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
131+
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
130132
async with lifespan(app) as context:
131133
yield context
132134

@@ -147,7 +149,7 @@ def __init__(
147149
):
148150
self.settings = Settings(**settings)
149151

150-
self._mcp_server = MCPServer(
152+
self._mcp_server: MCPServer[object, Request] = MCPServer(
151153
name=name or "FastMCP",
152154
instructions=instructions,
153155
lifespan=(
@@ -260,7 +262,7 @@ async def list_tools(self) -> list[MCPTool]:
260262
for info in tools
261263
]
262264

263-
def get_context(self) -> Context[ServerSession, object]:
265+
def get_context(self) -> Context[ServerSession, object, Request]:
264266
"""
265267
Returns a Context object. Note that the context will only be valid
266268
during a request; outside a request, most methods will error.
@@ -893,7 +895,7 @@ def _convert_to_content(
893895
return [TextContent(type="text", text=result)]
894896

895897

896-
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
898+
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
897899
"""Context object providing access to MCP capabilities.
898900
899901
This provides a cleaner interface to MCP's RequestContext functionality.
@@ -927,13 +929,15 @@ def my_tool(x: int, ctx: Context) -> str:
927929
The context is optional - tools that don't need it can omit the parameter.
928930
"""
929931

930-
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
932+
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
931933
_fastmcp: FastMCP | None
932934

933935
def __init__(
934936
self,
935937
*,
936-
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
938+
request_context: (
939+
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
940+
) = None,
937941
fastmcp: FastMCP | None = None,
938942
**kwargs: Any,
939943
):
@@ -949,7 +953,9 @@ def fastmcp(self) -> FastMCP:
949953
return self._fastmcp
950954

951955
@property
952-
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
956+
def request_context(
957+
self,
958+
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
953959
"""Access to the underlying request context."""
954960
if self._request_context is None:
955961
raise ValueError("Context is not available outside of a request")

src/mcp/server/fastmcp/tools/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
if TYPE_CHECKING:
1515
from mcp.server.fastmcp.server import Context
1616
from mcp.server.session import ServerSessionT
17-
from mcp.shared.context import LifespanContextT
17+
from mcp.shared.context import LifespanContextT, RequestT
1818

1919

2020
class Tool(BaseModel):
@@ -85,7 +85,7 @@ def from_function(
8585
async def run(
8686
self,
8787
arguments: dict[str, Any],
88-
context: Context[ServerSessionT, LifespanContextT] | None = None,
88+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
8989
) -> Any:
9090
"""Run the tool with arguments."""
9191
try:

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mcp.server.fastmcp.exceptions import ToolError
77
from mcp.server.fastmcp.tools.base import Tool
88
from mcp.server.fastmcp.utilities.logging import get_logger
9-
from mcp.shared.context import LifespanContextT
9+
from mcp.shared.context import LifespanContextT, RequestT
1010
from mcp.types import ToolAnnotations
1111

1212
if TYPE_CHECKING:
@@ -65,7 +65,7 @@ async def call_tool(
6565
self,
6666
name: str,
6767
arguments: dict[str, Any],
68-
context: Context[ServerSessionT, LifespanContextT] | None = None,
68+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
6969
) -> Any:
7070
"""Call a tool by name with arguments."""
7171
tool = self.get_tool(name)

src/mcp/server/lowlevel/server.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ async def main():
7272
import warnings
7373
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
7474
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
75-
from typing import Any, Generic, TypeVar
75+
from typing import Any, Generic
7676

7777
import anyio
7878
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7979
from pydantic import AnyUrl
80+
from typing_extensions import TypeVar
8081

8182
import mcp.types as types
8283
from mcp.server.lowlevel.helper_types import ReadResourceContents
@@ -91,9 +92,10 @@ async def main():
9192
logger = logging.getLogger(__name__)
9293

9394
LifespanResultT = TypeVar("LifespanResultT")
95+
RequestT = TypeVar("RequestT", default=Any)
9496

9597
# This will be properly typed in each Server instance's context
96-
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
98+
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
9799
contextvars.ContextVar("request_ctx")
98100
)
99101

@@ -111,7 +113,7 @@ def __init__(
111113

112114

113115
@asynccontextmanager
114-
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
116+
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
115117
"""Default lifespan context manager that does nothing.
116118
117119
Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
123125
yield {}
124126

125127

126-
class Server(Generic[LifespanResultT]):
128+
class Server(Generic[LifespanResultT, RequestT]):
127129
def __init__(
128130
self,
129131
name: str,
130132
version: str | None = None,
131133
instructions: str | None = None,
132134
lifespan: Callable[
133-
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
135+
[Server[LifespanResultT, RequestT]],
136+
AbstractAsyncContextManager[LifespanResultT],
134137
] = lifespan,
135138
):
136139
self.name = name
@@ -215,7 +218,9 @@ def get_capabilities(
215218
)
216219

217220
@property
218-
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
221+
def request_context(
222+
self,
223+
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
219224
"""If called outside of a request context, this will raise a LookupError."""
220225
return request_ctx.get()
221226

src/mcp/server/sse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ async def handle_post_message(
204204
return
205205

206206
# Pass the ASGI scope for framework-agnostic access to request data
207-
metadata = ServerMessageMetadata(request_context=dict(request.scope))
207+
metadata = ServerMessageMetadata(request_context=request)
208208
session_message = SessionMessage(message, metadata=metadata)
209209
logger.debug(f"Sending session message to writer: {session_message}")
210210
response = Response("Accepted", status_code=202)

src/mcp/server/streamable_http_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class StreamableHTTPSessionManager:
5656

5757
def __init__(
5858
self,
59-
app: MCPServer[Any],
59+
app: MCPServer[Any, Any],
6060
event_store: EventStore | None = None,
6161
json_response: bool = False,
6262
stateless: bool = False,

src/mcp/shared/context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
from typing_extensions import TypeVar
55

66
from mcp.shared.session import BaseSession
7-
from mcp.types import RequestData, RequestId, RequestParams
7+
from mcp.types import RequestId, RequestParams
88

99
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
1010
LifespanContextT = TypeVar("LifespanContextT")
11+
RequestT = TypeVar("RequestT", default=Any)
1112

1213

1314
@dataclass
14-
class RequestContext(Generic[SessionT, LifespanContextT]):
15+
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1516
request_id: RequestId
1617
meta: RequestParams.Meta | None
1718
session: SessionT
1819
lifespan_context: LifespanContextT
19-
request: RequestData | None = None
20+
request: RequestT | None = None

src/mcp/shared/message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Awaitable, Callable
99
from dataclasses import dataclass
1010

11-
from mcp.types import JSONRPCMessage, RequestData, RequestId
11+
from mcp.types import JSONRPCMessage, RequestId
1212

1313
ResumptionToken = str
1414

@@ -31,7 +31,7 @@ class ServerMessageMetadata:
3131

3232
related_request_id: RequestId | None = None
3333
# Request-specific context (e.g., headers, auth info)
34-
request_context: RequestData | None = None
34+
request_context: object | None = None
3535

3636

3737
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

src/mcp/types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@
3636
Role = Literal["user", "assistant"]
3737
RequestId = str | int
3838
AnyFunction: TypeAlias = Callable[..., Any]
39-
# Dictionary containing request metadata (headers, path, method, etc.
40-
# based on ASGI scope for most of the trasport implementations)
41-
RequestData: TypeAlias = dict[str, Any]
4239

4340

4441
class RequestParams(BaseModel):

tests/server/fastmcp/test_integration.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import socket
1111
import time
1212
from collections.abc import Generator
13+
from typing import Any
1314

1415
import pytest
1516
import uvicorn
1617
from pydantic import AnyUrl
1718
from starlette.applications import Starlette
19+
from starlette.requests import Request
1820

1921
import mcp.types as types
2022
from mcp.client.session import ClientSession
@@ -437,20 +439,17 @@ def make_fastmcp_with_context_app():
437439

438440
# Tool that echoes request headers
439441
@mcp.tool(description="Echo request headers from context")
440-
def echo_headers(ctx: Context) -> str:
442+
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
441443
"""Returns the request headers as JSON."""
442444
headers_info = {}
443445
if ctx.request_context.request:
444-
# Extract headers from ASGI scope
445-
headers_list = ctx.request_context.request.get("headers", [])
446-
headers_info = {
447-
k.decode("latin-1"): v.decode("latin-1") for k, v in headers_list
448-
}
446+
# Now the type system knows request is a Starlette Request object
447+
headers_info = dict(ctx.request_context.request.headers)
449448
return json.dumps(headers_info)
450449

451450
# Tool that returns full request context
452451
@mcp.tool(description="Echo request context with custom data")
453-
def echo_context(custom_request_id: str, ctx: Context) -> str:
452+
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
454453
"""Returns request context including headers and custom data."""
455454
context_data = {
456455
"custom_request_id": custom_request_id,
@@ -459,12 +458,11 @@ def echo_context(custom_request_id: str, ctx: Context) -> str:
459458
"path": None,
460459
}
461460
if ctx.request_context.request:
462-
# Extract data from ASGI scope
463-
headers_list = ctx.request_context.request.get("headers", [])
464-
context_data["headers"] = {
465-
k.decode("latin-1"): v.decode("latin-1") for k, v in headers_list
466-
}
467-
context_data["method"] = ctx.request_context.request.get("method")
461+
# Now we can access Request attributes directly
462+
request = ctx.request_context.request
463+
context_data["headers"] = dict(request.headers)
464+
context_data["method"] = request.method
465+
context_data["path"] = request.url.path
468466
return json.dumps(context_data)
469467

470468
# Create the SSE app

tests/server/fastmcp/test_tool_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mcp.server.fastmcp.tools import Tool, ToolManager
1010
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
1111
from mcp.server.session import ServerSessionT
12-
from mcp.shared.context import LifespanContextT
12+
from mcp.shared.context import LifespanContextT, RequestT
1313
from mcp.types import ToolAnnotations
1414

1515

@@ -347,7 +347,7 @@ def tool_without_context(x: int) -> str:
347347
assert tool.context_kwarg is None
348348

349349
def tool_with_parametrized_context(
350-
x: int, ctx: Context[ServerSessionT, LifespanContextT]
350+
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
351351
) -> str:
352352
return str(x)
353353

tests/shared/test_sse.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ async def test_sse_client_basic_connection_mounted_app(
322322

323323

324324
# Test server with request context that returns headers in the response
325-
class RequestContextServer(Server):
325+
class RequestContextServer(Server[object, Request]):
326326
def __init__(self):
327327
super().__init__("request_context_server")
328328

@@ -333,12 +333,8 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
333333
try:
334334
context = self.request_context
335335
if context.request:
336-
# Extract headers from ASGI scope
337-
headers_list = context.request.get("headers", [])
338-
headers_info = {
339-
k.decode("latin-1"): v.decode("latin-1")
340-
for k, v in headers_list
341-
}
336+
# The request is a Starlette Request object
337+
headers_info = dict(context.request.headers)
342338
except LookupError:
343339
pass # No request context available
344340

0 commit comments

Comments
 (0)