Skip to content

Commit 7706515

Browse files
committed
feat: add route-specific custom response validation and tests
1 parent 4ca1461 commit 7706515

File tree

4 files changed

+113
-6
lines changed

4 files changed

+113
-6
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def __init__(
320320
openapi_extensions: dict[str, Any] | None = None,
321321
deprecated: bool = False,
322322
middlewares: list[Callable[..., Response]] | None = None,
323+
custom_response_validation_http_code: int | HTTPStatus | None = None,
323324
):
324325
"""
325326
Internally used Route Configuration
@@ -362,6 +363,7 @@ def __init__(
362363
Whether or not to mark this route as deprecated in the OpenAPI schema
363364
middlewares: list[Callable[..., Response]] | None
364365
The list of route middlewares to be called in order.
366+
# TODO
365367
"""
366368
self.method = method.upper()
367369
self.path = "/" if path.strip() == "" else path
@@ -397,6 +399,8 @@ def __init__(
397399
# _body_field is used to cache the dependant model for the body field
398400
self._body_field: ModelField | None = None
399401

402+
self.custom_response_validation_http_code: int | HTTPStatus | None = custom_response_validation_http_code
403+
400404
def __call__(
401405
self,
402406
router_middlewares: list[Callable],
@@ -565,6 +569,8 @@ def _get_openapi_path(
565569
},
566570
}
567571

572+
# TODO update responses
573+
568574
# Add the response to the OpenAPI operation
569575
if self.responses:
570576
for status_code in list(self.responses):
@@ -943,6 +949,7 @@ def route(
943949
openapi_extensions: dict[str, Any] | None = None,
944950
deprecated: bool = False,
945951
middlewares: list[Callable[..., Any]] | None = None,
952+
custom_response_validation_http_code: int | HTTPStatus | None = None,
946953
) -> Callable[[AnyCallableT], AnyCallableT]:
947954
raise NotImplementedError()
948955

@@ -1004,6 +1011,7 @@ def get(
10041011
openapi_extensions: dict[str, Any] | None = None,
10051012
deprecated: bool = False,
10061013
middlewares: list[Callable[..., Any]] | None = None,
1014+
custom_response_validation_http_code: int | HTTPStatus | None = None,
10071015
) -> Callable[[AnyCallableT], AnyCallableT]:
10081016
"""Get route decorator with GET `method`
10091017
@@ -1044,6 +1052,7 @@ def lambda_handler(event, context):
10441052
openapi_extensions,
10451053
deprecated,
10461054
middlewares,
1055+
custom_response_validation_http_code,
10471056
)
10481057

10491058
def post(
@@ -1063,6 +1072,7 @@ def post(
10631072
openapi_extensions: dict[str, Any] | None = None,
10641073
deprecated: bool = False,
10651074
middlewares: list[Callable[..., Any]] | None = None,
1075+
custom_response_validation_http_code: int | HTTPStatus | None = None,
10661076
) -> Callable[[AnyCallableT], AnyCallableT]:
10671077
"""Post route decorator with POST `method`
10681078
@@ -1104,6 +1114,7 @@ def lambda_handler(event, context):
11041114
openapi_extensions,
11051115
deprecated,
11061116
middlewares,
1117+
custom_response_validation_http_code,
11071118
)
11081119

11091120
def put(
@@ -1123,6 +1134,7 @@ def put(
11231134
openapi_extensions: dict[str, Any] | None = None,
11241135
deprecated: bool = False,
11251136
middlewares: list[Callable[..., Any]] | None = None,
1137+
custom_response_validation_http_code: int | HTTPStatus | None = None,
11261138
) -> Callable[[AnyCallableT], AnyCallableT]:
11271139
"""Put route decorator with PUT `method`
11281140
@@ -1164,6 +1176,7 @@ def lambda_handler(event, context):
11641176
openapi_extensions,
11651177
deprecated,
11661178
middlewares,
1179+
custom_response_validation_http_code,
11671180
)
11681181

11691182
def delete(
@@ -1183,6 +1196,7 @@ def delete(
11831196
openapi_extensions: dict[str, Any] | None = None,
11841197
deprecated: bool = False,
11851198
middlewares: list[Callable[..., Any]] | None = None,
1199+
custom_response_validation_http_code: int | HTTPStatus | None = None,
11861200
) -> Callable[[AnyCallableT], AnyCallableT]:
11871201
"""Delete route decorator with DELETE `method`
11881202
@@ -1223,6 +1237,7 @@ def lambda_handler(event, context):
12231237
openapi_extensions,
12241238
deprecated,
12251239
middlewares,
1240+
custom_response_validation_http_code,
12261241
)
12271242

12281243
def patch(
@@ -1242,6 +1257,7 @@ def patch(
12421257
openapi_extensions: dict[str, Any] | None = None,
12431258
deprecated: bool = False,
12441259
middlewares: list[Callable] | None = None,
1260+
custom_response_validation_http_code: int | HTTPStatus | None = None,
12451261
) -> Callable[[AnyCallableT], AnyCallableT]:
12461262
"""Patch route decorator with PATCH `method`
12471263
@@ -1285,6 +1301,7 @@ def lambda_handler(event, context):
12851301
openapi_extensions,
12861302
deprecated,
12871303
middlewares,
1304+
custom_response_validation_http_code,
12881305
)
12891306

12901307
def head(
@@ -1304,6 +1321,7 @@ def head(
13041321
openapi_extensions: dict[str, Any] | None = None,
13051322
deprecated: bool = False,
13061323
middlewares: list[Callable] | None = None,
1324+
custom_response_validation_http_code: int | HTTPStatus | None = None,
13071325
) -> Callable[[AnyCallableT], AnyCallableT]:
13081326
"""Head route decorator with HEAD `method`
13091327
@@ -1346,6 +1364,7 @@ def lambda_handler(event, context):
13461364
openapi_extensions,
13471365
deprecated,
13481366
middlewares,
1367+
custom_response_validation_http_code,
13491368
)
13501369

13511370
def _push_processed_stack_frame(self, frame: str):
@@ -2126,9 +2145,14 @@ def route(
21262145
openapi_extensions: dict[str, Any] | None = None,
21272146
deprecated: bool = False,
21282147
middlewares: list[Callable[..., Any]] | None = None,
2148+
custom_response_validation_http_code: int | HTTPStatus | None = None,
21292149
) -> Callable[[AnyCallableT], AnyCallableT]:
21302150
"""Route decorator includes parameter `method`"""
21312151

2152+
custom_response_validation_http_code = self._validate_route_response_validation_error_http_code(
2153+
custom_response_validation_http_code,
2154+
)
2155+
21322156
def register_resolver(func: AnyCallableT) -> AnyCallableT:
21332157
methods = (method,) if isinstance(method, str) else method
21342158
logger.debug(f"Adding route using rule {rule} and methods: {','.join(m.upper() for m in methods)}")
@@ -2155,6 +2179,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT:
21552179
openapi_extensions,
21562180
deprecated,
21572181
middlewares,
2182+
custom_response_validation_http_code,
21582183
)
21592184

21602185
# The more specific route wins.
@@ -2523,15 +2548,20 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
25232548
)
25242549

25252550
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2526-
# 'self._response_validation_error_http_code' is not None
2551+
# 'self._response_validation_error_http_code' is not None or
2552+
# when route has custom_response_validation_http_code
25272553
if isinstance(exp, ResponseValidationError):
2528-
http_code = self._response_validation_error_http_code
2554+
http_code = (
2555+
self._response_validation_error_http_code
2556+
if exp.source == "app"
2557+
else route.custom_response_validation_http_code
2558+
)
25292559
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
25302560
return self._response_builder_class(
25312561
response=Response(
25322562
status_code=http_code.value,
25332563
content_type=content_types.APPLICATION_JSON,
2534-
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
2564+
body={"statusCode": http_code, "detail": errors},
25352565
),
25362566
serializer=self._serializer,
25372567
route=route,
@@ -2683,6 +2713,7 @@ def route(
26832713
openapi_extensions: dict[str, Any] | None = None,
26842714
deprecated: bool = False,
26852715
middlewares: list[Callable[..., Any]] | None = None,
2716+
custom_response_validation_http_code: int | HTTPStatus | None = None,
26862717
) -> Callable[[AnyCallableT], AnyCallableT]:
26872718
def register_route(func: AnyCallableT) -> AnyCallableT:
26882719
# All dict keys needs to be hashable. So we'll need to do some conversions:
@@ -2708,6 +2739,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT:
27082739
frozen_security,
27092740
frozen_openapi_extensions,
27102741
deprecated,
2742+
custom_response_validation_http_code,
27112743
)
27122744

27132745
# Collate Middleware for routes
@@ -2795,6 +2827,7 @@ def route(
27952827
openapi_extensions: dict[str, Any] | None = None,
27962828
deprecated: bool = False,
27972829
middlewares: list[Callable[..., Any]] | None = None,
2830+
custom_response_validation_http_code: int | HTTPStatus | None = None,
27982831
) -> Callable[[AnyCallableT], AnyCallableT]:
27992832
# NOTE: see #1552 for more context.
28002833
return super().route(
@@ -2814,6 +2847,7 @@ def route(
28142847
openapi_extensions,
28152848
deprecated,
28162849
middlewares,
2850+
custom_response_validation_http_code,
28172851
)
28182852

28192853
# Override _compile_regex to exclude trailing slashes for route resolution

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _handle_response(self, *, route: Route, response: Response):
150150
response.body = self._serialize_response(
151151
field=route.dependant.return_param,
152152
response_content=response.body,
153+
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
153154
)
154155

155156
return response
@@ -165,6 +166,7 @@ def _serialize_response(
165166
exclude_unset: bool = False,
166167
exclude_defaults: bool = False,
167168
exclude_none: bool = False,
169+
has_route_custom_response_validation: bool = False,
168170
) -> Any:
169171
"""
170172
Serialize the response content according to the field type.
@@ -174,7 +176,13 @@ def _serialize_response(
174176
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
175177
if errors:
176178
if self._has_response_validation_error:
177-
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content)
179+
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")
180+
if has_route_custom_response_validation:
181+
raise ResponseValidationError(
182+
errors=_normalize_errors(errors),
183+
body=response_content,
184+
source="route",
185+
)
178186
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
179187

180188
if hasattr(field, "serialize"):

aws_lambda_powertools/event_handler/openapi/exceptions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Sequence
1+
from typing import Any, Literal, Sequence
22

33

44
class ValidationException(Exception):
@@ -28,9 +28,10 @@ class ResponseValidationError(ValidationException):
2828
Raised when the response body does not match the OpenAPI schema
2929
"""
3030

31-
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
31+
def __init__(self, errors: Sequence[Any], *, body: Any = None, source: Literal["route", "app"] = "app") -> None:
3232
super().__init__(errors)
3333
self.body = body
34+
self.source = source
3435

3536

3637
class SerializationError(Exception):

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,3 +1378,67 @@ def test_custom_response_validation_error_bad_http_code(response_validation_erro
13781378
str(exception_info.value)
13791379
== f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code."
13801380
)
1381+
1382+
1383+
def test_custom_route_response_validation_error_http_code_invalid_response_incomplete_model(gw_event):
1384+
# GIVEN an APIGatewayRestResolver with custom response validation enabled
1385+
app = APIGatewayRestResolver(enable_validation=True)
1386+
1387+
class Model(BaseModel):
1388+
name: str
1389+
age: int
1390+
1391+
@app.get("/incomplete_model_not_allowed")
1392+
def handler_incomplete_model_not_allowed() -> Model:
1393+
return {"age": 18} # type: ignore
1394+
1395+
@app.get(
1396+
"/custom_incomplete_model_not_allowed",
1397+
custom_response_validation_http_code=500,
1398+
)
1399+
def handler_custom_route_response_validation_error() -> Model:
1400+
return {"age": 18} # type: ignore
1401+
1402+
# WHEN returning incomplete model for a non-Optional type
1403+
gw_event["path"] = "/incomplete_model_not_allowed"
1404+
result = app(gw_event, {})
1405+
1406+
gw_event["path"] = "/custom_incomplete_model_not_allowed"
1407+
custom_result = app(gw_event, {})
1408+
1409+
# THEN it should return a validation error with the custom status code provided
1410+
assert result["statusCode"] == 422
1411+
assert custom_result["statusCode"] == 500
1412+
assert json.loads(result["body"])["detail"] == json.loads(custom_result["body"])["detail"]
1413+
1414+
1415+
def test_custom_route_response_validation_error_sanitized_response(gw_event):
1416+
# GIVEN an APIGatewayRestResolver with custom response validation enabled
1417+
# with a sanitized response validation error response
1418+
app = APIGatewayRestResolver(enable_validation=True)
1419+
1420+
class Model(BaseModel):
1421+
name: str
1422+
age: int
1423+
1424+
@app.get(
1425+
"/custom_incomplete_model_not_allowed",
1426+
custom_response_validation_http_code=422,
1427+
)
1428+
def handler_custom_route_response_validation_error() -> Model:
1429+
return {"age": 18} # type: ignore
1430+
1431+
@app.exception_handler(ResponseValidationError)
1432+
def handle_response_validation_error(ex: ResponseValidationError):
1433+
return Response(
1434+
status_code=500,
1435+
body="Unexpected response.",
1436+
)
1437+
1438+
# WHEN returning incomplete model for a non-Optional type
1439+
gw_event["path"] = "/custom_incomplete_model_not_allowed"
1440+
result = app(gw_event, {})
1441+
1442+
# THEN it should return the sanitized response
1443+
assert result["statusCode"] == 500
1444+
assert result["body"] == "Unexpected response."

0 commit comments

Comments
 (0)