Skip to content

Add support for DNS rebinding protections #861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from mcp.server.stdio import stdio_server
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import (
AnyFunction,
Expand Down Expand Up @@ -119,6 +120,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
) = Field(None, description="Lifespan context manager")

auth: AuthSettings | None = None

# Transport security settings (DNS rebinding protection)
transport_security: TransportSecuritySettings | None = None


def lifespan_wrapper(
Expand Down Expand Up @@ -672,6 +676,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette:

sse = SseServerTransport(
normalized_message_endpoint,
security_settings=self.settings.transport_security,
)

async def handle_sse(scope: Scope, receive: Receive, send: Send):
Expand Down Expand Up @@ -779,6 +784,7 @@ def streamable_http_app(self) -> Starlette:
event_store=self._event_store,
json_response=self.settings.json_response,
stateless=self.settings.stateless_http, # Use the stateless setting
security_settings=self.settings.transport_security,
)

# Create the ASGI handler
Expand Down
26 changes: 25 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.server.transport_security import (
TransportSecurityMiddleware,
TransportSecuritySettings,
)
from mcp.shared.message import ServerMessageMetadata, SessionMessage

logger = logging.getLogger(__name__)
Expand All @@ -71,16 +75,24 @@ class SseServerTransport:

_endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
_security: TransportSecurityMiddleware

def __init__(self, endpoint: str) -> None:
def __init__(
self, endpoint: str, security_settings: TransportSecuritySettings | None = None
) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.

Args:
endpoint: The relative or absolute URL for POST messages.
security_settings: Optional security settings for DNS rebinding protection.
"""

super().__init__()
self._endpoint = endpoint
self._read_stream_writers = {}
self._security = TransportSecurityMiddleware(security_settings)
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

@asynccontextmanager
Expand All @@ -89,6 +101,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")

# Validate request headers for DNS rebinding protection
request = Request(scope, receive)
error_response = await self._security.validate_request(request, is_post=False)
if error_response:
await error_response(scope, receive, send)
raise ValueError("Request validation failed")

logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
Expand Down Expand Up @@ -169,6 +188,11 @@ async def handle_post_message(
) -> None:
logger.debug("Handling POST message")
request = Request(scope, receive)

# Validate request headers for DNS rebinding protection
error_response = await self._security.validate_request(request, is_post=True)
if error_response:
return await error_response(scope, receive, send)

session_id_param = request.query_params.get("session_id")
if session_id_param is None:
Expand Down
16 changes: 16 additions & 0 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

from mcp.server.transport_security import (
TransportSecurityMiddleware,
TransportSecuritySettings,
)
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import (
INTERNAL_ERROR,
Expand Down Expand Up @@ -131,12 +135,14 @@ class StreamableHTTPServerTransport:
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
_security: TransportSecurityMiddleware

def __init__(
self,
mcp_session_id: str | None,
is_json_response_enabled: bool = False,
event_store: EventStore | None = None,
security_settings: TransportSecuritySettings | None = None,
) -> None:
"""
Initialize a new StreamableHTTP server transport.
Expand All @@ -149,6 +155,7 @@ def __init__(
event_store: Event store for resumability support. If provided,
resumability will be enabled, allowing clients to
reconnect and resume messages.
security_settings: Optional security settings for DNS rebinding protection.

Raises:
ValueError: If the session ID contains invalid characters.
Expand All @@ -163,6 +170,7 @@ def __init__(
self.mcp_session_id = mcp_session_id
self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store
self._security = TransportSecurityMiddleware(security_settings)
self._request_streams: dict[
RequestId,
tuple[
Expand Down Expand Up @@ -260,6 +268,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests"""
request = Request(scope, receive)

# Validate request headers for DNS rebinding protection
is_post = request.method == "POST"
error_response = await self._security.validate_request(request, is_post=is_post)
if error_response:
await error_response(scope, receive, send)
return

if self._terminated:
# If the session has been terminated, return 404 Not Found
response = self._create_error_response(
Expand Down
5 changes: 5 additions & 0 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
EventStore,
StreamableHTTPServerTransport,
)
from mcp.server.transport_security import TransportSecuritySettings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,11 +61,13 @@ def __init__(
event_store: EventStore | None = None,
json_response: bool = False,
stateless: bool = False,
security_settings: TransportSecuritySettings | None = None,
):
self.app = app
self.event_store = event_store
self.json_response = json_response
self.stateless = stateless
self.security_settings = security_settings

# Session tracking (only used if not stateless)
self._session_creation_lock = anyio.Lock()
Expand Down Expand Up @@ -162,6 +165,7 @@ async def _handle_stateless_request(
mcp_session_id=None, # No session tracking in stateless mode
is_json_response_enabled=self.json_response,
event_store=None, # No event store in stateless mode
security_settings=self.security_settings,
)

# Start server in a new task
Expand Down Expand Up @@ -222,6 +226,7 @@ async def _handle_stateful_request(
mcp_session_id=new_session_id,
is_json_response_enabled=self.json_response,
event_store=self.event_store, # May be None (no resumability)
security_settings=self.security_settings,
)

assert http_transport.mcp_session_id is not None
Expand Down
133 changes: 133 additions & 0 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""DNS rebinding protection for MCP server transports."""

import logging

from pydantic import BaseModel, Field
from starlette.requests import Request
from starlette.responses import Response

logger = logging.getLogger(__name__)


class TransportSecuritySettings(BaseModel):
"""Settings for MCP transport security features.

These settings help protect against DNS rebinding attacks by validating
incoming request headers.
"""

enable_dns_rebinding_protection: bool = Field(
default=True,
description="Enable DNS rebinding protection (recommended for production)"
)

allowed_hosts: list[str] = Field(
default=[],
description="List of allowed Host header values. Only applies when " +
"enable_dns_rebinding_protection is True."
)

allowed_origins: list[str] = Field(
default=[],
description="List of allowed Origin header values. Only applies when " +
"enable_dns_rebinding_protection is True."
)


class TransportSecurityMiddleware:
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""

def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default
# for backwards compatibility
self.settings = settings or TransportSecuritySettings(
enable_dns_rebinding_protection=False
)

def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not self.settings.enable_dns_rebinding_protection:
return True

if not host:
logger.warning("Missing Host header in request")
return False

# Check exact match first
if host in self.settings.allowed_hosts:
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_hosts:
if allowed.endswith(":*"):
# Extract base host from pattern
base_host = allowed[:-2]
# Check if the actual host starts with base host and has a port
if host.startswith(base_host + ":"):
return True

logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
if not self.settings.enable_dns_rebinding_protection:
return True

# Origin can be absent for same-origin requests
if not origin:
return True

# Check exact match first
if origin in self.settings.allowed_origins:
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"):
# Extract base origin from pattern
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
if origin.startswith(base_origin + ":"):
return True

logger.warning(f"Invalid Origin header: {origin}")
return False

def _validate_content_type(self, content_type: str | None) -> bool:
"""Validate the Content-Type header for POST requests."""
if not content_type:
logger.warning("Missing Content-Type header in POST request")
return False

# Content-Type must start with application/json
if not content_type.lower().startswith("application/json"):
logger.warning(f"Invalid Content-Type header: {content_type}")
return False

return True

async def validate_request(
self, request: Request, is_post: bool = False
) -> Response | None:
"""Validate request headers for DNS rebinding protection.

Returns None if validation passes, or an error Response if validation fails.
"""
# Validate Host header
host = request.headers.get("host")
if not self._validate_host(host):
return Response("Invalid Host header", status_code=400)

# Validate Origin header
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response("Invalid Origin header", status_code=400)

# Validate Content-Type for POST requests
if is_post:
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)

return None
Loading
Loading