Skip to content

Commit 5efeaf9

Browse files
authored
Flatten CoreSchema types to get a single discriminant key (#450)
1 parent 28ff34e commit 5efeaf9

34 files changed

+348
-499
lines changed

generate_self_schema.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,26 +197,9 @@ def main() -> None:
197197
assert m, f'Unknown schema type: {type_}'
198198
key = m.group(1)
199199
value = get_schema(s)
200-
if key == 'function':
201-
mode = value['fields']['mode']['schema']['expected']
202-
if mode == ['plain']:
203-
key = 'function-plain'
204-
elif mode == ['wrap']:
205-
key = 'function-wrap'
206-
elif key == 'tuple':
207-
if value['fields']['mode']['schema']['expected'] == ['positional']:
208-
key = 'tuple-positional'
209-
else:
210-
key = 'tuple-variable'
211-
212200
choices[key] = value
213201

214-
schema = {
215-
'type': 'tagged-union',
216-
'ref': 'root-schema',
217-
'discriminator': 'self-schema-discriminator',
218-
'choices': choices,
219-
}
202+
schema = {'type': 'tagged-union', 'ref': 'root-schema', 'discriminator': 'type', 'choices': choices}
220203
python_code = (
221204
f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n'
222205
)

pydantic_core/core_schema.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,8 +1139,7 @@ def list_schema(
11391139

11401140

11411141
class TuplePositionalSchema(TypedDict, total=False):
1142-
type: Required[Literal['tuple']]
1143-
mode: Required[Literal['positional']]
1142+
type: Required[Literal['tuple-positional']]
11441143
items_schema: Required[List[CoreSchema]]
11451144
extra_schema: CoreSchema
11461145
strict: bool
@@ -1179,8 +1178,7 @@ def tuple_positional_schema(
11791178
serialization: Custom serialization schema
11801179
"""
11811180
return dict_not_none(
1182-
type='tuple',
1183-
mode='positional',
1181+
type='tuple-positional',
11841182
items_schema=list(items_schema),
11851183
extra_schema=extra_schema,
11861184
strict=strict,
@@ -1191,8 +1189,7 @@ def tuple_positional_schema(
11911189

11921190

11931191
class TupleVariableSchema(TypedDict, total=False):
1194-
type: Required[Literal['tuple']]
1195-
mode: Literal['variable']
1192+
type: Required[Literal['tuple-variable']]
11961193
items_schema: CoreSchema
11971194
min_length: int
11981195
max_length: int
@@ -1232,8 +1229,7 @@ def tuple_variable_schema(
12321229
serialization: Custom serialization schema
12331230
"""
12341231
return dict_not_none(
1235-
type='tuple',
1236-
mode='variable',
1232+
type='tuple-variable',
12371233
items_schema=items_schema,
12381234
min_length=min_length,
12391235
max_length=max_length,
@@ -1509,24 +1505,26 @@ class GeneralValidatorFunctionSchema(TypedDict):
15091505
function: GeneralValidatorFunction
15101506

15111507

1512-
class FunctionSchema(TypedDict, total=False):
1513-
type: Required[Literal['function']]
1508+
class _FunctionSchema(TypedDict, total=False):
15141509
function: Required[Union[FieldValidatorFunctionSchema, GeneralValidatorFunctionSchema]]
1515-
mode: Required[Literal['before', 'after']]
15161510
schema: Required[CoreSchema]
15171511
ref: str
15181512
metadata: Any
15191513
serialization: SerSchema
15201514

15211515

1516+
class FunctionBeforeSchema(_FunctionSchema, total=False):
1517+
type: Required[Literal['function-before']]
1518+
1519+
15221520
def field_before_validation_function(
15231521
function: FieldValidatorFunction,
15241522
schema: CoreSchema,
15251523
*,
15261524
ref: str | None = None,
15271525
metadata: Any = None,
15281526
serialization: SerSchema | None = None,
1529-
) -> FunctionSchema:
1527+
) -> FunctionBeforeSchema:
15301528
"""
15311529
Returns a schema that calls a validator function before validating
15321530
the provided **model field** schema, e.g.:
@@ -1556,8 +1554,7 @@ def fn(v: bytes, info: core_schema.ModelFieldValidationInfo) -> str:
15561554
serialization: Custom serialization schema
15571555
"""
15581556
return dict_not_none(
1559-
type='function',
1560-
mode='before',
1557+
type='function-before',
15611558
function={'type': 'field', 'function': function},
15621559
schema=schema,
15631560
ref=ref,
@@ -1573,7 +1570,7 @@ def general_before_validation_function(
15731570
ref: str | None = None,
15741571
metadata: Any = None,
15751572
serialization: SerSchema | None = None,
1576-
) -> FunctionSchema:
1573+
) -> FunctionBeforeSchema:
15771574
"""
15781575
Returns a schema that calls a validator function before validating the provided schema, e.g.:
15791576
@@ -1599,8 +1596,7 @@ def fn(v: Any, info: core_schema.ValidationInfo) -> str:
15991596
serialization: Custom serialization schema
16001597
"""
16011598
return dict_not_none(
1602-
type='function',
1603-
mode='before',
1599+
type='function-before',
16041600
function={'type': 'general', 'function': function},
16051601
schema=schema,
16061602
ref=ref,
@@ -1609,14 +1605,18 @@ def fn(v: Any, info: core_schema.ValidationInfo) -> str:
16091605
)
16101606

16111607

1608+
class FunctionAfterSchema(_FunctionSchema, total=False):
1609+
type: Required[Literal['function-after']]
1610+
1611+
16121612
def field_after_validation_function(
16131613
function: FieldValidatorFunction,
16141614
schema: CoreSchema,
16151615
*,
16161616
ref: str | None = None,
16171617
metadata: Any = None,
16181618
serialization: SerSchema | None = None,
1619-
) -> FunctionSchema:
1619+
) -> FunctionAfterSchema:
16201620
"""
16211621
Returns a schema that calls a validator function after validating
16221622
the provided **model field** schema, e.g.:
@@ -1646,8 +1646,7 @@ def fn(v: str, info: core_schema.ModelFieldValidationInfo) -> str:
16461646
serialization: Custom serialization schema
16471647
"""
16481648
return dict_not_none(
1649-
type='function',
1650-
mode='after',
1649+
type='function-after',
16511650
function={'type': 'field', 'function': function},
16521651
schema=schema,
16531652
ref=ref,
@@ -1663,7 +1662,7 @@ def general_after_validation_function(
16631662
ref: str | None = None,
16641663
metadata: Any = None,
16651664
serialization: SerSchema | None = None,
1666-
) -> FunctionSchema:
1665+
) -> FunctionAfterSchema:
16671666
"""
16681667
Returns a schema that calls a validator function after validating the provided schema, e.g.:
16691668
@@ -1687,8 +1686,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
16871686
serialization: Custom serialization schema
16881687
"""
16891688
return dict_not_none(
1690-
type='function',
1691-
mode='after',
1689+
type='function-after',
16921690
function={'type': 'general', 'function': function},
16931691
schema=schema,
16941692
ref=ref,
@@ -1727,9 +1725,8 @@ class GeneralWrapValidatorFunctionSchema(TypedDict):
17271725

17281726

17291727
class WrapFunctionSchema(TypedDict, total=False):
1730-
type: Required[Literal['function']]
1728+
type: Required[Literal['function-wrap']]
17311729
function: Required[Union[GeneralWrapValidatorFunctionSchema, FieldWrapValidatorFunctionSchema]]
1732-
mode: Required[Literal['wrap']]
17331730
schema: Required[CoreSchema]
17341731
ref: str
17351732
metadata: Any
@@ -1768,8 +1765,7 @@ def fn(v: str, validator: core_schema.CallableValidator, info: core_schema.Valid
17681765
serialization: Custom serialization schema
17691766
"""
17701767
return dict_not_none(
1771-
type='function',
1772-
mode='wrap',
1768+
type='function-wrap',
17731769
function={'type': 'general', 'function': function},
17741770
schema=schema,
17751771
ref=ref,
@@ -1817,8 +1813,7 @@ def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.Mod
18171813
serialization: Custom serialization schema
18181814
"""
18191815
return dict_not_none(
1820-
type='function',
1821-
mode='wrap',
1816+
type='function-wrap',
18221817
function={'type': 'field', 'function': function},
18231818
schema=schema,
18241819
ref=ref,
@@ -1828,8 +1823,7 @@ def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.Mod
18281823

18291824

18301825
class PlainFunctionSchema(TypedDict, total=False):
1831-
type: Required[Literal['function']]
1832-
mode: Required[Literal['plain']]
1826+
type: Required[Literal['function-plain']]
18331827
function: Required[Union[FieldValidatorFunctionSchema, GeneralValidatorFunctionSchema]]
18341828
ref: str
18351829
metadata: Any
@@ -1865,8 +1859,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
18651859
serialization: Custom serialization schema
18661860
"""
18671861
return dict_not_none(
1868-
type='function',
1869-
mode='plain',
1862+
type='function-plain',
18701863
function={'type': 'general', 'function': function},
18711864
ref=ref,
18721865
metadata=metadata,
@@ -1909,8 +1902,7 @@ def fn(v: Any, info: core_schema.ModelFieldValidationInfo) -> str:
19091902
serialization: Custom serialization schema
19101903
"""
19111904
return dict_not_none(
1912-
type='function',
1913-
mode='plain',
1905+
type='function-plain',
19141906
function={'type': 'field', 'function': function},
19151907
ref=ref,
19161908
metadata=metadata,
@@ -3068,7 +3060,8 @@ def definition_reference_schema(
30683060
FrozenSetSchema,
30693061
GeneratorSchema,
30703062
DictSchema,
3071-
FunctionSchema,
3063+
FunctionAfterSchema,
3064+
FunctionBeforeSchema,
30723065
WrapFunctionSchema,
30733066
PlainFunctionSchema,
30743067
WithDefaultSchema,
@@ -3109,12 +3102,16 @@ def definition_reference_schema(
31093102
'is-subclass',
31103103
'callable',
31113104
'list',
3112-
'tuple',
3105+
'tuple-positional',
3106+
'tuple-variable',
31133107
'set',
31143108
'frozenset',
31153109
'generator',
31163110
'dict',
3117-
'function',
3111+
'function-after',
3112+
'function-before',
3113+
'function-wrap',
3114+
'function-plain',
31183115
'default',
31193116
'nullable',
31203117
'union',

src/serializers/shared.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,12 @@ combined_serializer! {
7878
// hence they're here.
7979
Function: super::type_serializers::function::FunctionPlainSerializer;
8080
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
81-
// `TuplePositionalSerializer` & `TupleVariableSerializer` are created by
82-
// `TupleBuilder` based on the `mode` parameter.
83-
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
84-
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
8581
}
8682
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
8783
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
8884
find_only: {
89-
super::type_serializers::tuple::TupleBuilder;
9085
super::type_serializers::union::TaggedUnionBuilder;
9186
super::type_serializers::other::ChainBuilder;
92-
super::type_serializers::other::FunctionBuilder;
9387
super::type_serializers::other::CustomErrorBuilder;
9488
super::type_serializers::other::CallBuilder;
9589
super::type_serializers::other::LaxOrStrictBuilder;
@@ -100,6 +94,10 @@ combined_serializer! {
10094
super::type_serializers::definitions::DefinitionsBuilder;
10195
super::type_serializers::dataclass::DataclassArgsBuilder;
10296
super::type_serializers::dataclass::DataclassBuilder;
97+
super::type_serializers::function::FunctionBeforeSerializerBuilder;
98+
super::type_serializers::function::FunctionAfterSerializerBuilder;
99+
super::type_serializers::function::FunctionPlainSerializerBuilder;
100+
super::type_serializers::function::FunctionWrapSerializerBuilder;
103101
}
104102
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
105103
// `find_serializer` so they can be used via a `type` str.
@@ -132,6 +130,8 @@ combined_serializer! {
132130
Union: super::type_serializers::union::UnionSerializer;
133131
Literal: super::type_serializers::literal::LiteralSerializer;
134132
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
133+
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
134+
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
135135
}
136136
}
137137

@@ -150,11 +150,15 @@ impl CombinedSerializer {
150150
Some("function-plain") => {
151151
// `function` is a special case, not included in `find_serializer` since it means something
152152
// different in `schema.type`
153-
return super::type_serializers::function::FunctionPlainSerializer::new_combined(ser_schema)
154-
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
153+
return super::type_serializers::function::FunctionPlainSerializer::build(
154+
ser_schema,
155+
config,
156+
build_context,
157+
)
158+
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
155159
}
156160
Some("function-wrap") => {
157-
return super::type_serializers::function::FunctionWrapSerializer::new_combined(
161+
return super::type_serializers::function::FunctionWrapSerializer::build(
158162
ser_schema,
159163
config,
160164
build_context,

0 commit comments

Comments
 (0)