Skip to content

OpenAPI Schema Generation Fixes. #6827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion rest_framework/schemas/coreapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,18 @@
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')

# Generator #
# TODO: Pull some of this into base.


def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
for i, c in enumerate(s1):
if c != s2[i]:
common = s1[:i]
break
return '/' + '/'.join(common)


def is_custom_action(action):
Expand Down Expand Up @@ -209,6 +220,37 @@ def get_keys(self, subpath, method, view):
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]

def determine_path_prefix(self, paths):
"""
Given a list of all paths, return the common prefix which should be
discounted when generating a schema structure.

This will be the longest common string that does not include that last
component of the URL, or the last component before a path parameter.

For example:

/api/v1/users/
/api/v1/users/{pk}/

The path prefix is '/api/v1'
"""
prefixes = []
for path in paths:
components = path.strip('/').split('/')
initial_components = []
for component in components:
if '{' in component:
break
initial_components.append(component)
prefix = '/'.join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
return '/'
prefixes.append('/' + prefix + '/')
return common_path(prefixes)

# View Inspectors #


Expand Down
43 changes: 0 additions & 43 deletions rest_framework/schemas/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@
from rest_framework.utils.model_meta import _get_pk


def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
for i, c in enumerate(s1):
if c != s2[i]:
common = s1[:i]
break
return '/' + '/'.join(common)


def get_pk_name(model):
meta = model._meta.concrete_model._meta
return _get_pk(meta).name
Expand Down Expand Up @@ -236,37 +224,6 @@ def coerce_path(self, path, method, view):
def get_schema(self, request=None, public=False):
raise NotImplementedError(".get_schema() must be implemented in subclasses.")

def determine_path_prefix(self, paths):
"""
Given a list of all paths, return the common prefix which should be
discounted when generating a schema structure.

This will be the longest common string that does not include that last
component of the URL, or the last component before a path parameter.

For example:

/api/v1/users/
/api/v1/users/{pk}/

The path prefix is '/api/v1'
"""
prefixes = []
for path in paths:
components = path.strip('/').split('/')
initial_components = []
for component in components:
if '{' in component:
break
initial_components.append(component)
prefix = '/'.join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
return '/'
prefixes.append('/' + prefix + '/')
return common_path(prefixes)

def has_view_permissions(self, path, method, view):
"""
Return `True` if the incoming request has the correct view permissions.
Expand Down
14 changes: 8 additions & 6 deletions rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from urllib.parse import urljoin

from django.core.validators import (
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
Expand Down Expand Up @@ -39,17 +40,18 @@ def get_paths(self, request=None):
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
if prefix == '/': # no prefix
prefix = ''

for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
subpath = path[len(prefix):]
result.setdefault(subpath, {})
result[subpath][method.lower()] = operation
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
path = urljoin(self.url or '/', path)

result.setdefault(path, {})
result[path][method.lower()] = operation

return result

Expand Down
157 changes: 86 additions & 71 deletions tests/schemas/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,85 +190,36 @@ def test_repeat_operation_ids(self):
assert schema_str.count("newExample") == 1
assert schema_str.count("oldExample") == 1


@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'})
class TestGenerator(TestCase):

def test_override_settings(self):
assert isinstance(views.ExampleListView.schema, AutoSchema)

def test_paths_construction(self):
"""Construction of the `paths` key."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
generator._initialise_endpoints()

paths = generator.get_paths()

assert '/example/' in paths
example_operations = paths['/example/']
assert len(example_operations) == 2
assert 'get' in example_operations
assert 'post' in example_operations

def test_prefixed_paths_construction(self):
"""Construction of the `paths` key with a common prefix."""
patterns = [
url(r'^api/v1/example/?$', views.ExampleListView.as_view()),
url(r'^api/v1/example/{pk}/?$', views.ExampleDetailView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
generator._initialise_endpoints()

paths = generator.get_paths()

assert '/example/' in paths
assert '/example/{id}/' in paths

def test_schema_construction(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)

assert 'openapi' in schema
assert 'paths' in schema

def test_serializer_datefield(self):
patterns = [
url(r'^example/?$', views.ExampleGenericViewSet.as_view({"get": "get"})),
]
generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)

response = schema['paths']['/example/']['get']['responses']
response_schema = response['200']['content']['application/json']['schema']['properties']
path = '/'
method = 'GET'
view = create_view(
views.ExampleGenericAPIView,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view

responses = inspector._get_responses(path, method)
response_schema = responses['200']['content']['application/json']['schema']['properties']
assert response_schema['date']['type'] == response_schema['datetime']['type'] == 'string'

assert response_schema['date']['format'] == 'date'
assert response_schema['datetime']['format'] == 'date-time'

def test_serializer_validators(self):
patterns = [
url(r'^example/?$', views.ExampleValdidatedAPIView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)
path = '/'
method = 'GET'
view = create_view(
views.ExampleValidatedAPIView,
method,
create_request(path),
)
inspector = AutoSchema()
inspector.view = view

response = schema['paths']['/example/']['get']['responses']
response_schema = response['200']['content']['application/json']['schema']['properties']
responses = inspector._get_responses(path, method)
response_schema = responses['200']['content']['application/json']['schema']['properties']

assert response_schema['integer']['type'] == 'integer'
assert response_schema['integer']['maximum'] == 99
Expand Down Expand Up @@ -307,3 +258,67 @@ def test_serializer_validators(self):

assert response_schema['ip']['type'] == 'string'
assert 'format' not in response_schema['ip']


@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'})
class TestGenerator(TestCase):

def test_override_settings(self):
assert isinstance(views.ExampleListView.schema, AutoSchema)

def test_paths_construction(self):
"""Construction of the `paths` key."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
generator._initialise_endpoints()

paths = generator.get_paths()

assert '/example/' in paths
example_operations = paths['/example/']
assert len(example_operations) == 2
assert 'get' in example_operations
assert 'post' in example_operations

def test_prefixed_paths_construction(self):
"""Construction of the `paths` key maintains a common prefix."""
patterns = [
url(r'^v1/example/?$', views.ExampleListView.as_view()),
url(r'^v1/example/{pk}/?$', views.ExampleDetailView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
generator._initialise_endpoints()

paths = generator.get_paths()

assert '/v1/example/' in paths
assert '/v1/example/{id}/' in paths

def test_mount_url_prefixed_to_paths(self):
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
url(r'^example/{pk}/?$', views.ExampleDetailView.as_view()),
]
generator = SchemaGenerator(patterns=patterns, url='/api')
generator._initialise_endpoints()

paths = generator.get_paths()

assert '/api/example/' in paths
assert '/api/example/{id}/' in paths

def test_schema_construction(self):
"""Construction of the top level dictionary."""
patterns = [
url(r'^example/?$', views.ExampleListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)

assert 'openapi' in schema
assert 'paths' in schema
2 changes: 1 addition & 1 deletion tests/schemas/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ExampleValidatedSerializer(serializers.Serializer):
ip = serializers.IPAddressField()


class ExampleValdidatedAPIView(generics.GenericAPIView):
class ExampleValidatedAPIView(generics.GenericAPIView):
serializer_class = ExampleValidatedSerializer

def get(self, *args, **kwargs):
Expand Down