Skip to content

Commit 397acac

Browse files
add default CORS middleware (#441)
* add default CORS middleware * test that default cors middleware is working
1 parent 0d36b76 commit 397acac

File tree

3 files changed

+60
-3
lines changed

3 files changed

+60
-3
lines changed

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from starlette.responses import JSONResponse, Response
1515

1616
from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
17-
from stac_fastapi.api.middleware import ProxyHeaderMiddleware
17+
from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware
1818
from stac_fastapi.api.models import (
1919
APIRequest,
2020
CollectionUri,
@@ -93,7 +93,9 @@ class StacApi:
9393
pagination_extension = attr.ib(default=TokenPaginationExtension)
9494
response_class: Type[Response] = attr.ib(default=JSONResponse)
9595
middlewares: List = attr.ib(
96-
default=attr.Factory(lambda: [BrotliMiddleware, ProxyHeaderMiddleware])
96+
default=attr.Factory(
97+
lambda: [BrotliMiddleware, CORSMiddleware, ProxyHeaderMiddleware]
98+
)
9799
)
98100
route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[])
99101

stac_fastapi/api/stac_fastapi/api/middleware.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,48 @@
11
"""api middleware."""
2-
32
import re
3+
import typing
44
from http.client import HTTP_PORT, HTTPS_PORT
55
from typing import List, Tuple
66

7+
from starlette.middleware.cors import CORSMiddleware as _CORSMiddleware
78
from starlette.types import ASGIApp, Receive, Scope, Send
89

910

11+
class CORSMiddleware(_CORSMiddleware):
12+
"""
13+
Subclass of Starlette's standard CORS middleware with default values set to those reccomended by the STAC API spec.
14+
15+
https://github.com/radiantearth/stac-api-spec/blob/914cf8108302e2ec734340080a45aaae4859bb63/implementation.md#cors
16+
"""
17+
18+
def __init__(
19+
self,
20+
app: ASGIApp,
21+
allow_origins: typing.Sequence[str] = ("*",),
22+
allow_methods: typing.Sequence[str] = (
23+
"OPTIONS",
24+
"POST",
25+
"GET",
26+
),
27+
allow_headers: typing.Sequence[str] = ("Content-Type",),
28+
allow_credentials: bool = False,
29+
allow_origin_regex: typing.Optional[str] = None,
30+
expose_headers: typing.Sequence[str] = (),
31+
max_age: int = 600,
32+
) -> None:
33+
"""Create CORS middleware."""
34+
super().__init__(
35+
app,
36+
allow_origins,
37+
allow_methods,
38+
allow_headers,
39+
allow_credentials,
40+
allow_origin_regex,
41+
expose_headers,
42+
max_age,
43+
)
44+
45+
1046
class ProxyHeaderMiddleware:
1147
"""
1248
Account for forwarding headers when deriving base URL.

stac_fastapi/api/tests/test_middleware.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from unittest import mock
2+
13
import pytest
24
from starlette.applications import Starlette
5+
from starlette.testclient import TestClient
36

7+
from stac_fastapi.api.app import StacApi
48
from stac_fastapi.api.middleware import ProxyHeaderMiddleware
9+
from stac_fastapi.types.config import ApiSettings
10+
from stac_fastapi.types.core import BaseCoreClient
511

612

713
@pytest.fixture
@@ -10,6 +16,13 @@ def proxy_header_middleware() -> ProxyHeaderMiddleware:
1016
return ProxyHeaderMiddleware(app)
1117

1218

19+
@pytest.fixture
20+
def test_client() -> TestClient:
21+
app = StacApi(settings=ApiSettings(), client=mock.create_autospec(BaseCoreClient))
22+
with TestClient(app.app) as client:
23+
yield client
24+
25+
1326
@pytest.mark.parametrize(
1427
"headers,key,expected",
1528
[
@@ -138,3 +151,9 @@ def test_get_forwarded_url_parts(
138151
):
139152
actual = proxy_header_middleware._get_forwarded_url_parts(scope)
140153
assert actual == expected
154+
155+
156+
def test_cors_middleware(test_client):
157+
resp = test_client.get("/_mgmt/ping", headers={"Origin": "http://netloc"})
158+
assert resp.status_code == 200
159+
assert resp.headers["access-control-allow-origin"] == "*"

0 commit comments

Comments
 (0)