Skip to content

Commit c292b3d

Browse files
committed
Fix set_rollback on @transaction.non_atomic_requests.
1 parent 6651432 commit c292b3d

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

rest_framework/compat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ def set_rollback():
274274
if connection.settings_dict.get('ATOMIC_REQUESTS', False):
275275
# If running in >=1.6 then mark a rollback as required,
276276
# and allow it to be handled by Django.
277-
transaction.set_rollback(True)
277+
if connection.in_atomic_block:
278+
transaction.set_rollback(True)
278279
elif transaction.is_managed():
279280
# Otherwise handle it explicitly if in managed mode.
280281
if transaction.is_dirty():

tests/test_atomic_requests.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from django.db import connection, connections, transaction
44
from django.test import TestCase
5+
from django.utils.decorators import method_decorator
56
from django.utils.unittest import skipUnless
67
from rest_framework import status
7-
from rest_framework.exceptions import APIException
8+
from rest_framework.exceptions import APIException, PermissionDenied
89
from rest_framework.response import Response
910
from rest_framework.test import APIRequestFactory
1011
from rest_framework.views import APIView
@@ -32,6 +33,16 @@ def post(self, request, *args, **kwargs):
3233
raise APIException
3334

3435

36+
class NonAtomicAPIExceptionView(APIView):
37+
@method_decorator(transaction.non_atomic_requests)
38+
def dispatch(self, *args, **kwargs):
39+
return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs)
40+
41+
def post(self, request, *args, **kwargs):
42+
BasicModel.objects.create()
43+
raise PermissionDenied
44+
45+
3546
@skipUnless(connection.features.uses_savepoints,
3647
"'atomic' requires transactions and savepoints.")
3748
class DBTransactionTests(TestCase):
@@ -108,3 +119,24 @@ def test_api_exception_rollback_transaction(self):
108119
self.assertEqual(response.status_code,
109120
status.HTTP_500_INTERNAL_SERVER_ERROR)
110121
assert BasicModel.objects.count() == 0
122+
123+
124+
@skipUnless(connection.features.uses_savepoints,
125+
"'atomic' requires transactions and savepoints.")
126+
class NonAtomicDBTransactionAPIExceptionTests(TestCase):
127+
def setUp(self):
128+
self.view = NonAtomicAPIExceptionView.as_view()
129+
connections.databases['default']['ATOMIC_REQUESTS'] = True
130+
131+
def tearDown(self):
132+
connections.databases['default']['ATOMIC_REQUESTS'] = False
133+
134+
def test_api_exception_rollback_transaction_non_atomic_view(self):
135+
request = factory.post('/')
136+
137+
response = self.view(request)
138+
139+
# without checking connection.in_atomic_block view raises 500
140+
# due attempt to rollback without transaction
141+
self.assertEqual(response.status_code,
142+
status.HTTP_403_FORBIDDEN)

0 commit comments

Comments
 (0)