Skip to content

Commit 8a5e963

Browse files
committed
Separate Generator code by schema type.
1 parent 96f1dec commit 8a5e963

File tree

6 files changed

+259
-255
lines changed

6 files changed

+259
-255
lines changed

rest_framework/management/commands/generateschema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from django.core.management.base import BaseCommand
22

33
from rest_framework.compat import yaml
4-
from rest_framework.schemas.generators import OpenAPISchemaGenerator
4+
from rest_framework.schemas.openapi import SchemaGenerator
55
from rest_framework.utils import json
66

77

@@ -15,7 +15,7 @@ def add_arguments(self, parser):
1515
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
1616

1717
def handle(self, *args, **options):
18-
generator = OpenAPISchemaGenerator(
18+
generator = SchemaGenerator(
1919
url=options['url'],
2020
title=options['title'],
2121
description=options['description']

rest_framework/schemas/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
"""
2323
from rest_framework.settings import api_settings
2424

25-
from .generators import SchemaGenerator
2625
from .inspectors import DefaultSchema # noqa
27-
from .coreapi import AutoSchema, ManualSchema # noqa
26+
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
2827

2928

3029
def get_schema_view(

rest_framework/schemas/coreapi.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
import warnings
3-
from collections import OrderedDict
3+
from collections import Counter, OrderedDict
44

55
from django.db import models
66
from django.utils.encoding import force_text, smart_text
@@ -11,13 +11,206 @@
1111
from rest_framework.settings import api_settings
1212
from rest_framework.utils import formatting
1313

14+
from .generators import BaseSchemaGenerator
1415
from .inspectors import ViewInspector
1516
from .utils import get_pk_description, is_list_view
1617

1718
# Used in _get_description_section()
1819
# TODO: ???: move up to base.
1920
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
2021

22+
# Generator #
23+
# TODO: Pull some of this into base.
24+
25+
26+
def is_custom_action(action):
27+
return action not in {
28+
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
29+
}
30+
31+
32+
def distribute_links(obj):
33+
for key, value in obj.items():
34+
distribute_links(value)
35+
36+
for preferred_key, link in obj.links:
37+
key = obj.get_available_key(preferred_key)
38+
obj[key] = link
39+
40+
41+
INSERT_INTO_COLLISION_FMT = """
42+
Schema Naming Collision.
43+
44+
coreapi.Link for URL path {value_url} cannot be inserted into schema.
45+
Position conflicts with coreapi.Link for URL path {target_url}.
46+
47+
Attempted to insert link with keys: {keys}.
48+
49+
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
50+
to customise schema structure.
51+
"""
52+
53+
54+
class LinkNode(OrderedDict):
55+
def __init__(self):
56+
self.links = []
57+
self.methods_counter = Counter()
58+
super(LinkNode, self).__init__()
59+
60+
def get_available_key(self, preferred_key):
61+
if preferred_key not in self:
62+
return preferred_key
63+
64+
while True:
65+
current_val = self.methods_counter[preferred_key]
66+
self.methods_counter[preferred_key] += 1
67+
68+
key = '{}_{}'.format(preferred_key, current_val)
69+
if key not in self:
70+
return key
71+
72+
73+
def insert_into(target, keys, value):
74+
"""
75+
Nested dictionary insertion.
76+
77+
>>> example = {}
78+
>>> insert_into(example, ['a', 'b', 'c'], 123)
79+
>>> example
80+
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
81+
"""
82+
for key in keys[:-1]:
83+
if key not in target:
84+
target[key] = LinkNode()
85+
target = target[key]
86+
87+
try:
88+
target.links.append((keys[-1], value))
89+
except TypeError:
90+
msg = INSERT_INTO_COLLISION_FMT.format(
91+
value_url=value.url,
92+
target_url=target.url,
93+
keys=keys
94+
)
95+
raise ValueError(msg)
96+
97+
98+
class SchemaGenerator(BaseSchemaGenerator):
99+
"""
100+
Original CoreAPI version.
101+
"""
102+
# Map HTTP methods onto actions.
103+
default_mapping = {
104+
'get': 'retrieve',
105+
'post': 'create',
106+
'put': 'update',
107+
'patch': 'partial_update',
108+
'delete': 'destroy',
109+
}
110+
111+
# Map the method names we use for viewset actions onto external schema names.
112+
# These give us names that are more suitable for the external representation.
113+
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
114+
coerce_method_names = None
115+
116+
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
117+
assert coreapi, '`coreapi` must be installed for schema support.'
118+
assert coreschema, '`coreschema` must be installed for schema support.'
119+
120+
super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf)
121+
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
122+
123+
def get_links(self, request=None):
124+
"""
125+
Return a dictionary containing all the links that should be
126+
included in the API schema.
127+
"""
128+
links = LinkNode()
129+
130+
paths, view_endpoints = self._get_paths_and_endpoints(request)
131+
132+
# Only generate the path prefix for paths that will be included
133+
if not paths:
134+
return None
135+
prefix = self.determine_path_prefix(paths)
136+
137+
for path, method, view in view_endpoints:
138+
if not self.has_view_permissions(path, method, view):
139+
continue
140+
link = view.schema.get_link(path, method, base_url=self.url)
141+
subpath = path[len(prefix):]
142+
keys = self.get_keys(subpath, method, view)
143+
insert_into(links, keys, link)
144+
145+
return links
146+
147+
def get_schema(self, request=None, public=False):
148+
"""
149+
Generate a `coreapi.Document` representing the API schema.
150+
"""
151+
self._initialise_endpoints()
152+
153+
links = self.get_links(None if public else request)
154+
if not links:
155+
return None
156+
157+
url = self.url
158+
if not url and request is not None:
159+
url = request.build_absolute_uri()
160+
161+
distribute_links(links)
162+
return coreapi.Document(
163+
title=self.title, description=self.description,
164+
url=url, content=links
165+
)
166+
167+
# Method for generating the link layout....
168+
def get_keys(self, subpath, method, view):
169+
"""
170+
Return a list of keys that should be used to layout a link within
171+
the schema document.
172+
173+
/users/ ("users", "list"), ("users", "create")
174+
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
175+
/users/enabled/ ("users", "enabled") # custom viewset list action
176+
/users/{pk}/star/ ("users", "star") # custom viewset detail action
177+
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
178+
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
179+
"""
180+
if hasattr(view, 'action'):
181+
# Viewsets have explicitly named actions.
182+
action = view.action
183+
else:
184+
# Views have no associated action, so we determine one from the method.
185+
if is_list_view(subpath, method, view):
186+
action = 'list'
187+
else:
188+
action = self.default_mapping[method.lower()]
189+
190+
named_path_components = [
191+
component for component
192+
in subpath.strip('/').split('/')
193+
if '{' not in component
194+
]
195+
196+
if is_custom_action(action):
197+
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
198+
if len(view.action_map) > 1:
199+
action = self.default_mapping[method.lower()]
200+
if action in self.coerce_method_names:
201+
action = self.coerce_method_names[action]
202+
return named_path_components + [action]
203+
else:
204+
return named_path_components[:-1] + [action]
205+
206+
if action in self.coerce_method_names:
207+
action = self.coerce_method_names[action]
208+
209+
# Default action, eg "/users/", "/users/{pk}/"
210+
return named_path_components + [action]
211+
212+
# View Inspectors #
213+
21214

22215
def field_to_schema(field):
23216
title = force_text(field.label) if field.label else ''

0 commit comments

Comments
 (0)