Skip to content

Commit 5c0d11b

Browse files
committed
Flatten CoreSchema types to get a single discriminant key
1 parent cc79b05 commit 5c0d11b

34 files changed

+344
-403
lines changed

generate_self_schema.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,26 +197,9 @@ def main() -> None:
197197
assert m, f'Unknown schema type: {type_}'
198198
key = m.group(1)
199199
value = get_schema(s)
200-
if key == 'function':
201-
mode = value['fields']['mode']['schema']['expected']
202-
if mode == ['plain']:
203-
key = 'function-plain'
204-
elif mode == ['wrap']:
205-
key = 'function-wrap'
206-
elif key == 'tuple':
207-
if value['fields']['mode']['schema']['expected'] == ['positional']:
208-
key = 'tuple-positional'
209-
else:
210-
key = 'tuple-variable'
211-
212200
choices[key] = value
213201

214-
schema = {
215-
'type': 'tagged-union',
216-
'ref': 'root-schema',
217-
'discriminator': 'self-schema-discriminator',
218-
'choices': choices,
219-
}
202+
schema = {'type': 'tagged-union', 'ref': 'root-schema', 'discriminator': 'type', 'choices': choices}
220203
python_code = (
221204
f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n'
222205
)

pydantic_core/core_schema.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,8 +1094,7 @@ def list_schema(
10941094

10951095

10961096
class TuplePositionalSchema(TypedDict, total=False):
1097-
type: Required[Literal['tuple']]
1098-
mode: Required[Literal['positional']]
1097+
type: Required[Literal['tuple-positional']]
10991098
items_schema: Required[List[CoreSchema]]
11001099
extra_schema: CoreSchema
11011100
strict: bool
@@ -1134,8 +1133,7 @@ def tuple_positional_schema(
11341133
serialization: Custom serialization schema
11351134
"""
11361135
return dict_not_none(
1137-
type='tuple',
1138-
mode='positional',
1136+
type='tuple-positional',
11391137
items_schema=list(items_schema),
11401138
extra_schema=extra_schema,
11411139
strict=strict,
@@ -1146,8 +1144,7 @@ def tuple_positional_schema(
11461144

11471145

11481146
class TupleVariableSchema(TypedDict, total=False):
1149-
type: Required[Literal['tuple']]
1150-
mode: Literal['variable']
1147+
type: Required[Literal['tuple-variable']]
11511148
items_schema: CoreSchema
11521149
min_length: int
11531150
max_length: int
@@ -1187,8 +1184,7 @@ def tuple_variable_schema(
11871184
serialization: Custom serialization schema
11881185
"""
11891186
return dict_not_none(
1190-
type='tuple',
1191-
mode='variable',
1187+
type='tuple-variable',
11921188
items_schema=items_schema,
11931189
min_length=min_length,
11941190
max_length=max_length,
@@ -1449,24 +1445,30 @@ def __call__(self, __input_value: Any, __info: ValidationInfo) -> Any: # pragma
14491445
...
14501446

14511447

1452-
class FunctionSchema(TypedDict, total=False):
1453-
type: Required[Literal['function']]
1454-
mode: Required[Literal['before', 'after']]
1448+
class _FunctionSchema(TypedDict, total=False):
14551449
function: Required[ValidatorFunction]
14561450
schema: Required[CoreSchema]
14571451
ref: str
14581452
metadata: Any
14591453
serialization: SerSchema
14601454

14611455

1456+
class FunctionBeforeSchema(_FunctionSchema, total=False):
1457+
type: Required[Literal['function-before']]
1458+
1459+
1460+
class FunctionAfterSchema(_FunctionSchema, total=False):
1461+
type: Required[Literal['function-after']]
1462+
1463+
14621464
def function_before_schema(
14631465
function: ValidatorFunction,
14641466
schema: CoreSchema,
14651467
*,
14661468
ref: str | None = None,
14671469
metadata: Any = None,
14681470
serialization: SerSchema | None = None,
1469-
) -> FunctionSchema:
1471+
) -> FunctionBeforeSchema:
14701472
"""
14711473
Returns a schema that calls a validator function before validating the provided schema, e.g.:
14721474
@@ -1492,8 +1494,7 @@ def fn(v: Any, info: core_schema.ValidationInfo) -> str:
14921494
serialization: Custom serialization schema
14931495
"""
14941496
return dict_not_none(
1495-
type='function',
1496-
mode='before',
1497+
type='function-before',
14971498
function=function,
14981499
schema=schema,
14991500
ref=ref,
@@ -1509,7 +1510,7 @@ def function_after_schema(
15091510
ref: str | None = None,
15101511
metadata: Any = None,
15111512
serialization: SerSchema | None = None,
1512-
) -> FunctionSchema:
1513+
) -> FunctionAfterSchema:
15131514
"""
15141515
Returns a schema that calls a validator function after validating the provided schema, e.g.:
15151516
@@ -1533,13 +1534,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
15331534
serialization: Custom serialization schema
15341535
"""
15351536
return dict_not_none(
1536-
type='function',
1537-
mode='after',
1538-
function=function,
1539-
schema=schema,
1540-
ref=ref,
1541-
metadata=metadata,
1542-
serialization=serialization,
1537+
type='function-after', function=function, schema=schema, ref=ref, metadata=metadata, serialization=serialization
15431538
)
15441539

15451540

@@ -1556,8 +1551,7 @@ def __call__(
15561551

15571552

15581553
class FunctionWrapSchema(TypedDict, total=False):
1559-
type: Required[Literal['function']]
1560-
mode: Required[Literal['wrap']]
1554+
type: Required[Literal['function-wrap']]
15611555
function: Required[WrapValidatorFunction]
15621556
schema: Required[CoreSchema]
15631557
ref: str
@@ -1597,19 +1591,12 @@ def fn(v: str, validator: core_schema.CallableValidator, info: core_schema.Valid
15971591
serialization: Custom serialization schema
15981592
"""
15991593
return dict_not_none(
1600-
type='function',
1601-
mode='wrap',
1602-
function=function,
1603-
schema=schema,
1604-
ref=ref,
1605-
metadata=metadata,
1606-
serialization=serialization,
1594+
type='function-wrap', function=function, schema=schema, ref=ref, metadata=metadata, serialization=serialization
16071595
)
16081596

16091597

16101598
class FunctionPlainSchema(TypedDict, total=False):
1611-
type: Required[Literal['function']]
1612-
mode: Required[Literal['plain']]
1599+
type: Required[Literal['function-plain']]
16131600
function: Required[ValidatorFunction]
16141601
ref: str
16151602
metadata: Any
@@ -1641,7 +1628,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
16411628
serialization: Custom serialization schema
16421629
"""
16431630
return dict_not_none(
1644-
type='function', mode='plain', function=function, ref=ref, metadata=metadata, serialization=serialization
1631+
type='function-plain', function=function, ref=ref, metadata=metadata, serialization=serialization
16451632
)
16461633

16471634

@@ -2795,7 +2782,8 @@ def definition_reference_schema(
27952782
FrozenSetSchema,
27962783
GeneratorSchema,
27972784
DictSchema,
2798-
FunctionSchema,
2785+
FunctionBeforeSchema,
2786+
FunctionAfterSchema,
27992787
FunctionWrapSchema,
28002788
FunctionPlainSchema,
28012789
WithDefaultSchema,
@@ -2836,12 +2824,16 @@ def definition_reference_schema(
28362824
'is-subclass',
28372825
'callable',
28382826
'list',
2839-
'tuple',
2827+
'tuple-positional',
2828+
'tuple-variable',
28402829
'set',
28412830
'frozenset',
28422831
'generator',
28432832
'dict',
2844-
'function',
2833+
'function-before',
2834+
'function-after',
2835+
'function-wrap',
2836+
'function-plain',
28452837
'default',
28462838
'nullable',
28472839
'union',

src/serializers/shared.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,12 @@ combined_serializer! {
7878
// hence they're here.
7979
Function: super::type_serializers::function::FunctionPlainSerializer;
8080
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
81-
// `TuplePositionalSerializer` & `TupleVariableSerializer` are created by
82-
// `TupleBuilder` based on the `mode` parameter.
83-
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
84-
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
8581
}
8682
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
8783
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
8884
find_only: {
89-
super::type_serializers::tuple::TupleBuilder;
9085
super::type_serializers::union::TaggedUnionBuilder;
9186
super::type_serializers::other::ChainBuilder;
92-
super::type_serializers::other::FunctionBuilder;
9387
super::type_serializers::other::CustomErrorBuilder;
9488
super::type_serializers::other::CallBuilder;
9589
super::type_serializers::other::LaxOrStrictBuilder;
@@ -100,6 +94,10 @@ combined_serializer! {
10094
super::type_serializers::definitions::DefinitionsBuilder;
10195
super::type_serializers::dataclass::DataclassArgsBuilder;
10296
super::type_serializers::dataclass::DataclassBuilder;
97+
super::type_serializers::function::FunctionBeforeSerializerBuilder;
98+
super::type_serializers::function::FunctionAfterSerializerBuilder;
99+
super::type_serializers::function::FunctionPlainSerializerBuilder;
100+
super::type_serializers::function::FunctionWrapSerializerBuilder;
103101
}
104102
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
105103
// `find_serializer` so they can be used via a `type` str.
@@ -132,6 +130,8 @@ combined_serializer! {
132130
Union: super::type_serializers::union::UnionSerializer;
133131
Literal: super::type_serializers::literal::LiteralSerializer;
134132
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
133+
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
134+
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
135135
}
136136
}
137137

@@ -150,11 +150,15 @@ impl CombinedSerializer {
150150
Some("function-plain") => {
151151
// `function` is a special case, not included in `find_serializer` since it means something
152152
// different in `schema.type`
153-
return super::type_serializers::function::FunctionPlainSerializer::new_combined(ser_schema)
154-
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
153+
return super::type_serializers::function::FunctionPlainSerializer::build(
154+
ser_schema,
155+
config,
156+
build_context,
157+
)
158+
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
155159
}
156160
Some("function-wrap") => {
157-
return super::type_serializers::function::FunctionWrapSerializer::new_combined(
161+
return super::type_serializers::function::FunctionWrapSerializer::build(
158162
ser_schema,
159163
config,
160164
build_context,

src/serializers/type_serializers/function.rs

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,54 @@ use super::{
2121
PydanticSerializationError, TypeSerializer,
2222
};
2323

24+
pub struct FunctionBeforeSerializerBuilder;
25+
26+
impl BuildSerializer for FunctionBeforeSerializerBuilder {
27+
const EXPECTED_TYPE: &'static str = "function-before";
28+
29+
fn build(
30+
schema: &PyDict,
31+
config: Option<&PyDict>,
32+
build_context: &mut BuildContext<CombinedSerializer>,
33+
) -> PyResult<CombinedSerializer> {
34+
let py = schema.py();
35+
// `before` schemas will obviously have type from `schema` since the validator is called second
36+
let schema = schema.get_as_req(intern!(py, "schema"))?;
37+
CombinedSerializer::build(schema, config, build_context)
38+
}
39+
}
40+
41+
pub struct FunctionAfterSerializerBuilder;
42+
43+
impl BuildSerializer for FunctionAfterSerializerBuilder {
44+
const EXPECTED_TYPE: &'static str = "function-after";
45+
fn build(
46+
schema: &PyDict,
47+
config: Option<&PyDict>,
48+
build_context: &mut BuildContext<CombinedSerializer>,
49+
) -> PyResult<CombinedSerializer> {
50+
let py = schema.py();
51+
// while `before` schemas have an obvious type, for
52+
// `after` schemas it's less, clear but the default will be the same type, and the user/lib can always
53+
// override the serializer
54+
let schema = schema.get_as_req(intern!(py, "schema"))?;
55+
CombinedSerializer::build(schema, config, build_context)
56+
}
57+
}
58+
59+
pub struct FunctionPlainSerializerBuilder;
60+
61+
impl BuildSerializer for FunctionPlainSerializerBuilder {
62+
const EXPECTED_TYPE: &'static str = "function-plain";
63+
fn build(
64+
schema: &PyDict,
65+
config: Option<&PyDict>,
66+
build_context: &mut BuildContext<CombinedSerializer>,
67+
) -> PyResult<CombinedSerializer> {
68+
super::any::AnySerializer::build(schema, config, build_context)
69+
}
70+
}
71+
2472
#[derive(Debug, Clone)]
2573
pub struct FunctionPlainSerializer {
2674
func: PyObject,
@@ -30,8 +78,13 @@ pub struct FunctionPlainSerializer {
3078
when_used: WhenUsed,
3179
}
3280

33-
impl FunctionPlainSerializer {
34-
pub fn new_combined(schema: &PyDict) -> PyResult<CombinedSerializer> {
81+
impl BuildSerializer for FunctionPlainSerializer {
82+
const EXPECTED_TYPE: &'static str = "function-plain";
83+
fn build(
84+
schema: &PyDict,
85+
_config: Option<&PyDict>,
86+
_build_context: &mut BuildContext<CombinedSerializer>,
87+
) -> PyResult<CombinedSerializer> {
3588
let py = schema.py();
3689
let function = schema.get_as_req::<&PyAny>(intern!(py, "function"))?;
3790
let function_name = function_name(function)?;
@@ -46,7 +99,9 @@ impl FunctionPlainSerializer {
4699
}
47100
.into())
48101
}
102+
}
49103

104+
impl FunctionPlainSerializer {
50105
fn call(
51106
&self,
52107
value: &PyAny,
@@ -177,6 +232,19 @@ macro_rules! function_type_serializer {
177232

178233
function_type_serializer!(FunctionPlainSerializer);
179234

235+
pub struct FunctionWrapSerializerBuilder;
236+
237+
impl BuildSerializer for FunctionWrapSerializerBuilder {
238+
const EXPECTED_TYPE: &'static str = "function-wrap";
239+
fn build(
240+
schema: &PyDict,
241+
config: Option<&PyDict>,
242+
build_context: &mut BuildContext<CombinedSerializer>,
243+
) -> PyResult<CombinedSerializer> {
244+
super::any::AnySerializer::build(schema, config, build_context)
245+
}
246+
}
247+
180248
#[derive(Debug, Clone)]
181249
pub struct FunctionWrapSerializer {
182250
serializer: Box<CombinedSerializer>,
@@ -187,8 +255,9 @@ pub struct FunctionWrapSerializer {
187255
when_used: WhenUsed,
188256
}
189257

190-
impl FunctionWrapSerializer {
191-
pub fn new_combined(
258+
impl BuildSerializer for FunctionWrapSerializer {
259+
const EXPECTED_TYPE: &'static str = "function-wrap";
260+
fn build(
192261
schema: &PyDict,
193262
config: Option<&PyDict>,
194263
build_context: &mut BuildContext<CombinedSerializer>,
@@ -211,7 +280,9 @@ impl FunctionWrapSerializer {
211280
}
212281
.into())
213282
}
283+
}
214284

285+
impl FunctionWrapSerializer {
215286
fn call(
216287
&self,
217288
value: &PyAny,

0 commit comments

Comments
 (0)