Skip to content

Commit 817af72

Browse files
authored
separate model fields validator from typed dict validator (#568)
1 parent 9f69da1 commit 817af72

37 files changed

+2810
-1581
lines changed

pydantic_core/core_schema.py

Lines changed: 124 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,7 +2609,6 @@ class TypedDictField(TypedDict, total=False):
26092609
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
26102610
serialization_alias: str
26112611
serialization_exclude: bool # default: False
2612-
frozen: bool
26132612
metadata: Any
26142613

26152614

@@ -2620,7 +2619,6 @@ def typed_dict_field(
26202619
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
26212620
serialization_alias: str | None = None,
26222621
serialization_exclude: bool | None = None,
2623-
frozen: bool | None = None,
26242622
metadata: Any = None,
26252623
) -> TypedDictField:
26262624
"""
@@ -2638,7 +2636,6 @@ def typed_dict_field(
26382636
validation_alias: The alias(es) to use to find the field in the validation data
26392637
serialization_alias: The alias to use as a key when serializing
26402638
serialization_exclude: Whether to exclude the field when serializing
2641-
frozen: Whether the field is frozen
26422639
metadata: Any other information you want to include with the schema, not used by pydantic-core
26432640
"""
26442641
return dict_not_none(
@@ -2648,7 +2645,6 @@ def typed_dict_field(
26482645
validation_alias=validation_alias,
26492646
serialization_alias=serialization_alias,
26502647
serialization_exclude=serialization_exclude,
2651-
frozen=frozen,
26522648
metadata=metadata,
26532649
)
26542650

@@ -2659,12 +2655,10 @@ class TypedDictSchema(TypedDict, total=False):
26592655
computed_fields: List[ComputedField]
26602656
strict: bool
26612657
extra_validator: CoreSchema
2662-
return_fields_set: bool
26632658
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
26642659
extra_behavior: ExtraBehavior
26652660
total: bool # default: True
26662661
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
2667-
from_attributes: bool
26682662
ref: str
26692663
metadata: Any
26702664
serialization: SerSchema
@@ -2676,11 +2670,9 @@ def typed_dict_schema(
26762670
computed_fields: list[ComputedField] | None = None,
26772671
strict: bool | None = None,
26782672
extra_validator: CoreSchema | None = None,
2679-
return_fields_set: bool | None = None,
26802673
extra_behavior: ExtraBehavior | None = None,
26812674
total: bool | None = None,
26822675
populate_by_name: bool | None = None,
2683-
from_attributes: bool | None = None,
26842676
ref: str | None = None,
26852677
metadata: Any = None,
26862678
serialization: SerSchema | None = None,
@@ -2703,13 +2695,11 @@ def typed_dict_schema(
27032695
computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model
27042696
strict: Whether the typed dict is strict
27052697
extra_validator: The extra validator to use for the typed dict
2706-
return_fields_set: Whether the typed dict should return a fields set
27072698
ref: optional unique identifier of the schema, used to reference the schema in other places
27082699
metadata: Any other information you want to include with the schema, not used by pydantic-core
27092700
extra_behavior: The extra behavior to use for the typed dict
27102701
total: Whether the typed dict is total
27112702
populate_by_name: Whether the typed dict should populate by name
2712-
from_attributes: Whether the typed dict should be populated from attributes
27132703
serialization: Custom serialization schema
27142704
"""
27152705
return dict_not_none(
@@ -2718,10 +2708,124 @@ def typed_dict_schema(
27182708
computed_fields=computed_fields,
27192709
strict=strict,
27202710
extra_validator=extra_validator,
2721-
return_fields_set=return_fields_set,
27222711
extra_behavior=extra_behavior,
27232712
total=total,
27242713
populate_by_name=populate_by_name,
2714+
ref=ref,
2715+
metadata=metadata,
2716+
serialization=serialization,
2717+
)
2718+
2719+
2720+
class ModelField(TypedDict, total=False):
2721+
type: Required[Literal['model-field']]
2722+
schema: Required[CoreSchema]
2723+
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
2724+
serialization_alias: str
2725+
serialization_exclude: bool # default: False
2726+
frozen: bool
2727+
metadata: Any
2728+
2729+
2730+
def model_field(
2731+
schema: CoreSchema,
2732+
*,
2733+
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
2734+
serialization_alias: str | None = None,
2735+
serialization_exclude: bool | None = None,
2736+
frozen: bool | None = None,
2737+
metadata: Any = None,
2738+
) -> ModelField:
2739+
"""
2740+
Returns a schema for a model field, e.g.:
2741+
2742+
```py
2743+
from pydantic_core import core_schema
2744+
2745+
field = core_schema.model_field(schema=core_schema.int_schema())
2746+
```
2747+
2748+
Args:
2749+
schema: The schema to use for the field
2750+
validation_alias: The alias(es) to use to find the field in the validation data
2751+
serialization_alias: The alias to use as a key when serializing
2752+
serialization_exclude: Whether to exclude the field when serializing
2753+
frozen: Whether the field is frozen
2754+
metadata: Any other information you want to include with the schema, not used by pydantic-core
2755+
"""
2756+
return dict_not_none(
2757+
type='model-field',
2758+
schema=schema,
2759+
validation_alias=validation_alias,
2760+
serialization_alias=serialization_alias,
2761+
serialization_exclude=serialization_exclude,
2762+
frozen=frozen,
2763+
metadata=metadata,
2764+
)
2765+
2766+
2767+
class ModelFieldsSchema(TypedDict, total=False):
2768+
type: Required[Literal['model-fields']]
2769+
fields: Required[Dict[str, ModelField]]
2770+
computed_fields: List[ComputedField]
2771+
strict: bool
2772+
extra_validator: CoreSchema
2773+
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
2774+
extra_behavior: ExtraBehavior
2775+
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
2776+
from_attributes: bool
2777+
ref: str
2778+
metadata: Any
2779+
serialization: SerSchema
2780+
2781+
2782+
def model_fields_schema(
2783+
fields: Dict[str, ModelField],
2784+
*,
2785+
computed_fields: list[ComputedField] | None = None,
2786+
strict: bool | None = None,
2787+
extra_validator: CoreSchema | None = None,
2788+
extra_behavior: ExtraBehavior | None = None,
2789+
populate_by_name: bool | None = None,
2790+
from_attributes: bool | None = None,
2791+
ref: str | None = None,
2792+
metadata: Any = None,
2793+
serialization: SerSchema | None = None,
2794+
) -> ModelFieldsSchema:
2795+
"""
2796+
Returns a schema that matches a typed dict, e.g.:
2797+
2798+
```py
2799+
from pydantic_core import SchemaValidator, core_schema
2800+
2801+
wrapper_schema = core_schema.model_fields_schema(
2802+
{'a': core_schema.model_field(core_schema.str_schema())}
2803+
)
2804+
v = SchemaValidator(wrapper_schema)
2805+
print(v.validate_python({'a': 'hello'}))
2806+
#> ({'a': 'hello'}, None, {'a'})
2807+
```
2808+
2809+
Args:
2810+
fields: The fields to use for the typed dict
2811+
computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model
2812+
strict: Whether the typed dict is strict
2813+
extra_validator: The extra validator to use for the typed dict
2814+
ref: optional unique identifier of the schema, used to reference the schema in other places
2815+
metadata: Any other information you want to include with the schema, not used by pydantic-core
2816+
extra_behavior: The extra behavior to use for the typed dict
2817+
populate_by_name: Whether the typed dict should populate by name
2818+
from_attributes: Whether the typed dict should be populated from attributes
2819+
serialization: Custom serialization schema
2820+
"""
2821+
return dict_not_none(
2822+
type='model-fields',
2823+
fields=fields,
2824+
computed_fields=computed_fields,
2825+
strict=strict,
2826+
extra_validator=extra_validator,
2827+
extra_behavior=extra_behavior,
2828+
populate_by_name=populate_by_name,
27252829
from_attributes=from_attributes,
27262830
ref=ref,
27272831
metadata=metadata,
@@ -2768,14 +2872,13 @@ def model_schema(
27682872
from pydantic_core import CoreConfig, SchemaValidator, core_schema
27692873
27702874
class MyModel:
2771-
__slots__ = '__dict__', '__pydantic_fields_set__'
2875+
__slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__'
27722876
27732877
schema = core_schema.model_schema(
27742878
cls=MyModel,
27752879
config=CoreConfig(str_max_length=5),
2776-
schema=core_schema.typed_dict_schema(
2777-
fields={'a': core_schema.typed_dict_field(core_schema.str_schema())},
2778-
return_fields_set=True,
2880+
schema=core_schema.model_fields_schema(
2881+
fields={'a': core_schema.model_field(core_schema.str_schema())},
27792882
),
27802883
)
27812884
v = SchemaValidator(schema)
@@ -3236,16 +3339,15 @@ def json_schema(
32363339
```py
32373340
from pydantic_core import SchemaValidator, core_schema
32383341
3239-
dict_schema = core_schema.typed_dict_schema(
3342+
dict_schema = core_schema.model_fields_schema(
32403343
{
3241-
'field_a': core_schema.typed_dict_field(core_schema.str_schema()),
3242-
'field_b': core_schema.typed_dict_field(core_schema.bool_schema()),
3344+
'field_a': core_schema.model_field(core_schema.str_schema()),
3345+
'field_b': core_schema.model_field(core_schema.bool_schema()),
32433346
},
3244-
return_fields_set=True,
32453347
)
32463348
32473349
class MyModel:
3248-
__slots__ = '__dict__', '__pydantic_fields_set__'
3350+
__slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__'
32493351
field_a: str
32503352
field_b: bool
32513353
@@ -3497,6 +3599,7 @@ def definition_reference_schema(
34973599
ChainSchema,
34983600
LaxOrStrictSchema,
34993601
TypedDictSchema,
3602+
ModelFieldsSchema,
35003603
ModelSchema,
35013604
DataclassArgsSchema,
35023605
DataclassSchema,
@@ -3548,6 +3651,7 @@ def definition_reference_schema(
35483651
'chain',
35493652
'lax-or-strict',
35503653
'typed-dict',
3654+
'model-fields',
35513655
'model',
35523656
'dataclass-args',
35533657
'dataclass',

src/input/input_abstract.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
162162
self.strict_dict()
163163
}
164164

165-
fn validate_typed_dict(&'a self, strict: bool, _from_attributes: bool) -> ValResult<GenericMapping<'a>> {
165+
fn validate_model_fields(&'a self, strict: bool, _from_attributes: bool) -> ValResult<GenericMapping<'a>> {
166166
self.validate_dict(strict)
167167
}
168168

src/input/input_python.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ impl<'a> Input<'a> for PyAny {
354354
}
355355
}
356356

357-
fn validate_typed_dict(&'a self, strict: bool, from_attributes: bool) -> ValResult<GenericMapping<'a>> {
357+
fn validate_model_fields(&'a self, strict: bool, from_attributes: bool) -> ValResult<GenericMapping<'a>> {
358358
if from_attributes {
359359
// if from_attributes, first try a dict, then mapping then from_attributes
360360
if let Ok(dict) = self.downcast::<PyDict>() {
@@ -378,7 +378,7 @@ impl<'a> Input<'a> for PyAny {
378378
Err(ValError::new(ErrorType::DictAttributesType, self))
379379
}
380380
} else {
381-
// otherwise we just call back to lax_dict if from_mapping is allowed, not there error in this
381+
// otherwise we just call back to validate_dict if from_mapping is allowed, note that errors in this
382382
// case (correctly) won't hint about from_attributes
383383
self.validate_dict(strict)
384384
}

src/serializers/shared.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ combined_serializer! {
9898
super::type_serializers::function::FunctionAfterSerializerBuilder;
9999
super::type_serializers::function::FunctionPlainSerializerBuilder;
100100
super::type_serializers::function::FunctionWrapSerializerBuilder;
101+
super::type_serializers::model::ModelFieldsBuilder;
101102
}
102103
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
103104
// `find_serializer` so they can be used via a `type` str.

src/serializers/type_serializers/model.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,72 @@ use std::borrow::Cow;
22

33
use pyo3::intern;
44
use pyo3::prelude::*;
5-
use pyo3::types::{PyDict, PyType};
5+
use pyo3::types::{PyDict, PyString, PyType};
6+
7+
use ahash::AHashMap;
68

79
use crate::build_context::BuildContext;
8-
use crate::build_tools::SchemaDict;
10+
use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict};
11+
use crate::serializers::computed_fields::ComputedFields;
912
use crate::serializers::extra::SerCheck;
13+
use crate::serializers::filter::SchemaFilter;
1014
use crate::serializers::infer::{infer_serialize, infer_to_python};
1115
use crate::serializers::ob_type::ObType;
16+
use crate::serializers::type_serializers::typed_dict::{TypedDictField, TypedDictSerializer};
1217

1318
use super::{
1419
infer_json_key, infer_json_key_known, object_to_dict, py_err_se_err, BuildSerializer, CombinedSerializer, Extra,
1520
TypeSerializer,
1621
};
1722

23+
pub struct ModelFieldsBuilder;
24+
25+
impl BuildSerializer for ModelFieldsBuilder {
26+
const EXPECTED_TYPE: &'static str = "model-fields";
27+
28+
fn build(
29+
schema: &PyDict,
30+
config: Option<&PyDict>,
31+
build_context: &mut BuildContext<CombinedSerializer>,
32+
) -> PyResult<CombinedSerializer> {
33+
let py = schema.py();
34+
35+
let include_extra = matches!(
36+
ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?,
37+
ExtraBehavior::Allow
38+
);
39+
40+
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
41+
let mut fields: AHashMap<String, TypedDictField> = AHashMap::with_capacity(fields_dict.len());
42+
let mut exclude: Vec<Py<PyString>> = Vec::with_capacity(fields_dict.len());
43+
44+
for (key, value) in fields_dict.iter() {
45+
let key_py: &PyString = key.downcast()?;
46+
let key: String = key_py.extract()?;
47+
let field_info: &PyDict = value.downcast()?;
48+
49+
let key_py: Py<PyString> = key_py.into_py(py);
50+
51+
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
52+
exclude.push(key_py.clone_ref(py));
53+
} else {
54+
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
55+
56+
let schema = field_info.get_as_req(intern!(py, "schema"))?;
57+
let serializer = CombinedSerializer::build(schema, config, build_context)
58+
.map_err(|e| py_error_type!("Field `{}`:\n {}", key, e))?;
59+
60+
fields.insert(key, TypedDictField::new(py, key_py, alias, serializer, true));
61+
}
62+
}
63+
64+
let filter = SchemaFilter::from_vec_hash(py, exclude)?;
65+
let computed_fields = ComputedFields::new(schema)?;
66+
67+
Ok(TypedDictSerializer::new(fields, include_extra, filter, computed_fields).into())
68+
}
69+
}
70+
1871
#[derive(Debug, Clone)]
1972
pub struct ModelSerializer {
2073
class: Py<PyType>,

src/validators/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ mod lax_or_strict;
3939
mod list;
4040
mod literal;
4141
mod model;
42+
mod model_fields;
4243
mod none;
4344
mod nullable;
4445
mod set;
@@ -378,6 +379,7 @@ pub fn build_validator<'a>(
378379
nullable::NullableValidator,
379380
// model classes
380381
model::ModelValidator,
382+
model_fields::ModelFieldsValidator,
381383
// dataclasses
382384
dataclass::DataclassArgsValidator,
383385
dataclass::DataclassValidator,
@@ -505,6 +507,7 @@ pub enum CombinedValidator {
505507
Nullable(nullable::NullableValidator),
506508
// create new model classes
507509
Model(model::ModelValidator),
510+
ModelFields(model_fields::ModelFieldsValidator),
508511
// dataclasses
509512
DataclassArgs(dataclass::DataclassArgsValidator),
510513
Dataclass(dataclass::DataclassValidator),

0 commit comments

Comments
 (0)