Skip to content

Commit 573dcdf

Browse files
committed
Rebase on #7127 and remove _get_item_schema refactoring
1 parent e4a26ad commit 573dcdf

File tree

3 files changed

+102
-15
lines changed

3 files changed

+102
-15
lines changed

rest_framework/schemas/openapi.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from enum import Enum
23
from operator import attrgetter
34
from urllib.parse import urljoin
45

@@ -37,16 +38,21 @@ def get_schema(self, request=None, public=False):
3738
Generate a OpenAPI schema.
3839
"""
3940
self._initialise_endpoints()
41+
components_schemas = {}
4042

4143
# Iterate endpoints generating per method path operations.
42-
# TODO: …and reference components.
4344
paths = {}
4445
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
4546
for path, method, view in view_endpoints:
4647
if not self.has_view_permissions(path, method, view):
4748
continue
4849

4950
operation = view.schema.get_operation(path, method)
51+
component = view.schema.get_components(path, method)
52+
53+
if component is not None:
54+
components_schemas.update(component)
55+
5056
# Normalise path for any provided mount url.
5157
if path.startswith('/'):
5258
path = path[1:]
@@ -59,9 +65,14 @@ def get_schema(self, request=None, public=False):
5965
schema = {
6066
'openapi': '3.0.2',
6167
'info': self.get_info(),
62-
'paths': paths,
68+
'paths': paths
6369
}
6470

71+
if len(components_schemas) > 0:
72+
schema['components'] = {
73+
'schemas': components_schemas
74+
}
75+
6576
return schema
6677

6778
# View Inspectors
@@ -99,6 +110,21 @@ def get_operation(self, path, method):
99110

100111
return operation
101112

113+
def get_components(self, path, method):
114+
serializer = self._get_serializer(path, method)
115+
116+
if not isinstance(serializer, serializers.Serializer):
117+
return None
118+
119+
# If the model has no model, then the serializer will be inlined
120+
if not hasattr(serializer, 'Meta') or not hasattr(serializer.Meta, 'model'):
121+
return None
122+
123+
model_name = serializer.Meta.model.__name__
124+
content = self._map_serializer(serializer)
125+
126+
return {model_name: content}
127+
102128
def _get_operation_id(self, path, method):
103129
"""
104130
Compute an operation ID from the model, serializer or view name.
@@ -464,6 +490,10 @@ def _get_serializer(self, path, method):
464490
.format(view.__class__.__name__, method, path))
465491
return None
466492

493+
def _get_reference(self, serializer):
494+
model_name = serializer.Meta.model.__name__
495+
return {'$ref': '#/components/schemas/{}'.format(model_name)}
496+
467497
def _get_request_body(self, path, method):
468498
if method not in ('PUT', 'PATCH', 'POST'):
469499
return {}
@@ -473,20 +503,30 @@ def _get_request_body(self, path, method):
473503
serializer = self._get_serializer(path, method)
474504

475505
if not isinstance(serializer, serializers.Serializer):
476-
return {}
477-
478-
content = self._map_serializer(serializer)
479-
# No required fields for PATCH
480-
if method == 'PATCH':
481-
content.pop('required', None)
482-
# No read_only fields for request.
483-
for name, schema in content['properties'].copy().items():
484-
if 'readOnly' in schema:
485-
del content['properties'][name]
506+
item_schema = {}
507+
elif hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
508+
# If the serializer uses a model, we should use a reference
509+
item_schema = self._get_reference(serializer)
510+
else:
511+
# There is no model, we'll map the serializer's fields
512+
item_schema = self._map_serializer(serializer)
513+
# No required fields for PATCH
514+
if method == 'PATCH':
515+
item_schema.pop('required', None)
516+
# No read_only fields for request.
517+
# No write_only fields for response.
518+
for name, schema in item_schema['properties'].copy().items():
519+
if 'writeOnly' in schema:
520+
del item_schema['properties'][name]
521+
if 'required' in item_schema:
522+
item_schema['required'] = [f for f in item_schema['required'] if f != name]
523+
for name, schema in item_schema['properties'].copy().items():
524+
if 'readOnly' in schema:
525+
del item_schema['properties'][name]
486526

487527
return {
488528
'content': {
489-
ct: {'schema': content}
529+
ct: {'schema': item_schema}
490530
for ct in self.request_media_types
491531
}
492532
}
@@ -502,10 +542,15 @@ def _get_responses(self, path, method):
502542

503543
self.response_media_types = self.map_renderers(path, method)
504544

505-
item_schema = {}
506545
serializer = self._get_serializer(path, method)
507546

508-
if isinstance(serializer, serializers.Serializer):
547+
if not isinstance(serializer, serializers.Serializer):
548+
item_schema = {}
549+
elif hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
550+
# If the serializer uses a model, we should use a reference
551+
item_schema = self._get_reference(serializer)
552+
else:
553+
# There is no model, we'll map the serializer's fields
509554
item_schema = self._map_serializer(serializer)
510555
# No write_only fields for response.
511556
for name, schema in item_schema['properties'].copy().items():

tests/schemas/test_openapi.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,3 +744,18 @@ def test_schema_information_empty(self):
744744

745745
assert schema['info']['title'] == ''
746746
assert schema['info']['version'] == ''
747+
748+
def test_serializer_model(self):
749+
"""Construction of the top level dictionary."""
750+
patterns = [
751+
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
752+
]
753+
generator = SchemaGenerator(patterns=patterns)
754+
755+
request = create_request('/')
756+
schema = generator.get_schema(request=request)
757+
758+
print(schema)
759+
assert 'components' in schema
760+
assert 'schemas' in schema['components']
761+
assert 'OpenAPIExample' in schema['components']['schemas']

tests/schemas/views.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
DecimalValidator, MaxLengthValidator, MaxValueValidator,
55
MinLengthValidator, MinValueValidator, RegexValidator
66
)
7+
from django.db import models
78

89
from rest_framework import generics, permissions, serializers
910
from rest_framework.decorators import action
@@ -137,3 +138,29 @@ def get(self, *args, **kwargs):
137138
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
138139
ip='192.168.1.1')
139140
return Response(serializer.data)
141+
142+
143+
# Serializer with model.
144+
class OpenAPIExample(models.Model):
145+
first_name = models.CharField(max_length=30)
146+
147+
148+
class ExampleSerializerModel(serializers.Serializer):
149+
date = serializers.DateField()
150+
datetime = serializers.DateTimeField()
151+
hstore = serializers.HStoreField()
152+
uuid_field = serializers.UUIDField(default=uuid.uuid4)
153+
154+
class Meta:
155+
model = OpenAPIExample
156+
157+
158+
class ExampleGenericAPIViewModel(generics.GenericAPIView):
159+
serializer_class = ExampleSerializerModel
160+
161+
def get(self, *args, **kwargs):
162+
from datetime import datetime
163+
now = datetime.now()
164+
165+
serializer = self.get_serializer(data=now.date(), datetime=now)
166+
return Response(serializer.data)

0 commit comments

Comments
 (0)