Skip to content

Commit ac3e849

Browse files
committed
Support atomic transaction views in multiple database connections
1 parent 0618fa8 commit ac3e849

File tree

1 file changed

+80
-21
lines changed

1 file changed

+80
-21
lines changed

tests/test_atomic_requests.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,35 @@
1616

1717

1818
class BasicView(APIView):
19+
database = 'default'
20+
21+
def get_queryset(self):
22+
return BasicModel.objects.using(self.database).all()
23+
1924
def post(self, request, *args, **kwargs):
20-
BasicModel.objects.create()
25+
self.get_queryset().create()
2126
return Response({'method': 'GET'})
2227

2328

24-
class ErrorView(APIView):
29+
class ErrorView(BasicView):
2530
def post(self, request, *args, **kwargs):
26-
BasicModel.objects.create()
31+
self.get_queryset().create()
2732
raise Exception
2833

2934

30-
class APIExceptionView(APIView):
35+
class APIExceptionView(BasicView):
3136
def post(self, request, *args, **kwargs):
32-
BasicModel.objects.create()
37+
self.get_queryset().create()
3338
raise APIException
3439

3540

36-
class NonAtomicAPIExceptionView(APIView):
41+
class NonAtomicAPIExceptionView(BasicView):
3742
@transaction.non_atomic_requests
3843
def dispatch(self, *args, **kwargs):
3944
return super().dispatch(*args, **kwargs)
4045

4146
def get(self, request, *args, **kwargs):
42-
BasicModel.objects.all()
47+
self.get_queryset()
4348
raise Http404
4449

4550

@@ -53,34 +58,52 @@ def get(self, request, *args, **kwargs):
5358
"'atomic' requires transactions and savepoints."
5459
)
5560
class DBTransactionTests(TestCase):
61+
databases = '__all__'
62+
5663
def setUp(self):
57-
self.view = BasicView.as_view()
58-
connections.databases['default']['ATOMIC_REQUESTS'] = True
64+
self.view = BasicView
65+
for database in connections.databases:
66+
connections.databases[database]['ATOMIC_REQUESTS'] = True
5967

6068
def tearDown(self):
61-
connections.databases['default']['ATOMIC_REQUESTS'] = False
69+
for database in connections.databases:
70+
connections.databases[database]['ATOMIC_REQUESTS'] = False
6271

6372
def test_no_exception_commit_transaction(self):
6473
request = factory.post('/')
6574

6675
with self.assertNumQueries(1):
67-
response = self.view(request)
76+
response = self.view.as_view()(request)
6877
assert not transaction.get_rollback()
6978
assert response.status_code == status.HTTP_200_OK
7079
assert BasicModel.objects.count() == 1
7180

81+
def test_no_exception_commit_transaction_spare_connection(self):
82+
request = factory.post('/')
83+
84+
with self.assertNumQueries(1, using='spare'):
85+
view = self.view.as_view(database='spare')
86+
response = view(request)
87+
assert not transaction.get_rollback(using='spare')
88+
assert response.status_code == status.HTTP_200_OK
89+
assert BasicModel.objects.using('spare').count() == 1
90+
7291

7392
@unittest.skipUnless(
7493
connection.features.uses_savepoints,
7594
"'atomic' requires transactions and savepoints."
7695
)
7796
class DBTransactionErrorTests(TestCase):
97+
databases = '__all__'
98+
7899
def setUp(self):
79-
self.view = ErrorView.as_view()
80-
connections.databases['default']['ATOMIC_REQUESTS'] = True
100+
self.view = ErrorView
101+
for database in connections.databases:
102+
connections.databases[database]['ATOMIC_REQUESTS'] = True
81103

82104
def tearDown(self):
83-
connections.databases['default']['ATOMIC_REQUESTS'] = False
105+
for database in connections.databases:
106+
connections.databases[database]['ATOMIC_REQUESTS'] = False
84107

85108
def test_generic_exception_delegate_transaction_management(self):
86109
"""
@@ -95,22 +118,37 @@ def test_generic_exception_delegate_transaction_management(self):
95118
# 2 - insert
96119
# 3 - release savepoint
97120
with transaction.atomic():
98-
self.assertRaises(Exception, self.view, request)
121+
self.assertRaises(Exception, self.view.as_view(), request)
99122
assert not transaction.get_rollback()
100123
assert BasicModel.objects.count() == 1
101124

125+
def test_generic_exception_delegate_transaction_management_spare_connections(self):
126+
request = factory.post('/')
127+
with self.assertNumQueries(3, using='spare'):
128+
# 1 - begin savepoint
129+
# 2 - insert
130+
# 3 - release savepoint
131+
with transaction.atomic(using='spare'):
132+
self.assertRaises(Exception, self.view.as_view(database='spare'), request)
133+
assert not transaction.get_rollback(using='spare')
134+
assert BasicModel.objects.using('spare').count() == 1
135+
102136

103137
@unittest.skipUnless(
104138
connection.features.uses_savepoints,
105139
"'atomic' requires transactions and savepoints."
106140
)
107141
class DBTransactionAPIExceptionTests(TestCase):
142+
databases = '__all__'
143+
108144
def setUp(self):
109-
self.view = APIExceptionView.as_view()
110-
connections.databases['default']['ATOMIC_REQUESTS'] = True
145+
self.view = APIExceptionView
146+
for database in connections.databases:
147+
connections.databases[database]['ATOMIC_REQUESTS'] = True
111148

112149
def tearDown(self):
113-
connections.databases['default']['ATOMIC_REQUESTS'] = False
150+
for database in connections.databases:
151+
connections.databases[database]['ATOMIC_REQUESTS'] = False
114152

115153
def test_api_exception_rollback_transaction(self):
116154
"""
@@ -124,11 +162,28 @@ def test_api_exception_rollback_transaction(self):
124162
# 3 - rollback savepoint
125163
# 4 - release savepoint
126164
with transaction.atomic():
127-
response = self.view(request)
165+
response = self.view.as_view()(request)
128166
assert transaction.get_rollback()
129167
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
130168
assert BasicModel.objects.count() == 0
131169

170+
def test_api_exception_rollback_transaction_spare_connection(self):
171+
"""
172+
Transaction is rollbacked by our transaction atomic block.
173+
"""
174+
request = factory.post('/')
175+
num_queries = 4 if connections['spare'].features.can_release_savepoints else 3
176+
with self.assertNumQueries(num_queries, using='spare'):
177+
# 1 - begin savepoint
178+
# 2 - insert
179+
# 3 - rollback savepoint
180+
# 4 - release savepoint
181+
with transaction.atomic(using='spare'):
182+
response = self.view.as_view(database='spare')(request)
183+
assert transaction.get_rollback(using='spare')
184+
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
185+
assert BasicModel.objects.using('spare').count() == 0
186+
132187

133188
@unittest.skipUnless(
134189
connection.features.uses_savepoints,
@@ -171,11 +226,15 @@ def test_api_exception_rollback_transaction(self):
171226
)
172227
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
173228
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
229+
databases = '__all__'
230+
174231
def setUp(self):
175-
connections.databases['default']['ATOMIC_REQUESTS'] = True
232+
for database in connections.databases:
233+
connections.databases[database]['ATOMIC_REQUESTS'] = True
176234

177235
def tearDown(self):
178-
connections.databases['default']['ATOMIC_REQUESTS'] = False
236+
for database in connections.databases:
237+
connections.databases[database]['ATOMIC_REQUESTS'] = False
179238

180239
def test_api_exception_rollback_transaction_non_atomic_view(self):
181240
response = self.client.get('/')

0 commit comments

Comments
 (0)