Skip to content

Flatten schemas and replace macros with plain code #450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,9 @@ def main() -> None:
assert m, f'Unknown schema type: {type_}'
key = m.group(1)
value = get_schema(s)
if key == 'function':
mode = value['fields']['mode']['schema']['expected']
if mode == ['plain']:
key = 'function-plain'
elif mode == ['wrap']:
key = 'function-wrap'
elif key == 'tuple':
if value['fields']['mode']['schema']['expected'] == ['positional']:
key = 'tuple-positional'
else:
key = 'tuple-variable'

choices[key] = value

schema = {
'type': 'tagged-union',
'ref': 'root-schema',
'discriminator': 'self-schema-discriminator',
'choices': choices,
}
schema = {'type': 'tagged-union', 'ref': 'root-schema', 'discriminator': 'type', 'choices': choices}
python_code = (
f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n'
)
Expand Down
73 changes: 35 additions & 38 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,8 +1139,7 @@ def list_schema(


class TuplePositionalSchema(TypedDict, total=False):
type: Required[Literal['tuple']]
mode: Required[Literal['positional']]
type: Required[Literal['tuple-positional']]
items_schema: Required[List[CoreSchema]]
extra_schema: CoreSchema
strict: bool
Expand Down Expand Up @@ -1179,8 +1178,7 @@ def tuple_positional_schema(
serialization: Custom serialization schema
"""
return dict_not_none(
type='tuple',
mode='positional',
type='tuple-positional',
items_schema=list(items_schema),
extra_schema=extra_schema,
strict=strict,
Expand All @@ -1191,8 +1189,7 @@ def tuple_positional_schema(


class TupleVariableSchema(TypedDict, total=False):
type: Required[Literal['tuple']]
mode: Literal['variable']
type: Required[Literal['tuple-variable']]
items_schema: CoreSchema
min_length: int
max_length: int
Expand Down Expand Up @@ -1232,8 +1229,7 @@ def tuple_variable_schema(
serialization: Custom serialization schema
"""
return dict_not_none(
type='tuple',
mode='variable',
type='tuple-variable',
items_schema=items_schema,
min_length=min_length,
max_length=max_length,
Expand Down Expand Up @@ -1509,24 +1505,26 @@ class GeneralValidatorFunctionSchema(TypedDict):
function: GeneralValidatorFunction


class FunctionSchema(TypedDict, total=False):
type: Required[Literal['function']]
class _FunctionSchema(TypedDict, total=False):
function: Required[Union[FieldValidatorFunctionSchema, GeneralValidatorFunctionSchema]]
mode: Required[Literal['before', 'after']]
schema: Required[CoreSchema]
ref: str
metadata: Any
serialization: SerSchema


class FunctionBeforeSchema(_FunctionSchema, total=False):
type: Required[Literal['function-before']]


def field_before_validation_function(
function: FieldValidatorFunction,
schema: CoreSchema,
*,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionBeforeSchema:
"""
Returns a schema that calls a validator function before validating
the provided **model field** schema, e.g.:
Expand Down Expand Up @@ -1556,8 +1554,7 @@ def fn(v: bytes, info: core_schema.ModelFieldValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='before',
type='function-before',
function={'type': 'field', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1573,7 +1570,7 @@ def general_before_validation_function(
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionBeforeSchema:
"""
Returns a schema that calls a validator function before validating the provided schema, e.g.:

Expand All @@ -1599,8 +1596,7 @@ def fn(v: Any, info: core_schema.ValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='before',
type='function-before',
function={'type': 'general', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1609,14 +1605,18 @@ def fn(v: Any, info: core_schema.ValidationInfo) -> str:
)


class FunctionAfterSchema(_FunctionSchema, total=False):
type: Required[Literal['function-after']]


def field_after_validation_function(
function: FieldValidatorFunction,
schema: CoreSchema,
*,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionAfterSchema:
"""
Returns a schema that calls a validator function after validating
the provided **model field** schema, e.g.:
Expand Down Expand Up @@ -1646,8 +1646,7 @@ def fn(v: str, info: core_schema.ModelFieldValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='after',
type='function-after',
function={'type': 'field', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1663,7 +1662,7 @@ def general_after_validation_function(
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionAfterSchema:
"""
Returns a schema that calls a validator function after validating the provided schema, e.g.:

Expand All @@ -1687,8 +1686,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='after',
type='function-after',
function={'type': 'general', 'function': function},
schema=schema,
ref=ref,
Expand Down Expand Up @@ -1727,9 +1725,8 @@ class GeneralWrapValidatorFunctionSchema(TypedDict):


class WrapFunctionSchema(TypedDict, total=False):
type: Required[Literal['function']]
type: Required[Literal['function-wrap']]
function: Required[Union[GeneralWrapValidatorFunctionSchema, FieldWrapValidatorFunctionSchema]]
mode: Required[Literal['wrap']]
schema: Required[CoreSchema]
ref: str
metadata: Any
Expand Down Expand Up @@ -1768,8 +1765,7 @@ def fn(v: str, validator: core_schema.CallableValidator, info: core_schema.Valid
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='wrap',
type='function-wrap',
function={'type': 'general', 'function': function},
schema=schema,
ref=ref,
Expand Down Expand Up @@ -1817,8 +1813,7 @@ def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.Mod
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='wrap',
type='function-wrap',
function={'type': 'field', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1828,8 +1823,7 @@ def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.Mod


class PlainFunctionSchema(TypedDict, total=False):
type: Required[Literal['function']]
mode: Required[Literal['plain']]
type: Required[Literal['function-plain']]
function: Required[Union[FieldValidatorFunctionSchema, GeneralValidatorFunctionSchema]]
ref: str
metadata: Any
Expand Down Expand Up @@ -1865,8 +1859,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='plain',
type='function-plain',
function={'type': 'general', 'function': function},
ref=ref,
metadata=metadata,
Expand Down Expand Up @@ -1909,8 +1902,7 @@ def fn(v: Any, info: core_schema.ModelFieldValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='plain',
type='function-plain',
function={'type': 'field', 'function': function},
ref=ref,
metadata=metadata,
Expand Down Expand Up @@ -3068,7 +3060,8 @@ def definition_reference_schema(
FrozenSetSchema,
GeneratorSchema,
DictSchema,
FunctionSchema,
FunctionAfterSchema,
FunctionBeforeSchema,
WrapFunctionSchema,
PlainFunctionSchema,
WithDefaultSchema,
Expand Down Expand Up @@ -3109,12 +3102,16 @@ def definition_reference_schema(
'is-subclass',
'callable',
'list',
'tuple',
'tuple-positional',
'tuple-variable',
'set',
'frozenset',
'generator',
'dict',
'function',
'function-after',
'function-before',
'function-wrap',
'function-plain',
'default',
'nullable',
'union',
Expand Down
22 changes: 13 additions & 9 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,12 @@ combined_serializer! {
// hence they're here.
Function: super::type_serializers::function::FunctionPlainSerializer;
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
// `TuplePositionalSerializer` & `TupleVariableSerializer` are created by
// `TupleBuilder` based on the `mode` parameter.
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
}
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
find_only: {
super::type_serializers::tuple::TupleBuilder;
super::type_serializers::union::TaggedUnionBuilder;
super::type_serializers::other::ChainBuilder;
super::type_serializers::other::FunctionBuilder;
super::type_serializers::other::CustomErrorBuilder;
super::type_serializers::other::CallBuilder;
super::type_serializers::other::LaxOrStrictBuilder;
Expand All @@ -100,6 +94,10 @@ combined_serializer! {
super::type_serializers::definitions::DefinitionsBuilder;
super::type_serializers::dataclass::DataclassArgsBuilder;
super::type_serializers::dataclass::DataclassBuilder;
super::type_serializers::function::FunctionBeforeSerializerBuilder;
super::type_serializers::function::FunctionAfterSerializerBuilder;
super::type_serializers::function::FunctionPlainSerializerBuilder;
super::type_serializers::function::FunctionWrapSerializerBuilder;
}
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
// `find_serializer` so they can be used via a `type` str.
Expand Down Expand Up @@ -132,6 +130,8 @@ combined_serializer! {
Union: super::type_serializers::union::UnionSerializer;
Literal: super::type_serializers::literal::LiteralSerializer;
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
}
}

Expand All @@ -150,11 +150,15 @@ impl CombinedSerializer {
Some("function-plain") => {
// `function` is a special case, not included in `find_serializer` since it means something
// different in `schema.type`
return super::type_serializers::function::FunctionPlainSerializer::new_combined(ser_schema)
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
return super::type_serializers::function::FunctionPlainSerializer::build(
ser_schema,
config,
build_context,
)
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
}
Some("function-wrap") => {
return super::type_serializers::function::FunctionWrapSerializer::new_combined(
return super::type_serializers::function::FunctionWrapSerializer::build(
ser_schema,
config,
build_context,
Expand Down
Loading