Skip to content

Support atomic requests in multiple database connections #7617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 80 additions & 21 deletions tests/test_atomic_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,35 @@


class BasicView(APIView):
database = 'default'

def get_queryset(self):
return BasicModel.objects.using(self.database).all()

def post(self, request, *args, **kwargs):
BasicModel.objects.create()
self.get_queryset().create()
return Response({'method': 'GET'})


class ErrorView(APIView):
class ErrorView(BasicView):
def post(self, request, *args, **kwargs):
BasicModel.objects.create()
self.get_queryset().create()
raise Exception


class APIExceptionView(APIView):
class APIExceptionView(BasicView):
def post(self, request, *args, **kwargs):
BasicModel.objects.create()
self.get_queryset().create()
raise APIException


class NonAtomicAPIExceptionView(APIView):
class NonAtomicAPIExceptionView(BasicView):
@transaction.non_atomic_requests
def dispatch(self, *args, **kwargs):
return super().dispatch(*args, **kwargs)

def get(self, request, *args, **kwargs):
BasicModel.objects.all()
self.get_queryset()
raise Http404


Expand All @@ -53,34 +58,52 @@ def get(self, request, *args, **kwargs):
"'atomic' requires transactions and savepoints."
)
class DBTransactionTests(TestCase):
databases = '__all__'

def setUp(self):
self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
self.view = BasicView
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_no_exception_commit_transaction(self):
request = factory.post('/')

with self.assertNumQueries(1):
response = self.view(request)
response = self.view.as_view()(request)
assert not transaction.get_rollback()
assert response.status_code == status.HTTP_200_OK
assert BasicModel.objects.count() == 1

def test_no_exception_commit_transaction_spare_connection(self):
request = factory.post('/')

with self.assertNumQueries(1, using='spare'):
view = self.view.as_view(database='spare')
response = view(request)
assert not transaction.get_rollback(using='spare')
assert response.status_code == status.HTTP_200_OK
assert BasicModel.objects.using('spare').count() == 1


@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
class DBTransactionErrorTests(TestCase):
databases = '__all__'

def setUp(self):
self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
self.view = ErrorView
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_generic_exception_delegate_transaction_management(self):
"""
Expand All @@ -95,22 +118,37 @@ def test_generic_exception_delegate_transaction_management(self):
# 2 - insert
# 3 - release savepoint
with transaction.atomic():
self.assertRaises(Exception, self.view, request)
self.assertRaises(Exception, self.view.as_view(), request)
assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1

def test_generic_exception_delegate_transaction_management_spare_connections(self):
request = factory.post('/')
with self.assertNumQueries(3, using='spare'):
# 1 - begin savepoint
# 2 - insert
# 3 - release savepoint
with transaction.atomic(using='spare'):
self.assertRaises(Exception, self.view.as_view(database='spare'), request)
assert not transaction.get_rollback(using='spare')
assert BasicModel.objects.using('spare').count() == 1


@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
class DBTransactionAPIExceptionTests(TestCase):
databases = '__all__'

def setUp(self):
self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
self.view = APIExceptionView
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_api_exception_rollback_transaction(self):
"""
Expand All @@ -124,11 +162,28 @@ def test_api_exception_rollback_transaction(self):
# 3 - rollback savepoint
# 4 - release savepoint
with transaction.atomic():
response = self.view(request)
response = self.view.as_view()(request)
assert transaction.get_rollback()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.count() == 0

def test_api_exception_rollback_transaction_spare_connection(self):
"""
Transaction is rollbacked by our transaction atomic block.
"""
request = factory.post('/')
num_queries = 4 if connections['spare'].features.can_release_savepoints else 3
with self.assertNumQueries(num_queries, using='spare'):
# 1 - begin savepoint
# 2 - insert
# 3 - rollback savepoint
# 4 - release savepoint
with transaction.atomic(using='spare'):
response = self.view.as_view(database='spare')(request)
assert transaction.get_rollback(using='spare')
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.using('spare').count() == 0


@unittest.skipUnless(
connection.features.uses_savepoints,
Expand Down Expand Up @@ -171,11 +226,15 @@ def test_api_exception_rollback_transaction(self):
)
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
databases = '__all__'

def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/')
Expand Down