|
| 1 | +import collections.abc |
1 | 2 | import re
|
2 | 3 | from collections import deque
|
3 | 4 | from collections.abc import Sequence
|
4 |
| -from typing import Any, Dict |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Any, Dict, Iterator, List, Union |
5 | 7 |
|
6 | 8 | import pytest
|
7 |
| -from dirty_equals import HasRepr, IsInstance, IsList, IsStr |
| 9 | +from dirty_equals import Contains, HasRepr, IsInstance, IsList, IsStr |
8 | 10 |
|
9 |
| -from pydantic_core import SchemaValidator, ValidationError |
| 11 | +from pydantic_core import SchemaValidator, ValidationError, core_schema |
10 | 12 |
|
11 | 13 | from ..conftest import Err, PyAndJson, infinite_generator
|
12 | 14 |
|
@@ -411,3 +413,144 @@ def __next__(self):
|
411 | 413 | 'ctx': {'error': 'RuntimeError: broken'},
|
412 | 414 | }
|
413 | 415 | ]
|
| 416 | + |
| 417 | + |
| 418 | +@pytest.mark.parametrize('error_in_func', [True, False]) |
| 419 | +def test_max_length_fail_fast(error_in_func: bool) -> None: |
| 420 | + calls: list[int] = [] |
| 421 | + |
| 422 | + def f(v: int) -> int: |
| 423 | + calls.append(v) |
| 424 | + if error_in_func: |
| 425 | + assert v < 10 |
| 426 | + return v |
| 427 | + |
| 428 | + s = core_schema.list_schema( |
| 429 | + core_schema.no_info_after_validator_function(f, core_schema.int_schema()), max_length=10 |
| 430 | + ) |
| 431 | + |
| 432 | + v = SchemaValidator(s) |
| 433 | + |
| 434 | + data = list(range(15)) |
| 435 | + |
| 436 | + with pytest.raises(ValidationError) as exc_info: |
| 437 | + v.validate_python(data) |
| 438 | + |
| 439 | + assert len(calls) <= 10, len(calls) |
| 440 | + |
| 441 | + assert exc_info.value.errors(include_url=False) == Contains( |
| 442 | + { |
| 443 | + 'type': 'too_long', |
| 444 | + 'loc': (), |
| 445 | + 'msg': 'List should have at most 10 items after validation, not 11', |
| 446 | + 'input': data, |
| 447 | + 'ctx': {'field_type': 'List', 'max_length': 10, 'actual_length': 11}, |
| 448 | + } |
| 449 | + ) |
| 450 | + |
| 451 | + |
| 452 | +class MySequence(collections.abc.Sequence): |
| 453 | + def __init__(self, data: List[Any]): |
| 454 | + self._data = data |
| 455 | + |
| 456 | + def __getitem__(self, index: int) -> Any: |
| 457 | + return self._data[index] |
| 458 | + |
| 459 | + def __len__(self): |
| 460 | + return len(self._data) |
| 461 | + |
| 462 | + def __repr__(self) -> str: |
| 463 | + return f'MySequence({repr(self._data)})' |
| 464 | + |
| 465 | + |
| 466 | +class MyMapping(collections.abc.Mapping): |
| 467 | + def __init__(self, data: Dict[Any, Any]) -> None: |
| 468 | + self._data = data |
| 469 | + |
| 470 | + def __getitem__(self, key: Any) -> Any: |
| 471 | + return self._data[key] |
| 472 | + |
| 473 | + def __iter__(self) -> Iterator[Any]: |
| 474 | + return iter(self._data) |
| 475 | + |
| 476 | + def __len__(self) -> int: |
| 477 | + return len(self._data) |
| 478 | + |
| 479 | + def __repr__(self) -> str: |
| 480 | + return f'MyMapping({repr(self._data)})' |
| 481 | + |
| 482 | + |
| 483 | +@dataclass |
| 484 | +class ListInputTestCase: |
| 485 | + input: Any |
| 486 | + output: Union[Any, Err] |
| 487 | + strict: Union[bool, None] = None |
| 488 | + |
| 489 | + |
| 490 | +LAX_MODE_INPUTS: List[Any] = [ |
| 491 | + (1, 2, 3), |
| 492 | + frozenset((1, 2, 3)), |
| 493 | + set((1, 2, 3)), |
| 494 | + deque([1, 2, 3]), |
| 495 | + {1: 'a', 2: 'b', 3: 'c'}.keys(), |
| 496 | + {'a': 1, 'b': 2, 'c': 3}.values(), |
| 497 | + MySequence([1, 2, 3]), |
| 498 | + MyMapping({1: 'a', 2: 'b', 3: 'c'}).keys(), |
| 499 | + MyMapping({'a': 1, 'b': 2, 'c': 3}).values(), |
| 500 | + (x for x in [1, 2, 3]), |
| 501 | +] |
| 502 | + |
| 503 | + |
| 504 | +@pytest.mark.parametrize( |
| 505 | + 'testcase', |
| 506 | + [ |
| 507 | + *[ListInputTestCase([1, 2, 3], [1, 2, 3], strict) for strict in (True, False, None)], |
| 508 | + *[ |
| 509 | + ListInputTestCase(inp, Err('Input should be a valid list [type=list_type,'), True) |
| 510 | + for inp in [*LAX_MODE_INPUTS, '123', b'123'] |
| 511 | + ], |
| 512 | + *[ListInputTestCase(inp, [1, 2, 3], False) for inp in LAX_MODE_INPUTS], |
| 513 | + *[ |
| 514 | + ListInputTestCase(inp, Err('Input should be a valid list [type=list_type,'), False) |
| 515 | + for inp in ['123', b'123', MyMapping({1: 'a', 2: 'b', 3: 'c'}), {1: 'a', 2: 'b', 3: 'c'}] |
| 516 | + ], |
| 517 | + ], |
| 518 | + ids=repr, |
| 519 | +) |
| 520 | +def test_list_allowed_inputs_python(testcase: ListInputTestCase): |
| 521 | + v = SchemaValidator(core_schema.list_schema(core_schema.int_schema(), strict=testcase.strict)) |
| 522 | + if isinstance(testcase.output, Err): |
| 523 | + with pytest.raises(ValidationError, match=re.escape(testcase.output.message)): |
| 524 | + v.validate_python(testcase.input) |
| 525 | + else: |
| 526 | + output = v.validate_python(testcase.input) |
| 527 | + assert output == testcase.output |
| 528 | + assert output is not testcase.input |
| 529 | + |
| 530 | + |
| 531 | +@pytest.mark.parametrize( |
| 532 | + 'testcase', |
| 533 | + [ |
| 534 | + ListInputTestCase({1: 1, 2: 2, 3: 3}.items(), Err('Input should be a valid list [type=list_type,'), True), |
| 535 | + ListInputTestCase( |
| 536 | + MyMapping({1: 1, 2: 2, 3: 3}).items(), Err('Input should be a valid list [type=list_type,'), True |
| 537 | + ), |
| 538 | + ListInputTestCase({1: 1, 2: 2, 3: 3}.items(), [(1, 1), (2, 2), (3, 3)], False), |
| 539 | + ListInputTestCase(MyMapping({1: 1, 2: 2, 3: 3}).items(), [(1, 1), (2, 2), (3, 3)], False), |
| 540 | + ], |
| 541 | + ids=repr, |
| 542 | +) |
| 543 | +def test_list_dict_items_input(testcase: ListInputTestCase) -> None: |
| 544 | + v = SchemaValidator( |
| 545 | + core_schema.list_schema( |
| 546 | + core_schema.tuple_positional_schema([core_schema.int_schema(), core_schema.int_schema()]), |
| 547 | + strict=testcase.strict, |
| 548 | + ) |
| 549 | + ) |
| 550 | + if isinstance(testcase.output, Err): |
| 551 | + with pytest.raises(ValidationError, match=re.escape(testcase.output.message)): |
| 552 | + v.validate_python(testcase.input) |
| 553 | + else: |
| 554 | + output = v.validate_python(testcase.input) |
| 555 | + assert output == testcase.output |
| 556 | + assert output is not testcase.input |
0 commit comments