@@ -790,51 +790,92 @@ def __init__(self, type_: Literal['cat']) -> None:
790
790
self .type_ = 'cat'
791
791
792
792
793
- def test_union_of_unions_of_models () -> None :
793
+ @pytest .fixture
794
+ def model_a_b_union_schema () -> core_schema .UnionSchema :
795
+ return core_schema .union_schema (
796
+ [
797
+ core_schema .model_schema (
798
+ cls = ModelA ,
799
+ schema = core_schema .model_fields_schema (
800
+ fields = {
801
+ 'a' : core_schema .model_field (core_schema .str_schema ()),
802
+ 'b' : core_schema .model_field (core_schema .str_schema ()),
803
+ },
804
+ ),
805
+ ),
806
+ core_schema .model_schema (
807
+ cls = ModelB ,
808
+ schema = core_schema .model_fields_schema (
809
+ fields = {
810
+ 'c' : core_schema .model_field (core_schema .str_schema ()),
811
+ 'd' : core_schema .model_field (core_schema .str_schema ()),
812
+ },
813
+ ),
814
+ ),
815
+ ]
816
+ )
817
+
818
+
819
+ def test_union_of_unions_of_models (model_a_b_union_schema : core_schema .UnionSchema ) -> None :
794
820
s = SchemaSerializer (
795
821
core_schema .union_schema (
796
822
[
823
+ model_a_b_union_schema ,
797
824
core_schema .union_schema (
798
825
[
799
826
core_schema .model_schema (
800
- cls = ModelA ,
827
+ cls = ModelCat ,
801
828
schema = core_schema .model_fields_schema (
802
829
fields = {
803
- 'a' : core_schema .model_field (core_schema .str_schema ()),
804
- 'b' : core_schema .model_field (core_schema .str_schema ()),
830
+ 'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
805
831
},
806
832
),
807
833
),
808
834
core_schema .model_schema (
809
- cls = ModelB ,
835
+ cls = ModelDog ,
810
836
schema = core_schema .model_fields_schema (
811
837
fields = {
812
- 'c' : core_schema .model_field (core_schema .str_schema ()),
813
- 'd' : core_schema .model_field (core_schema .str_schema ()),
838
+ 'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
814
839
},
815
840
),
816
841
),
817
842
]
818
843
),
819
- core_schema .union_schema (
820
- [
821
- core_schema .model_schema (
844
+ ]
845
+ )
846
+ )
847
+
848
+ assert s .to_python (ModelA (a = 'a' , b = 'b' ), warnings = 'error' ) == {'a' : 'a' , 'b' : 'b' }
849
+ assert s .to_python (ModelB (c = 'c' , d = 'd' ), warnings = 'error' ) == {'c' : 'c' , 'd' : 'd' }
850
+ assert s .to_python (ModelCat (type_ = 'cat' ), warnings = 'error' ) == {'type_' : 'cat' }
851
+ assert s .to_python (ModelDog (type_ = 'dog' ), warnings = 'error' ) == {'type_' : 'dog' }
852
+
853
+
854
+ def test_union_of_unions_of_models_with_tagged_union (model_a_b_union_schema : core_schema .UnionSchema ) -> None :
855
+ s = SchemaSerializer (
856
+ core_schema .union_schema (
857
+ [
858
+ model_a_b_union_schema ,
859
+ core_schema .tagged_union_schema (
860
+ discriminator = 'type_' ,
861
+ choices = {
862
+ 'cat' : core_schema .model_schema (
822
863
cls = ModelCat ,
823
864
schema = core_schema .model_fields_schema (
824
865
fields = {
825
866
'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
826
867
},
827
868
),
828
869
),
829
- core_schema .model_schema (
870
+ 'dog' : core_schema .model_schema (
830
871
cls = ModelDog ,
831
872
schema = core_schema .model_fields_schema (
832
873
fields = {
833
874
'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
834
875
},
835
876
),
836
877
),
837
- ]
878
+ },
838
879
),
839
880
]
840
881
)
0 commit comments