Skip to content

Commit b676446

Browse files
committed
further fleshing out
1 parent 493563b commit b676446

File tree

3 files changed

+70
-29
lines changed

3 files changed

+70
-29
lines changed

pymongo/helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ def _check_write_command_response(result):
224224
write_errors = result.get("writeErrors")
225225
if write_errors:
226226
_raise_last_write_error(write_errors)
227-
228227
error = result.get("writeConcernError")
229228
if error:
230229
error_labels = result.get("errorLabels")

test/test_unified_format.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
sys.path[0:0] = [""]
1919

20+
from bson import ObjectId, Timestamp
21+
2022
from test import unittest
21-
from test.unified_format import generate_test_classes
23+
from test.unified_format import generate_test_classes, MatchEvaluatorUtil
2224

2325

2426
_TEST_PATH = os.path.join(
@@ -30,5 +32,28 @@
3032
class_name_prefix='UnifiedTestFormat'))
3133

3234

35+
class TestMatchEvaluatorUtil(unittest.TestCase):
36+
def setUp(self):
37+
self.match_evaluator = MatchEvaluatorUtil(self)
38+
39+
def test_unsetOrMatches(self):
40+
spec = {'$$unsetOrMatches': {'y': {'$$unsetOrMatches': 2}}}
41+
for actual in [{}, {'y': 2}, None]:
42+
self.match_evaluator.match_result(spec, actual)
43+
44+
spec = {'x': {'$$unsetOrMatches': {'y': {'$$unsetOrMatches': 2}}}}
45+
for actual in [{}, {'x': {}}, {'x': {'y': 2}}]:
46+
self.match_evaluator.match_result(spec, actual)
47+
48+
def test_type(self):
49+
self.match_evaluator.match_result(
50+
{'operationType': 'insert',
51+
'ns': {'db': 'change-stream-tests', 'coll': 'test'},
52+
'fullDocument': {'_id': {'$$type': 'objectId'}, 'x': 1}},
53+
{'operationType': 'insert',
54+
'fullDocument': {'_id': ObjectId('5fc93511ac93941052098f0c'), 'x': 1},
55+
'ns': {'db': 'change-stream-tests', 'coll': 'test'}})
56+
57+
3358
if __name__ == "__main__":
3459
unittest.main()

test/unified_format.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
CommandSucceededEvent)
4343
from pymongo.read_concern import ReadConcern
4444
from pymongo.read_preferences import ReadPreference
45-
from pymongo.results import BulkWriteResult
45+
from pymongo.results import BulkWriteResult, InsertManyResult, InsertOneResult
4646
from pymongo.write_concern import WriteConcern
4747

4848
from test import client_context, unittest, IntegrationTest
@@ -189,11 +189,13 @@ def _create_entity(self, entity_spec):
189189
# Add logic to respect the following fields
190190
# - uriOptions
191191
# - useMultipleMongoses
192+
uri_options = spec.get('uriOptions', {})
192193
observe_events = spec.get('observeEvents')
193194
ignore_commands = spec.get('ignoreCommandMonitoringEvents', [])
194195
if observe_events:
195196
listener = EventListenerUtil(observe_events, ignore_commands)
196-
client = rs_or_single_client(event_listeners=[listener])
197+
client = rs_or_single_client(
198+
event_listeners=[listener], **uri_options)
197199
else:
198200
listener = None
199201
client = rs_or_single_client()
@@ -321,11 +323,15 @@ def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
321323
raise NotImplementedError
322324

323325
def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
326+
if key_to_compare is None and not actual:
327+
# top-level document can be None when unset
328+
return
329+
324330
if key_to_compare not in actual:
325331
# we add a dummy value for the compared key to pass map size check
326-
actual[key_to_compare] = None
332+
actual[key_to_compare] = 'dummyValue'
327333
return
328-
self._test_class.assertEqual(spec, actual[key_to_compare])
334+
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
329335

330336
def _operation_sessionLsid(self, spec, actual):
331337
raise NotImplementedError
@@ -353,21 +359,30 @@ def _evaluate_if_special_operation(self, expectation, actual,
353359
if not isinstance(expectation, abc.Mapping):
354360
return False
355361

356-
if len(expectation) == 1:
357-
field_name, spec = next(iteritems(expectation))
358-
elif key_to_compare is not None:
359-
field_name, spec = key_to_compare, expectation[key_to_compare]
360-
else:
361-
return False
362-
363-
if not (isinstance(spec, abc.Mapping) and len(spec) == 1):
364-
return False
362+
is_special_op, opname, spec = False, False, False
365363

366-
opname, payload = next(iteritems(spec))
367-
if opname.startswith('$$'):
364+
if key_to_compare is not None:
365+
if key_to_compare.startswith('$$'):
366+
is_special_op = True
367+
opname = key_to_compare
368+
spec = expectation[key_to_compare]
369+
key_to_compare = None
370+
else:
371+
nested = expectation[key_to_compare]
372+
if isinstance(nested, abc.Mapping) and len(nested) == 1:
373+
opname, spec = next(iteritems(nested))
374+
if opname.startswith('$$'):
375+
is_special_op = True
376+
elif len(expectation) == 1:
377+
opname, spec = next(iteritems(expectation))
378+
if opname.startswith('$$'):
379+
is_special_op = True
380+
key_to_compare = None
381+
382+
if is_special_op:
368383
self._evaluate_special_operation(
369384
opname=opname,
370-
spec=payload,
385+
spec=spec,
371386
actual=actual,
372387
key_to_compare=key_to_compare)
373388
return True
@@ -546,10 +561,15 @@ def process_error(self, exception, spec):
546561
raise NotImplementedError
547562

548563
if error_labels_contain:
549-
raise NotImplementedError
564+
labels = [err_label for err_label in error_labels_contain
565+
if exception.has_error_label(err_label)]
566+
self.assertEqual(labels, error_labels_contain)
550567

551568
if error_labels_omit:
552-
raise NotImplementedError
569+
for err_label in error_labels_omit:
570+
if exception.has_error_label(err_label):
571+
self.fail("Exception '%s' unexpectedly had label '%s'" % (
572+
exception, err_label))
553573

554574
if expect_result:
555575
if isinstance(exception, BulkWriteError):
@@ -607,21 +627,21 @@ def _collectionOperation_find(self, target, *args, **kwargs):
607627

608628
def _collectionOperation_findOneAndReplace(self, target, *args, **kwargs):
609629
self.__raise_if_unsupported('findOneAndReplace', target, Collection)
610-
find_cursor = target.find_one_and_replace(*args, **kwargs)
611-
return list(find_cursor)
630+
return target.find_one_and_replace(*args, **kwargs)
612631

613632
def _collectionOperation_findOneAndUpdate(self, target, *args, **kwargs):
614633
self.__raise_if_unsupported('findOneAndReplace', target, Collection)
615-
find_cursor = target.find_one_and_update(*args, **kwargs)
616-
return list(find_cursor)
634+
return target.find_one_and_update(*args, **kwargs)
617635

618636
def _collectionOperation_insertMany(self, target, *args, **kwargs):
619637
self.__raise_if_unsupported('insertMany', target, Collection)
620-
return target.insert_many(*args, **kwargs)
638+
result = target.insert_many(*args, **kwargs)
639+
return {idx: _id for idx, _id in enumerate(result.inserted_ids)}
621640

622641
def _collectionOperation_insertOne(self, target, *args, **kwargs):
623642
self.__raise_if_unsupported('insertOne', target, Collection)
624-
return target.insert_one(*args, **kwargs)
643+
result = target.insert_one(*args, **kwargs)
644+
return {'insertedId': result.inserted_id}
625645

626646
def _sessionOperation_withTransaction(self, target, *args, **kwargs):
627647
self.__raise_if_unsupported('withTransaction', target, ClientSession)
@@ -677,9 +697,6 @@ def run_entity_operation(self, spec):
677697
if expect_error:
678698
return self.process_error(exc, expect_error)
679699
raise
680-
else:
681-
if isinstance(result, Cursor):
682-
result = list(result)
683700

684701
if 'expectResult' in spec:
685702
self.match_evaluator.match_result(spec['expectResult'], result)

0 commit comments

Comments
 (0)