|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import dataclasses
|
2 | 4 | import json
|
3 | 5 | import uuid
|
@@ -778,3 +780,223 @@ class ModelB:
|
778 | 780 | model_b = ModelB(field=1)
|
779 | 781 | assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
|
780 | 782 | assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}
|
| 783 | + |
| 784 | + |
| 785 | +class ModelDog: |
| 786 | + def __init__(self, type_: Literal['dog']) -> None: |
| 787 | + self.type_ = 'dog' |
| 788 | + |
| 789 | + |
| 790 | +class ModelCat: |
| 791 | + def __init__(self, type_: Literal['cat']) -> None: |
| 792 | + self.type_ = 'cat' |
| 793 | + |
| 794 | + |
| 795 | +class ModelAlien: |
| 796 | + def __init__(self, type_: Literal['alien']) -> None: |
| 797 | + self.type_ = 'alien' |
| 798 | + |
| 799 | + |
| 800 | +@pytest.fixture |
| 801 | +def model_a_b_union_schema() -> core_schema.UnionSchema: |
| 802 | + return core_schema.union_schema( |
| 803 | + [ |
| 804 | + core_schema.model_schema( |
| 805 | + cls=ModelA, |
| 806 | + schema=core_schema.model_fields_schema( |
| 807 | + fields={ |
| 808 | + 'a': core_schema.model_field(core_schema.str_schema()), |
| 809 | + 'b': core_schema.model_field(core_schema.str_schema()), |
| 810 | + }, |
| 811 | + ), |
| 812 | + ), |
| 813 | + core_schema.model_schema( |
| 814 | + cls=ModelB, |
| 815 | + schema=core_schema.model_fields_schema( |
| 816 | + fields={ |
| 817 | + 'c': core_schema.model_field(core_schema.str_schema()), |
| 818 | + 'd': core_schema.model_field(core_schema.str_schema()), |
| 819 | + }, |
| 820 | + ), |
| 821 | + ), |
| 822 | + ] |
| 823 | + ) |
| 824 | + |
| 825 | + |
| 826 | +@pytest.fixture |
| 827 | +def union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema: |
| 828 | + return core_schema.union_schema( |
| 829 | + [ |
| 830 | + model_a_b_union_schema, |
| 831 | + core_schema.union_schema( |
| 832 | + [ |
| 833 | + core_schema.model_schema( |
| 834 | + cls=ModelCat, |
| 835 | + schema=core_schema.model_fields_schema( |
| 836 | + fields={ |
| 837 | + 'type_': core_schema.model_field(core_schema.literal_schema(['cat'])), |
| 838 | + }, |
| 839 | + ), |
| 840 | + ), |
| 841 | + core_schema.model_schema( |
| 842 | + cls=ModelDog, |
| 843 | + schema=core_schema.model_fields_schema( |
| 844 | + fields={ |
| 845 | + 'type_': core_schema.model_field(core_schema.literal_schema(['dog'])), |
| 846 | + }, |
| 847 | + ), |
| 848 | + ), |
| 849 | + ] |
| 850 | + ), |
| 851 | + ] |
| 852 | + ) |
| 853 | + |
| 854 | + |
| 855 | +@pytest.mark.parametrize( |
| 856 | + 'input,expected', |
| 857 | + [ |
| 858 | + (ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}), |
| 859 | + (ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}), |
| 860 | + (ModelCat(type_='cat'), {'type_': 'cat'}), |
| 861 | + (ModelDog(type_='dog'), {'type_': 'dog'}), |
| 862 | + ], |
| 863 | +) |
| 864 | +def test_union_of_unions_of_models(union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any) -> None: |
| 865 | + s = SchemaSerializer(union_of_unions_schema) |
| 866 | + assert s.to_python(input, warnings='error') == expected |
| 867 | + |
| 868 | + |
| 869 | +def test_union_of_unions_of_models_invalid_variant(union_of_unions_schema: core_schema.UnionSchema) -> None: |
| 870 | + s = SchemaSerializer(union_of_unions_schema) |
| 871 | + # All warnings should be available |
| 872 | + messages = [ |
| 873 | + 'Expected `ModelA` but got `ModelAlien`', |
| 874 | + 'Expected `ModelB` but got `ModelAlien`', |
| 875 | + 'Expected `ModelCat` but got `ModelAlien`', |
| 876 | + 'Expected `ModelDog` but got `ModelAlien`', |
| 877 | + ] |
| 878 | + |
| 879 | + with warnings.catch_warnings(record=True) as w: |
| 880 | + warnings.simplefilter('always') |
| 881 | + s.to_python(ModelAlien(type_='alien')) |
| 882 | + for m in messages: |
| 883 | + assert m in str(w[0].message) |
| 884 | + |
| 885 | + |
| 886 | +@pytest.fixture |
| 887 | +def tagged_union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema: |
| 888 | + return core_schema.union_schema( |
| 889 | + [ |
| 890 | + model_a_b_union_schema, |
| 891 | + core_schema.tagged_union_schema( |
| 892 | + discriminator='type_', |
| 893 | + choices={ |
| 894 | + 'cat': core_schema.model_schema( |
| 895 | + cls=ModelCat, |
| 896 | + schema=core_schema.model_fields_schema( |
| 897 | + fields={ |
| 898 | + 'type_': core_schema.model_field(core_schema.literal_schema(['cat'])), |
| 899 | + }, |
| 900 | + ), |
| 901 | + ), |
| 902 | + 'dog': core_schema.model_schema( |
| 903 | + cls=ModelDog, |
| 904 | + schema=core_schema.model_fields_schema( |
| 905 | + fields={ |
| 906 | + 'type_': core_schema.model_field(core_schema.literal_schema(['dog'])), |
| 907 | + }, |
| 908 | + ), |
| 909 | + ), |
| 910 | + }, |
| 911 | + ), |
| 912 | + ] |
| 913 | + ) |
| 914 | + |
| 915 | + |
| 916 | +@pytest.mark.parametrize( |
| 917 | + 'input,expected', |
| 918 | + [ |
| 919 | + (ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}), |
| 920 | + (ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}), |
| 921 | + (ModelCat(type_='cat'), {'type_': 'cat'}), |
| 922 | + (ModelDog(type_='dog'), {'type_': 'dog'}), |
| 923 | + ], |
| 924 | +) |
| 925 | +def test_union_of_unions_of_models_with_tagged_union( |
| 926 | + tagged_union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any |
| 927 | +) -> None: |
| 928 | + s = SchemaSerializer(tagged_union_of_unions_schema) |
| 929 | + assert s.to_python(input, warnings='error') == expected |
| 930 | + |
| 931 | + |
| 932 | +def test_union_of_unions_of_models_with_tagged_union_invalid_variant( |
| 933 | + tagged_union_of_unions_schema: core_schema.UnionSchema, |
| 934 | +) -> None: |
| 935 | + s = SchemaSerializer(tagged_union_of_unions_schema) |
| 936 | + # All warnings should be available |
| 937 | + messages = [ |
| 938 | + 'Expected `ModelA` but got `ModelAlien`', |
| 939 | + 'Expected `ModelB` but got `ModelAlien`', |
| 940 | + 'Expected `ModelCat` but got `ModelAlien`', |
| 941 | + 'Expected `ModelDog` but got `ModelAlien`', |
| 942 | + ] |
| 943 | + |
| 944 | + with warnings.catch_warnings(record=True) as w: |
| 945 | + warnings.simplefilter('always') |
| 946 | + s.to_python(ModelAlien(type_='alien')) |
| 947 | + for m in messages: |
| 948 | + assert m in str(w[0].message) |
| 949 | + |
| 950 | + |
| 951 | +@pytest.mark.parametrize( |
| 952 | + 'input,expected', |
| 953 | + [ |
| 954 | + ({True: '1'}, b'{"true":"1"}'), |
| 955 | + ({1: '1'}, b'{"1":"1"}'), |
| 956 | + ({2.3: '1'}, b'{"2.3":"1"}'), |
| 957 | + ({'a': 'b'}, b'{"a":"b"}'), |
| 958 | + ], |
| 959 | +) |
| 960 | +def test_union_of_unions_of_models_with_tagged_union_json_key_serialization( |
| 961 | + input: dict[bool | int | float | str, str], expected: bytes |
| 962 | +) -> None: |
| 963 | + s = SchemaSerializer( |
| 964 | + core_schema.dict_schema( |
| 965 | + keys_schema=core_schema.union_schema( |
| 966 | + [ |
| 967 | + core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]), |
| 968 | + core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]), |
| 969 | + ] |
| 970 | + ), |
| 971 | + values_schema=core_schema.str_schema(), |
| 972 | + ) |
| 973 | + ) |
| 974 | + |
| 975 | + assert s.to_json(input, warnings='error') == expected |
| 976 | + |
| 977 | + |
| 978 | +@pytest.mark.parametrize( |
| 979 | + 'input,expected', |
| 980 | + [ |
| 981 | + ({'key': True}, b'{"key":true}'), |
| 982 | + ({'key': 1}, b'{"key":1}'), |
| 983 | + ({'key': 2.3}, b'{"key":2.3}'), |
| 984 | + ({'key': 'a'}, b'{"key":"a"}'), |
| 985 | + ], |
| 986 | +) |
| 987 | +def test_union_of_unions_of_models_with_tagged_union_json_serialization( |
| 988 | + input: dict[str, bool | int | float | str], expected: bytes |
| 989 | +) -> None: |
| 990 | + s = SchemaSerializer( |
| 991 | + core_schema.dict_schema( |
| 992 | + keys_schema=core_schema.str_schema(), |
| 993 | + values_schema=core_schema.union_schema( |
| 994 | + [ |
| 995 | + core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]), |
| 996 | + core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]), |
| 997 | + ] |
| 998 | + ), |
| 999 | + ) |
| 1000 | + ) |
| 1001 | + |
| 1002 | + assert s.to_json(input, warnings='error') == expected |
0 commit comments