Skip to content

Response headers support #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions openapi_core/deserializing/parameters/deserializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.deserializing.parameters.exceptions import (
EmptyParameterValue,
Expand All @@ -7,19 +9,28 @@

class PrimitiveDeserializer(object):

def __init__(self, param, deserializer_callable):
self.param = param
def __init__(self, param_or_header, deserializer_callable):
self.param_or_header = param_or_header
self.deserializer_callable = deserializer_callable

self.aslist = get_aslist(self.param)
self.explode = get_explode(self.param)
self.style = get_style(self.param)
self.aslist = get_aslist(self.param_or_header)
self.explode = get_explode(self.param_or_header)
self.style = get_style(self.param_or_header)

def __call__(self, value):
if (self.param['in'] == 'query' and value == "" and
not self.param.getkey('allowEmptyValue', False)):
raise EmptyParameterValue(
value, self.style, self.param['name'])
# if "in" not defined then it's a Header
if 'allowEmptyValue' in self.param_or_header:
warnings.warn(
"Use of allowEmptyValue property is deprecated",
DeprecationWarning,
)
allow_empty_values = self.param_or_header.getkey(
'allowEmptyValue', False)
location_name = self.param_or_header.getkey('in', 'header')
if (location_name == 'query' and value == "" and
not allow_empty_values):
name = self.param_or_header.getkey('name', 'header')
raise EmptyParameterValue(value, self.style, name)

if not self.aslist or self.explode:
return value
Expand Down
26 changes: 26 additions & 0 deletions openapi_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@ class OpenAPIError(Exception):
pass


class OpenAPIHeaderError(OpenAPIError):
pass


class MissingHeaderError(OpenAPIHeaderError):
"""Missing header error"""
pass


@attr.s(hash=True)
class MissingHeader(MissingHeaderError):
name = attr.ib()

def __str__(self):
return "Missing header (without default value): {0}".format(
self.name)


@attr.s(hash=True)
class MissingRequiredHeader(MissingHeaderError):
name = attr.ib()

def __str__(self):
return "Missing required header: {0}".format(self.name)


class OpenAPIParameterError(OpenAPIError):
pass

Expand Down
50 changes: 35 additions & 15 deletions openapi_core/schema/parameters.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,54 @@
from __future__ import division


def get_aslist(param):
"""Checks if parameter is described as list for simpler scenarios"""
def get_aslist(param_or_header):
"""Checks if parameter/header is described as list for simpler scenarios"""
# if schema is not defined it's a complex scenario
if 'schema' not in param:
if 'schema' not in param_or_header:
return False

param_schema = param / 'schema'
schema_type = param_schema.getkey('type', 'any')
schema = param_or_header / 'schema'
schema_type = schema.getkey('type', 'any')
# TODO: resolve for 'any' schema type
return schema_type in ['array', 'object']


def get_style(param):
"""Checks parameter style for simpler scenarios"""
if 'style' in param:
return param['style']
def get_style(param_or_header):
"""Checks parameter/header style for simpler scenarios"""
if 'style' in param_or_header:
return param_or_header['style']

# if "in" not defined then it's a Header
location = param_or_header.getkey('in', 'header')

# determine default
return (
'simple' if param['in'] in ['path', 'header'] else 'form'
'simple' if location in ['path', 'header'] else 'form'
)


def get_explode(param):
"""Checks parameter explode for simpler scenarios"""
if 'explode' in param:
return param['explode']
def get_explode(param_or_header):
"""Checks parameter/header explode for simpler scenarios"""
if 'explode' in param_or_header:
return param_or_header['explode']

# determine default
style = get_style(param)
style = get_style(param_or_header)
return style == 'form'


def get_value(param_or_header, location, name=None):
"""Returns parameter/header value from specific location"""
name = name or param_or_header['name']

if name not in location:
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if hasattr(location, 'getall'):
return location.getall(name)
return location.getlist(name)

return location[name]
5 changes: 4 additions & 1 deletion openapi_core/testing/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
class MockResponseFactory(object):

@classmethod
def create(cls, data, status_code=200, mimetype='application/json'):
def create(
cls, data, status_code=200, headers=None,
mimetype='application/json'):
return OpenAPIResponse(
data=data,
status_code=status_code,
headers=headers or {},
mimetype=mimetype,
)
47 changes: 12 additions & 35 deletions openapi_core/validation/request/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.deserializing.parameters.factories import (
ParameterDeserializersFactory,
)
from openapi_core.exceptions import (
MissingRequiredParameter, MissingParameter,
MissingRequiredRequestBody, MissingRequestBody,
Expand Down Expand Up @@ -46,10 +43,6 @@ def schema_unmarshallers_factory(self):
def security_provider_factory(self):
return SecurityProviderFactory()

@property
def parameter_deserializers_factory(self):
return ParameterDeserializersFactory()

def validate(self, request):
try:
path, operation, _, path_result, _ = self._find_path(request)
Expand Down Expand Up @@ -177,35 +170,23 @@ def _get_parameters(self, request, params):
return RequestParameters(**locations), errors

def _get_parameter(self, param, request):
if param.getkey('deprecated', False):
name = param['name']
deprecated = param.getkey('deprecated', False)
if deprecated:
warnings.warn(
"{0} parameter is deprecated".format(param['name']),
"{0} parameter is deprecated".format(name),
DeprecationWarning,
)

param_location = param['in']
location = request.parameters[param_location]
try:
raw_value = self._get_parameter_value(param, request)
except MissingParameter:
if 'schema' not in param:
raise
schema = param / 'schema'
if 'default' not in schema:
raise
casted = schema['default']
else:
# Simple scenario
if 'content' not in param:
deserialised = self._deserialise_parameter(param, raw_value)
schema = param / 'schema'
# Complex scenario
else:
content = param / 'content'
mimetype, media_type = next(content.items())
deserialised = self._deserialise_data(mimetype, raw_value)
schema = media_type / 'schema'
casted = self._cast(schema, deserialised)
unmarshalled = self._unmarshal(schema, casted)
return unmarshalled
return self._get_param_or_header_value(param, location)
except KeyError:
required = param.getkey('required', False)
if required:
raise MissingRequiredParameter(name)
raise MissingParameter(name)

def _get_body(self, request, operation):
if 'requestBody' not in operation:
Expand Down Expand Up @@ -280,7 +261,3 @@ def _get_body_value(self, request_body, request):
raise MissingRequiredRequestBody(request)
raise MissingRequestBody(request)
return request.body

def _deserialise_parameter(self, param, value):
deserializer = self.parameter_deserializers_factory.create(param)
return deserializer(value)
6 changes: 4 additions & 2 deletions openapi_core/validation/response/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""OpenAPI core validation response datatypes module"""
import attr
from werkzeug.datastructures import Headers

from openapi_core.validation.datatypes import BaseValidationResult

Expand All @@ -13,14 +14,15 @@ class OpenAPIResponse(object):
The response body, as string.
status_code
The status code as integer.
headers
Response headers as Headers.
mimetype
Lowercase content type without charset.
"""

data = attr.ib()
status_code = attr.ib()

mimetype = attr.ib()
headers = attr.ib(factory=Headers, converter=Headers)


@attr.s
Expand Down
49 changes: 44 additions & 5 deletions openapi_core/validation/response/validators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""OpenAPI core validation response validators module"""
from __future__ import division
import warnings

from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.exceptions import MissingResponseContent
from openapi_core.exceptions import (
MissingHeader, MissingRequiredHeader, MissingResponseContent,
)
from openapi_core.templating.media_types.exceptions import MediaTypeFinderError
from openapi_core.templating.paths.exceptions import PathError
from openapi_core.templating.responses.exceptions import ResponseFinderError
Expand Down Expand Up @@ -117,12 +120,48 @@ def _get_data(self, response, operation_response):
return data, []

def _get_headers(self, response, operation_response):
errors = []
if 'headers' not in operation_response:
return {}, []

# @todo: implement
headers = {}
headers = operation_response / 'headers'

return headers, errors
errors = []
validated = {}
for name, header in headers.items():
# ignore Content-Type header
if name.lower() == "content-type":
continue
try:
value = self._get_header(name, header, response)
except MissingHeader:
continue
except (
MissingRequiredHeader, DeserializeError,
CastError, ValidateError, UnmarshalError,
) as exc:
errors.append(exc)
continue
else:
validated[name] = value

return validated, errors

def _get_header(self, name, header, response):
deprecated = header.getkey('deprecated', False)
if deprecated:
warnings.warn(
"{0} header is deprecated".format(name),
DeprecationWarning,
)

try:
return self._get_param_or_header_value(
header, response.headers, name=name)
except KeyError:
required = header.getkey('required', False)
if required:
raise MissingRequiredHeader(name)
raise MissingHeader(name)

def _get_data_value(self, response):
if not response.data:
Expand Down
38 changes: 38 additions & 0 deletions openapi_core/validation/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
from openapi_core.deserializing.parameters.factories import (
ParameterDeserializersFactory,
)
from openapi_core.schema.parameters import get_value
from openapi_core.templating.paths.finders import PathFinder
from openapi_core.unmarshalling.schemas.util import build_format_checker

Expand Down Expand Up @@ -36,6 +40,10 @@ def media_type_deserializers_factory(self):
return MediaTypeDeserializersFactory(
self.custom_media_type_deserializers)

@property
def parameter_deserializers_factory(self):
return ParameterDeserializersFactory()

@property
def schema_unmarshallers_factory(self):
raise NotImplementedError
Expand All @@ -52,10 +60,40 @@ def _deserialise_data(self, mimetype, value):
deserializer = self.media_type_deserializers_factory.create(mimetype)
return deserializer(value)

def _deserialise_parameter(self, param, value):
deserializer = self.parameter_deserializers_factory.create(param)
return deserializer(value)

def _cast(self, schema, value):
caster = self.schema_casters_factory.create(schema)
return caster(value)

def _unmarshal(self, schema, value):
unmarshaller = self.schema_unmarshallers_factory.create(schema)
return unmarshaller(value)

def _get_param_or_header_value(self, param_or_header, location, name=None):
try:
raw_value = get_value(param_or_header, location, name=name)
except KeyError:
if 'schema' not in param_or_header:
raise
schema = param_or_header / 'schema'
if 'default' not in schema:
raise
casted = schema['default']
else:
# Simple scenario
if 'content' not in param_or_header:
deserialised = self._deserialise_parameter(
param_or_header, raw_value)
schema = param_or_header / 'schema'
# Complex scenario
else:
content = param_or_header / 'content'
mimetype, media_type = next(content.items())
deserialised = self._deserialise_data(mimetype, raw_value)
schema = media_type / 'schema'
casted = self._cast(schema, deserialised)
unmarshalled = self._unmarshal(schema, casted)
return unmarshalled
Loading