Skip to content

Commit c82f6aa

Browse files
committed
Merge remote-tracking branch 'origin/master' into documentation/include_translations_in_process
2 parents e8e1908 + 0c97dd1 commit c82f6aa

File tree

10 files changed

+297
-16
lines changed

10 files changed

+297
-16
lines changed

docs/api-guide/authentication.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ Unauthenticated responses that are denied permission will result in an `HTTP 403
247247

248248
If you're using an AJAX style API with SessionAuthentication, you'll need to make sure you include a valid CSRF token for any "unsafe" HTTP method calls, such as `PUT`, `PATCH`, `POST` or `DELETE` requests. See the [Django CSRF documentation][csrf-ajax] for more details.
249249

250+
**Warning**: Always use Django's standard login view when creating login pages. This will ensure your login views are properly protected.
251+
252+
CSRF validation in REST framework works slightly differently to standard Django due to the need to support both session and non-session based authentication to the same views. This means that only authenticated requests require CSRF tokens, and anonymous requests may be sent without CSRF tokens. This behaviour is not suitable for login views, which should always have CSRF validation applied.
253+
250254
# Custom authentication
251255

252256
To implement a custom authentication scheme, subclass `BaseAuthentication` and override the `.authenticate(self, request)` method. The method should return a two-tuple of `(user, auth)` if authentication succeeds, or `None` otherwise.

docs/api-guide/fields.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,18 @@ Corresponds to `django.db.models.fields.TimeField`
302302

303303
Format strings may either be [Python strftime formats][strftime] which explicitly specify the format, or the special string `'iso-8601'`, which indicates that [ISO 8601][iso8601] style times should be used. (eg `'12:34:56.000000'`)
304304

305+
## DurationField
306+
307+
A Duration representation.
308+
Corresponds to `django.db.models.fields.DurationField`
309+
310+
The `validated_data` for these fields will contain a `datetime.timedelta` instance.
311+
The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`.
312+
313+
**Note:** This field is only available with Django versions >= 1.8.
314+
315+
**Signature:** `DurationField()`
316+
305317
---
306318

307319
# Choice selection fields

rest_framework/compat.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import unicode_literals
88
from django.core.exceptions import ImproperlyConfigured
99
from django.conf import settings
10+
from django.db import connection, transaction
1011
from django.utils.encoding import force_text
1112
from django.utils.six.moves.urllib.parse import urlparse as _urlparse
1213
from django.utils import six
@@ -258,3 +259,27 @@ def apply_markdown(text):
258259
SHORT_SEPARATORS = (b',', b':')
259260
LONG_SEPARATORS = (b', ', b': ')
260261
INDENT_SEPARATORS = (b',', b': ')
262+
263+
264+
if django.VERSION >= (1, 8):
265+
from django.db.models import DurationField
266+
from django.utils.dateparse import parse_duration
267+
from django.utils.duration import duration_string
268+
else:
269+
DurationField = duration_string = parse_duration = None
270+
271+
272+
def set_rollback():
273+
if hasattr(transaction, 'set_rollback'):
274+
if connection.settings_dict.get('ATOMIC_REQUESTS', False):
275+
# If running in >=1.6 then mark a rollback as required,
276+
# and allow it to be handled by Django.
277+
transaction.set_rollback(True)
278+
elif transaction.is_managed():
279+
# Otherwise handle it explicitly if in managed mode.
280+
if transaction.is_dirty():
281+
transaction.rollback()
282+
transaction.leave_transaction_management()
283+
else:
284+
# transaction not managed
285+
pass

rest_framework/fields.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from rest_framework.compat import (
1313
EmailValidator, MinValueValidator, MaxValueValidator,
1414
MinLengthValidator, MaxLengthValidator, URLValidator, OrderedDict,
15-
unicode_repr, unicode_to_repr
15+
unicode_repr, unicode_to_repr, parse_duration, duration_string,
1616
)
1717
from rest_framework.exceptions import ValidationError
1818
from rest_framework.settings import api_settings
@@ -1003,6 +1003,29 @@ def to_representation(self, value):
10031003
return value.strftime(self.format)
10041004

10051005

1006+
class DurationField(Field):
1007+
default_error_messages = {
1008+
'invalid': _('Duration has wrong format. Use one of these formats instead: {format}.'),
1009+
}
1010+
1011+
def __init__(self, *args, **kwargs):
1012+
if parse_duration is None:
1013+
raise NotImplementedError(
1014+
'DurationField not supported for django versions prior to 1.8')
1015+
return super(DurationField, self).__init__(*args, **kwargs)
1016+
1017+
def to_internal_value(self, value):
1018+
if isinstance(value, datetime.timedelta):
1019+
return value
1020+
parsed = parse_duration(value)
1021+
if parsed is not None:
1022+
return parsed
1023+
self.fail('invalid', format='[DD] [HH:[MM:]]ss[.uuuuuu]')
1024+
1025+
def to_representation(self, value):
1026+
return duration_string(value)
1027+
1028+
10061029
# Choice types...
10071030

10081031
class ChoiceField(Field):
@@ -1060,7 +1083,11 @@ def get_value(self, dictionary):
10601083
# We override the default field access in order to support
10611084
# lists in HTML forms.
10621085
if html.is_html_input(dictionary):
1063-
return dictionary.getlist(self.field_name)
1086+
ret = dictionary.getlist(self.field_name)
1087+
if getattr(self.root, 'partial', False) and not ret:
1088+
ret = empty
1089+
return ret
1090+
10641091
return dictionary.get(self.field_name, empty)
10651092

10661093
def to_internal_value(self, data):

rest_framework/serializers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField
1616
from django.db.models import query
1717
from django.utils.translation import ugettext_lazy as _
18-
from rest_framework.compat import postgres_fields, unicode_to_repr
18+
from rest_framework.compat import (
19+
postgres_fields,
20+
unicode_to_repr,
21+
DurationField as ModelDurationField,
22+
)
1923
from rest_framework.utils import model_meta
2024
from rest_framework.utils.field_mapping import (
2125
get_url_kwargs, get_field_kwargs,
@@ -731,6 +735,8 @@ class ModelSerializer(Serializer):
731735
models.TimeField: TimeField,
732736
models.URLField: URLField,
733737
}
738+
if ModelDurationField is not None:
739+
serializer_field_mapping[ModelDurationField] = DurationField
734740
serializer_related_field = PrimaryKeyRelatedField
735741
serializer_url_field = HyperlinkedIdentityField
736742
serializer_choice_field = ChoiceField
@@ -1088,6 +1094,9 @@ def include_extra_kwargs(self, kwargs, extra_kwargs):
10881094
if extra_kwargs.get('default') and kwargs.get('required') is False:
10891095
kwargs.pop('required')
10901096

1097+
if kwargs.get('read_only', False):
1098+
extra_kwargs.pop('required', None)
1099+
10911100
kwargs.update(extra_kwargs)
10921101

10931102
return kwargs

rest_framework/views.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from django.utils.translation import ugettext_lazy as _
1010
from django.views.decorators.csrf import csrf_exempt
1111
from rest_framework import status, exceptions
12-
from rest_framework.compat import HttpResponseBase, View
12+
from rest_framework.compat import HttpResponseBase, View, set_rollback
1313
from rest_framework.request import Request
1414
from rest_framework.response import Response
1515
from rest_framework.settings import api_settings
@@ -71,16 +71,21 @@ def exception_handler(exc, context):
7171
else:
7272
data = {'detail': exc.detail}
7373

74+
set_rollback()
7475
return Response(data, status=exc.status_code, headers=headers)
7576

7677
elif isinstance(exc, Http404):
7778
msg = _('Not found.')
7879
data = {'detail': six.text_type(msg)}
80+
81+
set_rollback()
7982
return Response(data, status=status.HTTP_404_NOT_FOUND)
8083

8184
elif isinstance(exc, PermissionDenied):
8285
msg = _('Permission denied.')
8386
data = {'detail': six.text_type(msg)}
87+
88+
set_rollback()
8489
return Response(data, status=status.HTTP_403_FORBIDDEN)
8590

8691
# Note: Unhandled exceptions will raise a 500 error.

tests/test_atomic_requests.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from __future__ import unicode_literals
2+
3+
from django.db import connection, connections, transaction
4+
from django.test import TestCase
5+
from django.utils.unittest import skipUnless
6+
from rest_framework import status
7+
from rest_framework.exceptions import APIException
8+
from rest_framework.response import Response
9+
from rest_framework.test import APIRequestFactory
10+
from rest_framework.views import APIView
11+
from tests.models import BasicModel
12+
13+
14+
factory = APIRequestFactory()
15+
16+
17+
class BasicView(APIView):
18+
def post(self, request, *args, **kwargs):
19+
BasicModel.objects.create()
20+
return Response({'method': 'GET'})
21+
22+
23+
class ErrorView(APIView):
24+
def post(self, request, *args, **kwargs):
25+
BasicModel.objects.create()
26+
raise Exception
27+
28+
29+
class APIExceptionView(APIView):
30+
def post(self, request, *args, **kwargs):
31+
BasicModel.objects.create()
32+
raise APIException
33+
34+
35+
@skipUnless(connection.features.uses_savepoints,
36+
"'atomic' requires transactions and savepoints.")
37+
class DBTransactionTests(TestCase):
38+
def setUp(self):
39+
self.view = BasicView.as_view()
40+
connections.databases['default']['ATOMIC_REQUESTS'] = True
41+
42+
def tearDown(self):
43+
connections.databases['default']['ATOMIC_REQUESTS'] = False
44+
45+
def test_no_exception_conmmit_transaction(self):
46+
request = factory.post('/')
47+
48+
with self.assertNumQueries(1):
49+
response = self.view(request)
50+
self.assertFalse(transaction.get_rollback())
51+
self.assertEqual(response.status_code, status.HTTP_200_OK)
52+
assert BasicModel.objects.count() == 1
53+
54+
55+
@skipUnless(connection.features.uses_savepoints,
56+
"'atomic' requires transactions and savepoints.")
57+
class DBTransactionErrorTests(TestCase):
58+
def setUp(self):
59+
self.view = ErrorView.as_view()
60+
connections.databases['default']['ATOMIC_REQUESTS'] = True
61+
62+
def tearDown(self):
63+
connections.databases['default']['ATOMIC_REQUESTS'] = False
64+
65+
def test_generic_exception_delegate_transaction_management(self):
66+
"""
67+
Transaction is eventually managed by outer-most transaction atomic
68+
block. DRF do not try to interfere here.
69+
70+
We let django deal with the transaction when it will catch the Exception.
71+
"""
72+
request = factory.post('/')
73+
with self.assertNumQueries(3):
74+
# 1 - begin savepoint
75+
# 2 - insert
76+
# 3 - release savepoint
77+
with transaction.atomic():
78+
self.assertRaises(Exception, self.view, request)
79+
self.assertFalse(transaction.get_rollback())
80+
assert BasicModel.objects.count() == 1
81+
82+
83+
@skipUnless(connection.features.uses_savepoints,
84+
"'atomic' requires transactions and savepoints.")
85+
class DBTransactionAPIExceptionTests(TestCase):
86+
def setUp(self):
87+
self.view = APIExceptionView.as_view()
88+
connections.databases['default']['ATOMIC_REQUESTS'] = True
89+
90+
def tearDown(self):
91+
connections.databases['default']['ATOMIC_REQUESTS'] = False
92+
93+
def test_api_exception_rollback_transaction(self):
94+
"""
95+
Transaction is rollbacked by our transaction atomic block.
96+
"""
97+
request = factory.post('/')
98+
num_queries = (4 if getattr(connection.features,
99+
'can_release_savepoints', False) else 3)
100+
with self.assertNumQueries(num_queries):
101+
# 1 - begin savepoint
102+
# 2 - insert
103+
# 3 - rollback savepoint
104+
# 4 - release savepoint (django>=1.8 only)
105+
with transaction.atomic():
106+
response = self.view(request)
107+
self.assertTrue(transaction.get_rollback())
108+
self.assertEqual(response.status_code,
109+
status.HTTP_500_INTERNAL_SERVER_ERROR)
110+
assert BasicModel.objects.count() == 0

tests/test_fields.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from decimal import Decimal
22
from django.utils import timezone
33
from rest_framework import serializers
4+
import rest_framework
45
import datetime
56
import django
67
import pytest
@@ -221,6 +222,14 @@ def test_invalid_error_key(self):
221222
assert str(exc_info.value) == expected
222223

223224

225+
class MockHTMLDict(dict):
226+
"""
227+
This class mocks up a dictionary like object, that behaves
228+
as if it was returned for multipart or urlencoded data.
229+
"""
230+
getlist = None
231+
232+
224233
class TestBooleanHTMLInput:
225234
def setup(self):
226235
class TestSerializer(serializers.Serializer):
@@ -234,21 +243,11 @@ def test_empty_html_checkbox(self):
234243
"""
235244
# This class mocks up a dictionary like object, that behaves
236245
# as if it was returned for multipart or urlencoded data.
237-
class MockHTMLDict(dict):
238-
getlist = None
239246
serializer = self.Serializer(data=MockHTMLDict())
240247
assert serializer.is_valid()
241248
assert serializer.validated_data == {'archived': False}
242249

243250

244-
class MockHTMLDict(dict):
245-
"""
246-
This class mocks up a dictionary like object, that behaves
247-
as if it was returned for multipart or urlencoded data.
248-
"""
249-
getlist = None
250-
251-
252251
class TestHTMLInput:
253252
def test_empty_html_charfield(self):
254253
class TestSerializer(serializers.Serializer):
@@ -905,6 +904,29 @@ class TestNoOutputFormatTimeField(FieldValues):
905904
field = serializers.TimeField(format=None)
906905

907906

907+
@pytest.mark.skipif(django.VERSION < (1, 8),
908+
reason='DurationField is only available for django1.8+')
909+
class TestDurationField(FieldValues):
910+
"""
911+
Valid and invalid values for `DurationField`.
912+
"""
913+
valid_inputs = {
914+
'13': datetime.timedelta(seconds=13),
915+
'3 08:32:01.000123': datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
916+
'08:01': datetime.timedelta(minutes=8, seconds=1),
917+
datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
918+
}
919+
invalid_inputs = {
920+
'abc': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
921+
'3 08:32 01.123': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
922+
}
923+
outputs = {
924+
datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): '3 08:32:01.000123',
925+
}
926+
if django.VERSION >= (1, 8):
927+
field = serializers.DurationField()
928+
929+
908930
# Choice types...
909931

910932
class TestChoiceField(FieldValues):
@@ -1017,6 +1039,15 @@ class TestMultipleChoiceField(FieldValues):
10171039
]
10181040
)
10191041

1042+
def test_against_partial_and_full_updates(self):
1043+
# serializer = self.Serializer(data=MockHTMLDict())
1044+
from django.http import QueryDict
1045+
field = serializers.MultipleChoiceField(choices=(('a', 'a'), ('b', 'b')))
1046+
field.partial = False
1047+
assert field.get_value(QueryDict({})) == []
1048+
field.partial = True
1049+
assert field.get_value(QueryDict({})) == rest_framework.fields.empty
1050+
10201051

10211052
# File serializers...
10221053

0 commit comments

Comments
 (0)