30
30
from bson .py3compat import abc , integer_types , iteritems , text_type
31
31
from bson .regex import Regex , RE_TYPE
32
32
33
+ from gridfs import GridFSBucket
34
+
33
35
from pymongo import ASCENDING , MongoClient
34
- from pymongo .client_session import ClientSession , TransactionOptions
36
+ from pymongo .client_session import ClientSession , TransactionOptions , _TxnState
35
37
from pymongo .change_stream import ChangeStream
36
38
from pymongo .collection import Collection
37
39
from pymongo .cursor import Cursor
38
40
from pymongo .database import Database
39
- from pymongo .errors import BulkWriteError , PyMongoError
41
+ from pymongo .errors import BulkWriteError , InvalidOperation , PyMongoError
40
42
from pymongo .monitoring import (
41
43
CommandFailedEvent , CommandListener , CommandStartedEvent ,
42
44
CommandSucceededEvent )
@@ -162,6 +164,7 @@ class EntityMapUtil(object):
162
164
def __init__ (self , test_class ):
163
165
self ._entities = {}
164
166
self ._listeners = {}
167
+ self ._session_lsids = {}
165
168
self ._test_class = test_class
166
169
167
170
def __getitem__ (self , item ):
@@ -236,15 +239,17 @@ def _create_entity(self, entity_spec):
236
239
txn_opts = parse_spec_options (
237
240
opts ['default_transaction_options' ])
238
241
txn_opts = TransactionOptions (** txn_opts )
242
+ opts = copy .deepcopy (opts )
239
243
opts ['default_transaction_options' ] = txn_opts
240
244
session = client .start_session (** dict (opts ))
241
245
self [spec ['id' ]] = session
242
- self ._test_class .addCleanup (session .end_session ())
246
+ self ._session_lsids [spec ['id' ]] = copy .deepcopy (session .session_id )
247
+ self ._test_class .addCleanup (session .end_session )
243
248
return
244
- # elif ...
245
- # TODO
246
- # Implement the following entity types:
247
- # - bucket
249
+ elif entity_type == 'bucket' :
250
+ # TODO: implement the 'bucket' entity type
251
+ self . _test_class . skipTest (
252
+ 'GridFS entity types are not currently supported.' )
248
253
self ._test_class .fail (
249
254
'Unable to create entity of unknown type %s' % (entity_type ,))
250
255
@@ -266,6 +271,19 @@ def get_listener_for_client(self, client_name):
266
271
267
272
return listener
268
273
274
+ def get_lsid_for_session (self , session_name ):
275
+ session = self [session_name ]
276
+ if not isinstance (session , ClientSession ):
277
+ self ._test_class .fail (
278
+ 'Expected entity %s to be of type ClientSession, got %s' % (
279
+ session_name , type (session )))
280
+
281
+ try :
282
+ return session .session_id
283
+ except InvalidOperation :
284
+ # session has been closed.
285
+ return self ._session_lsids [session_name ]
286
+
269
287
270
288
BSON_TYPE_ALIAS_MAP = {
271
289
# https://docs.mongodb.com/manual/reference/operator/query/type/
@@ -333,8 +351,9 @@ def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
333
351
return
334
352
self .match_result (spec , actual [key_to_compare ], in_recursive_call = True )
335
353
336
- def _operation_sessionLsid (self , spec , actual ):
337
- raise NotImplementedError
354
+ def _operation_sessionLsid (self , spec , actual , key_to_compare ):
355
+ expected_lsid = self ._test_class .entity_map .get_lsid_for_session (spec )
356
+ self ._test_class .assertEqual (expected_lsid , actual [key_to_compare ])
338
357
339
358
def _evaluate_special_operation (self , opname , spec , actual ,
340
359
key_to_compare ):
@@ -552,7 +571,7 @@ def process_error(self, exception, spec):
552
571
self .assertNotIsInstance (exception , PyMongoError )
553
572
554
573
if error_contains :
555
- raise RuntimeError
574
+ raise NotImplementedError
556
575
557
576
if error_code :
558
577
raise NotImplementedError
@@ -577,7 +596,8 @@ def process_error(self, exception, spec):
577
596
exception )
578
597
self .match_evaluator .match_result (expect_result , result )
579
598
else :
580
- raise NotImplementedError
599
+ self .fail ("expectResult can only be specified with %s "
600
+ "exceptions" % (BulkWriteError ,))
581
601
582
602
def __raise_if_unsupported (self , opname , target , * target_types ):
583
603
if not isinstance (target , target_types ):
@@ -645,7 +665,7 @@ def _collectionOperation_insertOne(self, target, *args, **kwargs):
645
665
646
666
def _sessionOperation_withTransaction (self , target , * args , ** kwargs ):
647
667
self .__raise_if_unsupported ('withTransaction' , target , ClientSession )
648
- raise NotImplementedError
668
+ return target . with_transaction ( * args , ** kwargs )
649
669
650
670
def _changeStreamOperation_iterateUntilDocumentOrError (self , target ,
651
671
* args , ** kwargs ):
@@ -660,8 +680,8 @@ def run_entity_operation(self, spec):
660
680
expect_error = spec .get ('expectError' )
661
681
if opargs :
662
682
arguments = parse_spec_options (copy .deepcopy (opargs ))
663
- prepare_spec_arguments (spec , arguments , opname , self . entity_map ,
664
- None )
683
+ prepare_spec_arguments (spec , arguments , camel_to_snake ( opname ) ,
684
+ self . entity_map , self . run_operations )
665
685
else :
666
686
arguments = tuple ()
667
687
@@ -675,8 +695,8 @@ def run_entity_operation(self, spec):
675
695
method_name = '_changeStreamOperation_%s' % (opname ,)
676
696
elif isinstance (target , ClientSession ):
677
697
method_name = '_sessionOperation_%s' % (opname ,)
678
- # elif isinstance(target, GridFSBucket):
679
- # method_name = ...
698
+ elif isinstance (target , GridFSBucket ):
699
+ raise NotImplementedError
680
700
else :
681
701
method_name = 'doesNotExist'
682
702
@@ -719,25 +739,43 @@ def _testOperation_targetedFailPoint(self, spec):
719
739
raise NotImplementedError
720
740
721
741
def _testOperation_assertSessionTransactionState (self , spec ):
722
- raise NotImplementedError
742
+ session = self .entity_map [spec ['session' ]]
743
+ expected_state = getattr (_TxnState , spec ['state' ].upper ())
744
+ self .assertEqual (expected_state , session ._transaction .state )
723
745
724
746
def _testOperation_assertSessionPinned (self , spec ):
725
- raise NotImplementedError
747
+ session = self .entity_map [spec ['session' ]]
748
+ self .assertIsNotNone (session ._pinned_address )
726
749
727
750
def _testOperation_assertSessionUnpinned (self , spec ):
728
- raise NotImplementedError
751
+ session = self .entity_map [spec ['session' ]]
752
+ self .assertIsNone (session ._pinned_address )
753
+
754
+ def __get_last_two_command_lsids (self , listener ):
755
+ cmd_started_events = []
756
+ for event in reversed (listener .results ):
757
+ if isinstance (event , CommandStartedEvent ):
758
+ cmd_started_events .append (event )
759
+ if len (cmd_started_events ) < 2 :
760
+ self .fail ('Needed 2 CommandStartedEvents to compare lsids, '
761
+ 'got %s' % (len (cmd_started_events )))
762
+ return tuple ([e .command ['lsid' ] for e in cmd_started_events ][:2 ])
729
763
730
764
def _testOperation_assertDifferentLsidOnLastTwoCommands (self , spec ):
731
- raise NotImplementedError
765
+ listener = self .entity_map .get_listener_for_client (spec ['client' ])
766
+ self .assertNotEqual (* self .__get_last_two_command_lsids (listener ))
732
767
733
768
def _testOperation_assertSameLsidOnLastTwoCommands (self , spec ):
734
- raise NotImplementedError
769
+ listener = self .entity_map .get_listener_for_client (spec ['client' ])
770
+ self .assertEqual (* self .__get_last_two_command_lsids (listener ))
735
771
736
772
def _testOperation_assertSessionDirty (self , spec ):
737
- raise NotImplementedError
773
+ session = self .entity_map [spec ['session' ]]
774
+ self .assertTrue (session ._server_session .dirty )
738
775
739
776
def _testOperation_assertSessionNotDirty (self , spec ):
740
- raise NotImplementedError
777
+ session = self .entity_map [spec ['session' ]]
778
+ return self .assertFalse (session ._server_session .dirty )
741
779
742
780
def _testOperation_assertCollectionExists (self , spec ):
743
781
database_name = spec ['databaseName' ]
@@ -754,10 +792,14 @@ def _testOperation_assertCollectionNotExists(self, spec):
754
792
self .assertNotIn (collection_name , collection_name_list )
755
793
756
794
def _testOperation_assertIndexExists (self , spec ):
757
- raise NotImplementedError
795
+ collection = self .client [spec ['databaseName' ]][spec ['collectionName' ]]
796
+ index_names = [idx ['name' ] for idx in collection .list_indexes ()]
797
+ self .assertIn (spec ['indexName' ], index_names )
758
798
759
799
def _testOperation_assertIndexNotExists (self , spec ):
760
- raise NotImplementedError
800
+ collection = self .client [spec ['databaseName' ]][spec ['collectionName' ]]
801
+ for index in collection .list_indexes ():
802
+ self .assertNotEqual (spec ['indexName' ], index ['name' ])
761
803
762
804
def run_special_operation (self , spec ):
763
805
opname = spec ['name' ]
0 commit comments