Skip to content

Commit 1e383f1

Browse files
authored
Check extra action func.__name__ (#7098)
1 parent 0d2bbd3 commit 1e383f1

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

rest_framework/viewsets.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ def _is_extra_action(attr):
3333
return hasattr(attr, 'mapping') and isinstance(attr.mapping, MethodMapper)
3434

3535

36+
def _check_attr_name(func, name):
37+
assert func.__name__ == name, (
38+
'Expected function (`{func.__name__}`) to match its attribute name '
39+
'(`{name}`). If using a decorator, ensure the inner function is '
40+
'decorated with `functools.wraps`, or that `{func.__name__}.__name__` '
41+
'is otherwise set to `{name}`.').format(func=func, name=name)
42+
return func
43+
44+
3645
class ViewSetMixin:
3746
"""
3847
This is the magic.
@@ -164,7 +173,9 @@ def get_extra_actions(cls):
164173
"""
165174
Get the methods that are marked as an extra ViewSet `@action`.
166175
"""
167-
return [method for _, method in getmembers(cls, _is_extra_action)]
176+
return [_check_attr_name(method, name)
177+
for name, method
178+
in getmembers(cls, _is_extra_action)]
168179

169180
def get_extra_action_url_map(self):
170181
"""

tests/test_viewsets.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import OrderedDict
2+
from functools import wraps
23

34
import pytest
45
from django.conf.urls import include, url
@@ -33,6 +34,13 @@ class Action(models.Model):
3334
pass
3435

3536

37+
def decorate(fn):
38+
@wraps(fn)
39+
def wrapper(self, request, *args, **kwargs):
40+
return fn(self, request, *args, **kwargs)
41+
return wrapper
42+
43+
3644
class ActionViewSet(GenericViewSet):
3745
queryset = Action.objects.all()
3846

@@ -68,6 +76,16 @@ def custom_detail_action(self, request, *args, **kwargs):
6876
def unresolvable_detail_action(self, request, *args, **kwargs):
6977
raise NotImplementedError
7078

79+
@action(detail=False)
80+
@decorate
81+
def wrapped_list_action(self, request, *args, **kwargs):
82+
raise NotImplementedError
83+
84+
@action(detail=True)
85+
@decorate
86+
def wrapped_detail_action(self, request, *args, **kwargs):
87+
raise NotImplementedError
88+
7189

7290
class ActionNamesViewSet(GenericViewSet):
7391

@@ -191,6 +209,8 @@ def test_extra_actions(self):
191209
'detail_action',
192210
'list_action',
193211
'unresolvable_detail_action',
212+
'wrapped_detail_action',
213+
'wrapped_list_action',
194214
]
195215

196216
self.assertEqual(actual, expected)
@@ -204,9 +224,35 @@ def test_should_only_return_decorated_methods(self):
204224
'detail_action',
205225
'list_action',
206226
'unresolvable_detail_action',
227+
'wrapped_detail_action',
228+
'wrapped_list_action',
207229
]
208230
self.assertEqual(actual, expected)
209231

232+
def test_attr_name_check(self):
233+
def decorate(fn):
234+
def wrapper(self, request, *args, **kwargs):
235+
return fn(self, request, *args, **kwargs)
236+
return wrapper
237+
238+
class ActionViewSet(GenericViewSet):
239+
queryset = Action.objects.all()
240+
241+
@action(detail=False)
242+
@decorate
243+
def wrapped_list_action(self, request, *args, **kwargs):
244+
raise NotImplementedError
245+
246+
view = ActionViewSet()
247+
with pytest.raises(AssertionError) as excinfo:
248+
view.get_extra_actions()
249+
250+
assert str(excinfo.value) == (
251+
'Expected function (`wrapper`) to match its attribute name '
252+
'(`wrapped_list_action`). If using a decorator, ensure the inner '
253+
'function is decorated with `functools.wraps`, or that '
254+
'`wrapper.__name__` is otherwise set to `wrapped_list_action`.')
255+
210256

211257
@override_settings(ROOT_URLCONF='tests.test_viewsets')
212258
class GetExtraActionUrlMapTests(TestCase):
@@ -218,6 +264,7 @@ def test_list_view(self):
218264
expected = OrderedDict([
219265
('Custom list action', 'http://testserver/api/actions/custom_list_action/'),
220266
('List action', 'http://testserver/api/actions/list_action/'),
267+
('Wrapped list action', 'http://testserver/api/actions/wrapped_list_action/'),
221268
])
222269

223270
self.assertEqual(view.get_extra_action_url_map(), expected)
@@ -229,6 +276,7 @@ def test_detail_view(self):
229276
expected = OrderedDict([
230277
('Custom detail action', 'http://testserver/api/actions/1/custom_detail_action/'),
231278
('Detail action', 'http://testserver/api/actions/1/detail_action/'),
279+
('Wrapped detail action', 'http://testserver/api/actions/1/wrapped_detail_action/'),
232280
# "Unresolvable detail action" excluded, since it's not resolvable
233281
])
234282

0 commit comments

Comments
 (0)