Skip to content

Commit 97ea843

Browse files
committed
get transactions tests working
1 parent b676446 commit 97ea843

File tree

2 files changed

+73
-26
lines changed

2 files changed

+73
-26
lines changed

test/unified_format.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030
from bson.py3compat import abc, integer_types, iteritems, text_type
3131
from bson.regex import Regex, RE_TYPE
3232

33+
from gridfs import GridFSBucket
34+
3335
from pymongo import ASCENDING, MongoClient
34-
from pymongo.client_session import ClientSession, TransactionOptions
36+
from pymongo.client_session import ClientSession, TransactionOptions, _TxnState
3537
from pymongo.change_stream import ChangeStream
3638
from pymongo.collection import Collection
3739
from pymongo.cursor import Cursor
3840
from pymongo.database import Database
39-
from pymongo.errors import BulkWriteError, PyMongoError
41+
from pymongo.errors import BulkWriteError, InvalidOperation, PyMongoError
4042
from pymongo.monitoring import (
4143
CommandFailedEvent, CommandListener, CommandStartedEvent,
4244
CommandSucceededEvent)
@@ -162,6 +164,7 @@ class EntityMapUtil(object):
162164
def __init__(self, test_class):
163165
self._entities = {}
164166
self._listeners = {}
167+
self._session_lsids = {}
165168
self._test_class = test_class
166169

167170
def __getitem__(self, item):
@@ -236,15 +239,17 @@ def _create_entity(self, entity_spec):
236239
txn_opts = parse_spec_options(
237240
opts['default_transaction_options'])
238241
txn_opts = TransactionOptions(**txn_opts)
242+
opts = copy.deepcopy(opts)
239243
opts['default_transaction_options'] = txn_opts
240244
session = client.start_session(**dict(opts))
241245
self[spec['id']] = session
242-
self._test_class.addCleanup(session.end_session())
246+
self._session_lsids[spec['id']] = copy.deepcopy(session.session_id)
247+
self._test_class.addCleanup(session.end_session)
243248
return
244-
# elif ...
245-
# TODO
246-
# Implement the following entity types:
247-
# - bucket
249+
elif entity_type == 'bucket':
250+
# TODO: implement the 'bucket' entity type
251+
self._test_class.skipTest(
252+
'GridFS entity types are not currently supported.')
248253
self._test_class.fail(
249254
'Unable to create entity of unknown type %s' % (entity_type,))
250255

@@ -266,6 +271,19 @@ def get_listener_for_client(self, client_name):
266271

267272
return listener
268273

274+
def get_lsid_for_session(self, session_name):
275+
session = self[session_name]
276+
if not isinstance(session, ClientSession):
277+
self._test_class.fail(
278+
'Expected entity %s to be of type ClientSession, got %s' % (
279+
session_name, type(session)))
280+
281+
try:
282+
return session.session_id
283+
except InvalidOperation:
284+
# session has been closed.
285+
return self._session_lsids[session_name]
286+
269287

270288
BSON_TYPE_ALIAS_MAP = {
271289
# https://docs.mongodb.com/manual/reference/operator/query/type/
@@ -333,8 +351,9 @@ def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
333351
return
334352
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
335353

336-
def _operation_sessionLsid(self, spec, actual):
337-
raise NotImplementedError
354+
def _operation_sessionLsid(self, spec, actual, key_to_compare):
355+
expected_lsid = self._test_class.entity_map.get_lsid_for_session(spec)
356+
self._test_class.assertEqual(expected_lsid, actual[key_to_compare])
338357

339358
def _evaluate_special_operation(self, opname, spec, actual,
340359
key_to_compare):
@@ -552,7 +571,7 @@ def process_error(self, exception, spec):
552571
self.assertNotIsInstance(exception, PyMongoError)
553572

554573
if error_contains:
555-
raise RuntimeError
574+
raise NotImplementedError
556575

557576
if error_code:
558577
raise NotImplementedError
@@ -577,7 +596,8 @@ def process_error(self, exception, spec):
577596
exception)
578597
self.match_evaluator.match_result(expect_result, result)
579598
else:
580-
raise NotImplementedError
599+
self.fail("expectResult can only be specified with %s "
600+
"exceptions" % (BulkWriteError,))
581601

582602
def __raise_if_unsupported(self, opname, target, *target_types):
583603
if not isinstance(target, target_types):
@@ -645,7 +665,7 @@ def _collectionOperation_insertOne(self, target, *args, **kwargs):
645665

646666
def _sessionOperation_withTransaction(self, target, *args, **kwargs):
647667
self.__raise_if_unsupported('withTransaction', target, ClientSession)
648-
raise NotImplementedError
668+
return target.with_transaction(*args, **kwargs)
649669

650670
def _changeStreamOperation_iterateUntilDocumentOrError(self, target,
651671
*args, **kwargs):
@@ -660,8 +680,8 @@ def run_entity_operation(self, spec):
660680
expect_error = spec.get('expectError')
661681
if opargs:
662682
arguments = parse_spec_options(copy.deepcopy(opargs))
663-
prepare_spec_arguments(spec, arguments, opname, self.entity_map,
664-
None)
683+
prepare_spec_arguments(spec, arguments, camel_to_snake(opname),
684+
self.entity_map, self.run_operations)
665685
else:
666686
arguments = tuple()
667687

@@ -675,8 +695,8 @@ def run_entity_operation(self, spec):
675695
method_name = '_changeStreamOperation_%s' % (opname,)
676696
elif isinstance(target, ClientSession):
677697
method_name = '_sessionOperation_%s' % (opname,)
678-
#elif isinstance(target, GridFSBucket):
679-
# method_name = ...
698+
elif isinstance(target, GridFSBucket):
699+
raise NotImplementedError
680700
else:
681701
method_name = 'doesNotExist'
682702

@@ -719,25 +739,43 @@ def _testOperation_targetedFailPoint(self, spec):
719739
raise NotImplementedError
720740

721741
def _testOperation_assertSessionTransactionState(self, spec):
722-
raise NotImplementedError
742+
session = self.entity_map[spec['session']]
743+
expected_state = getattr(_TxnState, spec['state'].upper())
744+
self.assertEqual(expected_state, session._transaction.state)
723745

724746
def _testOperation_assertSessionPinned(self, spec):
725-
raise NotImplementedError
747+
session = self.entity_map[spec['session']]
748+
self.assertIsNotNone(session._pinned_address)
726749

727750
def _testOperation_assertSessionUnpinned(self, spec):
728-
raise NotImplementedError
751+
session = self.entity_map[spec['session']]
752+
self.assertIsNone(session._pinned_address)
753+
754+
def __get_last_two_command_lsids(self, listener):
755+
cmd_started_events = []
756+
for event in reversed(listener.results):
757+
if isinstance(event, CommandStartedEvent):
758+
cmd_started_events.append(event)
759+
if len(cmd_started_events) < 2:
760+
self.fail('Needed 2 CommandStartedEvents to compare lsids, '
761+
'got %s' % (len(cmd_started_events)))
762+
return tuple([e.command['lsid'] for e in cmd_started_events][:2])
729763

730764
def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec):
731-
raise NotImplementedError
765+
listener = self.entity_map.get_listener_for_client(spec['client'])
766+
self.assertNotEqual(*self.__get_last_two_command_lsids(listener))
732767

733768
def _testOperation_assertSameLsidOnLastTwoCommands(self, spec):
734-
raise NotImplementedError
769+
listener = self.entity_map.get_listener_for_client(spec['client'])
770+
self.assertEqual(*self.__get_last_two_command_lsids(listener))
735771

736772
def _testOperation_assertSessionDirty(self, spec):
737-
raise NotImplementedError
773+
session = self.entity_map[spec['session']]
774+
self.assertTrue(session._server_session.dirty)
738775

739776
def _testOperation_assertSessionNotDirty(self, spec):
740-
raise NotImplementedError
777+
session = self.entity_map[spec['session']]
778+
return self.assertFalse(session._server_session.dirty)
741779

742780
def _testOperation_assertCollectionExists(self, spec):
743781
database_name = spec['databaseName']
@@ -754,10 +792,14 @@ def _testOperation_assertCollectionNotExists(self, spec):
754792
self.assertNotIn(collection_name, collection_name_list)
755793

756794
def _testOperation_assertIndexExists(self, spec):
757-
raise NotImplementedError
795+
collection = self.client[spec['databaseName']][spec['collectionName']]
796+
index_names = [idx['name'] for idx in collection.list_indexes()]
797+
self.assertIn(spec['indexName'], index_names)
758798

759799
def _testOperation_assertIndexNotExists(self, spec):
760-
raise NotImplementedError
800+
collection = self.client[spec['databaseName']][spec['collectionName']]
801+
for index in collection.list_indexes():
802+
self.assertNotEqual(spec['indexName'], index['name'])
761803

762804
def run_special_operation(self, spec):
763805
opname = spec['name']

test/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,12 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map,
10531053
# camelCase maxTimeMS. See PYTHON-1855.
10541054
arguments['maxTimeMS'] = arguments.pop('max_time_ms')
10551055
elif opname == 'with_transaction' and arg_name == 'callback':
1056-
callback_ops = arguments[arg_name]['operations']
1056+
if 'operations' in arguments[arg_name]:
1057+
# CRUD v2 format
1058+
callback_ops = arguments[arg_name]['operations']
1059+
else:
1060+
# Unified test format
1061+
callback_ops = arguments[arg_name]
10571062
arguments['callback'] = lambda _: with_txn_callback(
10581063
copy.deepcopy(callback_ops))
10591064
elif opname == 'drop_collection' and arg_name == 'collection':

0 commit comments

Comments
 (0)