Skip to content

Commit f95b55a

Browse files
committed
fixup change streams 1
1 parent b2f3526 commit f95b55a

File tree

2 files changed

+175
-29
lines changed

2 files changed

+175
-29
lines changed

test/unified_format.py

Lines changed: 171 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,22 @@
1818
"""
1919

2020
import copy
21+
import datetime
22+
import functools
2123
import os
2224
import sys
2325
import types
2426

2527
from bson import json_util, SON
28+
from bson.binary import Binary
29+
from bson.objectid import ObjectId
2630
from bson.py3compat import abc, iteritems, text_type
31+
from bson.regex import Regex, RE_TYPE
2732

2833
from pymongo import ASCENDING, MongoClient
34+
from pymongo.client_session import ClientSession, TransactionOptions
35+
from pymongo.change_stream import ChangeStream
36+
from pymongo.collection import Collection
2937
from pymongo.cursor import Cursor
3038
from pymongo.database import Database
3139
from pymongo.monitoring import (
@@ -41,7 +49,7 @@
4149
snake_to_camel, ScenarioDict)
4250

4351
from test.version import Version
44-
from test.utils import parse_spec_options, prepare_spec_arguments
52+
from test.utils import camel_to_snake_args, parse_spec_options, prepare_spec_arguments
4553

4654

4755
JSON_OPTS = json_util.JSONOptions(tz_aware=False)
@@ -188,10 +196,24 @@ def _create_entity(self, entity_spec):
188196
spec['database'], type(database)))
189197
self[spec['id']] = database.get_collection(spec['collectionName'])
190198
return
199+
elif entity_type == 'session':
200+
client = self[spec['client']]
201+
if not isinstance(client, MongoClient):
202+
self._test_class.fail(
203+
'Expected entity %s to be of type MongoClient, got %s' % (
204+
spec['client'], type(client)))
205+
opts = camel_to_snake_args(spec['sessionOptions'])
206+
if 'default_transaction_options' in opts:
207+
txn_opts = parse_spec_options(
208+
opts['default_transaction_options'])
209+
txn_opts = TransactionOptions(**txn_opts)
210+
opts['default_transaction_options'] = txn_opts
211+
session = client.start_session(**dict(opts))
212+
self[spec['id']] = session
213+
self._test_class.addCleanup(session.end_session())
191214
# elif ...
192215
# TODO
193216
# Implement the following entity types:
194-
# - session
195217
# - bucket
196218
self._test_class.fail(
197219
'Unable to create entity of unknown type %s' % (entity_type,))
@@ -215,6 +237,23 @@ def get_listener_for_client(self, client_name):
215237
return listener
216238

217239

240+
BSON_TYPE_ALIAS_MAP = {
241+
# https://docs.mongodb.com/manual/reference/operator/query/type/
242+
# https://pymongo.readthedocs.io/en/stable/api/bson/index.html
243+
'double': float,
244+
'string': text_type,
245+
'object': abc.Mapping,
246+
'array': abc.Sequence,
247+
'binData': (Binary, bytes),
248+
'objectId': ObjectId,
249+
'bool': bool,
250+
'date': datetime.datetime,
251+
'null': type(None),
252+
'regex': (Regex, RE_TYPE),
253+
# TODO: add all supported types
254+
}
255+
256+
218257
class MatchEvaluatorUtil(object):
219258
"""Utility class that implements methods for evaluating matches as per
220259
the unified test format specification."""
@@ -225,7 +264,9 @@ def _operation_exists(self, spec, actual):
225264
raise NotImplementedError
226265

227266
def _operation_type(self, spec, actual):
228-
raise NotImplementedError
267+
if spec not in BSON_TYPE_ALIAS_MAP:
268+
self._test_class.fail('Unrecognized BSON type alias %s' % (spec,))
269+
self._test_class.assertIsInstance(actual, BSON_TYPE_ALIAS_MAP[spec])
229270

230271
def _operation_matchesEntity(self, spec, actual):
231272
raise NotImplementedError
@@ -260,7 +301,7 @@ def _evaluate_if_special_operation(self, expectation, actual):
260301
return True
261302
return False
262303

263-
def _match_document(self, expectation, actual):
304+
def _match_document(self, expectation, actual, is_root):
264305
if self._evaluate_if_special_operation(expectation, actual):
265306
return
266307

@@ -270,20 +311,24 @@ def _match_document(self, expectation, actual):
270311
continue
271312

272313
self._test_class.assertIn(key, actual)
273-
self.match_result(value, actual[key])
314+
self.match_result(value, actual[key], is_root=False)
274315

275-
# TODO: handle if expected is not root
276-
return
316+
if not is_root:
317+
self._test_class.assertEqual(
318+
set(expectation.keys()), set(actual.keys()))
277319

278320
def _match_array(self, expectation, actual):
279321
self._test_class.assertIsInstance(actual, abc.Iterable)
280322

281323
for e, a in zip(expectation, actual):
282324
self.match_result(e, a)
283325

284-
def match_result(self, expectation, actual):
326+
def match_result(self, expectation, actual, is_root=True):
327+
if expectation is None:
328+
return
329+
285330
if isinstance(expectation, abc.Mapping):
286-
return self._match_document(expectation, actual)
331+
return self._match_document(expectation, actual, is_root=is_root)
287332

288333
if isinstance(expectation, abc.MutableSequence):
289334
return self._match_array(expectation, actual)
@@ -430,15 +475,111 @@ def process_error(self, exception, spec):
430475
if expect_result:
431476
raise NotImplementedError
432477

478+
def __raise_if_unsupported(self, opname, target, *target_types):
479+
if not isinstance(target, target_types):
480+
self.fail('Operation %s not supported for entity '
481+
'of type %s' % (opname, type(target)))
482+
483+
def __entityOperation_createChangeStream(self, target, *args, **kwargs):
484+
self.__raise_if_unsupported(
485+
'createChangeStream', target, MongoClient, Database, Collection)
486+
return target.watch(*args, **kwargs)
487+
488+
def _clientOperation_createChangeStream(self, target, *args, **kwargs):
489+
return self.__entityOperation_createChangeStream(
490+
target, *args, **kwargs)
491+
492+
def _databaseOperation_createChangeStream(self, target, *args, **kwargs):
493+
return self.__entityOperation_createChangeStream(
494+
target, *args, **kwargs)
495+
496+
def _collectionOperation_createChangeStream(self, target, *args, **kwargs):
497+
return self.__entityOperation_createChangeStream(
498+
target, *args, **kwargs)
499+
500+
def _databaseOperation_runCommand(self, target, *args, **kwargs):
501+
self.__raise_if_unsupported('runCommand', target, Database)
502+
return target.command(*args, **kwargs)
503+
504+
def _collectionOperation_aggregate(self, target, *args, **kwargs):
505+
self.__raise_if_unsupported('aggregate', target, Collection)
506+
agg_cursor = target.aggregate(*args, **kwargs)
507+
return list(agg_cursor)
508+
509+
def _collectionOperation_bulkWrite(self, target, *args, **kwargs):
510+
self.__raise_if_unsupported('bulkWrite', target, Collection)
511+
raise NotImplementedError
512+
513+
def _collectionOperation_find(self, target, *args, **kwargs):
514+
self.__raise_if_unsupported('find', target, Collection)
515+
find_cursor = target.find(*args, **kwargs)
516+
return list(find_cursor)
517+
518+
def _collectionOperation_findOneAndReplace(self, target, *args, **kwargs):
519+
self.__raise_if_unsupported('findOneAndReplace', target, Collection)
520+
find_cursor = target.find_one_and_replace(*args, **kwargs)
521+
return list(find_cursor)
522+
523+
def _collectionOperation_findOneAndUpdate(self, target, *args, **kwargs):
524+
self.__raise_if_unsupported('findOneAndReplace', target, Collection)
525+
find_cursor = target.find_one_and_update(*args, **kwargs)
526+
return list(find_cursor)
527+
528+
def _collectionOperation_insertMany(self, target, *args, **kwargs):
529+
self.__raise_if_unsupported('insertMany', target, Collection)
530+
return target.insert_many(*args, **kwargs)
531+
532+
def _collectionOperation_insertOne(self, target, *args, **kwargs):
533+
self.__raise_if_unsupported('insertOne', target, Collection)
534+
return target.insert_one(*args, **kwargs)
535+
536+
def _sessionOperation_withTransaction(self, target, *args, **kwargs):
537+
self.__raise_if_unsupported('withTransaction', target, ClientSession)
538+
raise NotImplementedError
539+
540+
def _changeStreamOperation_iterateUntilDocumentOrError(self, target,
541+
*args, **kwargs):
542+
self.__raise_if_unsupported(
543+
'iterateUntilDocumentOrError', target, ChangeStream)
544+
return next(target)
545+
433546
def run_entity_operation(self, spec):
434547
target = self.entity_map[spec['object']]
435-
opname = camel_to_snake(spec['name'])
548+
opname = spec['name']
436549
opargs = spec.get('arguments')
437550
expect_error = spec.get('expectError')
438-
arguments = parse_spec_options(copy.deepcopy(opargs))
439-
cmd = getattr(target, opname)
440-
prepare_spec_arguments(spec, arguments, opname, self.entity_map,
441-
None)
551+
if opargs:
552+
arguments = parse_spec_options(copy.deepcopy(opargs))
553+
prepare_spec_arguments(spec, arguments, opname, self.entity_map,
554+
None)
555+
else:
556+
arguments = tuple()
557+
558+
if isinstance(target, MongoClient):
559+
method_name = '_clientOperation_%s' % (opname,)
560+
elif isinstance(target, Database):
561+
method_name = '_databaseOperation_%s' % (opname,)
562+
elif isinstance(target, Collection):
563+
method_name = '_collectionOperation_%s' % (opname,)
564+
elif isinstance(target, ChangeStream):
565+
method_name = '_changeStreamOperation_%s' % (opname,)
566+
elif isinstance(target, ClientSession):
567+
method_name = '_sessionOperation_%s' % (opname,)
568+
#elif isinstance(target, GridFSBucket):
569+
# method_name = ...
570+
else:
571+
method_name = 'doesNotExist'
572+
573+
try:
574+
method = getattr(self, method_name)
575+
except AttributeError:
576+
try:
577+
cmd = getattr(target, camel_to_snake(opname))
578+
except AttributeError:
579+
self.fail('Unsupported operation %s on entity %s' % (
580+
opname, target))
581+
else:
582+
cmd = functools.partial(method, target)
442583

443584
try:
444585
result = cmd(**dict(arguments))
@@ -457,7 +598,7 @@ def run_entity_operation(self, spec):
457598
if save_as_entity:
458599
self.entity_map[save_as_entity] = result
459600

460-
def _operation_failPoint(self, spec):
601+
def _testOperation_failPoint(self, spec):
461602
client = self.entity_map[spec['client']]
462603
command_args = spec['failPoint']
463604
cmd_on = SON([('configureFailPoint', 'failCommand')])
@@ -467,53 +608,53 @@ def _operation_failPoint(self, spec):
467608
client.admin.command,
468609
'configureFailPoint', cmd_on['configureFailPoint'], mode='off')
469610

470-
def _operation_targetedFailPoint(self, spec):
611+
def _testOperation_targetedFailPoint(self, spec):
471612
raise NotImplementedError
472613

473-
def _operation_assertSessionTransactionState(self, spec):
614+
def _testOperation_assertSessionTransactionState(self, spec):
474615
raise NotImplementedError
475616

476-
def _operation_assertSessionPinned(self, spec):
617+
def _testOperation_assertSessionPinned(self, spec):
477618
raise NotImplementedError
478619

479-
def _operation_assertSessionUnpinned(self, spec):
620+
def _testOperation_assertSessionUnpinned(self, spec):
480621
raise NotImplementedError
481622

482-
def _operation_assertDifferentLsidOnLastTwoCommands(self, spec):
623+
def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec):
483624
raise NotImplementedError
484625

485-
def _operation_assertSameLsidOnLastTwoCommands(self, spec):
626+
def _testOperation_assertSameLsidOnLastTwoCommands(self, spec):
486627
raise NotImplementedError
487628

488-
def _operation_assertSessionDirty(self, spec):
629+
def _testOperation_assertSessionDirty(self, spec):
489630
raise NotImplementedError
490631

491-
def _operation_assertSessionNotDirty(self, spec):
632+
def _testOperation_assertSessionNotDirty(self, spec):
492633
raise NotImplementedError
493634

494-
def _operation_assertCollectionExists(self, spec):
635+
def _testOperation_assertCollectionExists(self, spec):
495636
database_name = spec['databaseName']
496637
collection_name = spec['collectionName']
497638
collection_name_list = list(
498639
self.client.get_database(database_name).list_collection_names())
499640
self.assertIn(collection_name, collection_name_list)
500641

501-
def _operation_assertCollectionNotExists(self, spec):
642+
def _testOperation_assertCollectionNotExists(self, spec):
502643
database_name = spec['databaseName']
503644
collection_name = spec['collectionName']
504645
collection_name_list = list(
505646
self.client.get_database(database_name).list_collection_names())
506647
self.assertNotIn(collection_name, collection_name_list)
507648

508-
def _operation_assertIndexExists(self, spec):
649+
def _testOperation_assertIndexExists(self, spec):
509650
raise NotImplementedError
510651

511-
def _operation_assertIndexNotExists(self, spec):
652+
def _testOperation_assertIndexNotExists(self, spec):
512653
raise NotImplementedError
513654

514655
def run_special_operation(self, spec):
515656
opname = spec['name']
516-
method_name = '_operation_%s' % (opname,)
657+
method_name = '_testOperation_%s' % (opname,)
517658
try:
518659
method = getattr(self, method_name)
519660
except AttributeError:
@@ -567,6 +708,7 @@ def verify_outcome(self, spec):
567708
actual_documents)
568709

569710
def run_scenario(self, spec):
711+
# import ipdb; ipdb.set_trace()
570712
# process test-level runOnRequirements
571713
run_on_spec = spec.get('runOnRequirements', [])
572714
if not self.should_run_on(run_on_spec):
@@ -600,7 +742,7 @@ def test_case(self):
600742
for test_spec in cls.TEST_SPEC['tests']:
601743
description = test_spec['description']
602744
test_name = 'test_%s' % (
603-
description.replace(' ', '_').replace('.', '_'),)
745+
description.strip('. ').replace(' ', '_').replace('.', '_'),)
604746
test_method = create_test(copy.deepcopy(test_spec))
605747
test_method.__name__ = test_name
606748
setattr(cls, test_name, test_method)

test/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,10 @@ def assertion_context(msg):
958958
py3compat.reraise(type(exc), msg, sys.exc_info()[2])
959959

960960

961+
class SpecParserUtil(object):
962+
pass
963+
964+
961965
def parse_spec_options(opts):
962966
if 'readPreference' in opts:
963967
opts['read_preference'] = parse_read_preference(

0 commit comments

Comments
 (0)