Skip to content

separate model fields validator from typed dict validator #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 2, 2023
144 changes: 124 additions & 20 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2609,7 +2609,6 @@ class TypedDictField(TypedDict, total=False):
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
serialization_alias: str
serialization_exclude: bool # default: False
frozen: bool
metadata: Any


Expand All @@ -2620,7 +2619,6 @@ def typed_dict_field(
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
serialization_alias: str | None = None,
serialization_exclude: bool | None = None,
frozen: bool | None = None,
metadata: Any = None,
) -> TypedDictField:
"""
Expand All @@ -2638,7 +2636,6 @@ def typed_dict_field(
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
serialization_exclude: Whether to exclude the field when serializing
frozen: Whether the field is frozen
metadata: Any other information you want to include with the schema, not used by pydantic-core
"""
return dict_not_none(
Expand All @@ -2648,7 +2645,6 @@ def typed_dict_field(
validation_alias=validation_alias,
serialization_alias=serialization_alias,
serialization_exclude=serialization_exclude,
frozen=frozen,
metadata=metadata,
)

Expand All @@ -2659,12 +2655,10 @@ class TypedDictSchema(TypedDict, total=False):
computed_fields: List[ComputedField]
strict: bool
extra_validator: CoreSchema
return_fields_set: bool
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
extra_behavior: ExtraBehavior
total: bool # default: True
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
from_attributes: bool
ref: str
metadata: Any
serialization: SerSchema
Expand All @@ -2676,11 +2670,9 @@ def typed_dict_schema(
computed_fields: list[ComputedField] | None = None,
strict: bool | None = None,
extra_validator: CoreSchema | None = None,
return_fields_set: bool | None = None,
extra_behavior: ExtraBehavior | None = None,
total: bool | None = None,
populate_by_name: bool | None = None,
from_attributes: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
Expand All @@ -2703,13 +2695,11 @@ def typed_dict_schema(
computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model
strict: Whether the typed dict is strict
extra_validator: The extra validator to use for the typed dict
return_fields_set: Whether the typed dict should return a fields set
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
extra_behavior: The extra behavior to use for the typed dict
total: Whether the typed dict is total
populate_by_name: Whether the typed dict should populate by name
from_attributes: Whether the typed dict should be populated from attributes
serialization: Custom serialization schema
"""
return dict_not_none(
Expand All @@ -2718,10 +2708,124 @@ def typed_dict_schema(
computed_fields=computed_fields,
strict=strict,
extra_validator=extra_validator,
return_fields_set=return_fields_set,
extra_behavior=extra_behavior,
total=total,
populate_by_name=populate_by_name,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class ModelField(TypedDict, total=False):
type: Required[Literal['model-field']]
schema: Required[CoreSchema]
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
serialization_alias: str
serialization_exclude: bool # default: False
frozen: bool
metadata: Any


def model_field(
schema: CoreSchema,
*,
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
serialization_alias: str | None = None,
serialization_exclude: bool | None = None,
frozen: bool | None = None,
metadata: Any = None,
) -> ModelField:
"""
Returns a schema for a model field, e.g.:

```py
from pydantic_core import core_schema

field = core_schema.model_field(schema=core_schema.int_schema())
```

Args:
schema: The schema to use for the field
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
serialization_exclude: Whether to exclude the field when serializing
frozen: Whether the field is frozen
metadata: Any other information you want to include with the schema, not used by pydantic-core
"""
return dict_not_none(
type='model-field',
schema=schema,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
serialization_exclude=serialization_exclude,
frozen=frozen,
metadata=metadata,
)


class ModelFieldsSchema(TypedDict, total=False):
type: Required[Literal['model-fields']]
fields: Required[Dict[str, ModelField]]
computed_fields: List[ComputedField]
strict: bool
extra_validator: CoreSchema
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
extra_behavior: ExtraBehavior
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
from_attributes: bool
ref: str
metadata: Any
serialization: SerSchema


def model_fields_schema(
fields: Dict[str, ModelField],
*,
computed_fields: list[ComputedField] | None = None,
strict: bool | None = None,
extra_validator: CoreSchema | None = None,
extra_behavior: ExtraBehavior | None = None,
populate_by_name: bool | None = None,
from_attributes: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> ModelFieldsSchema:
"""
Returns a schema that matches a typed dict, e.g.:

```py
from pydantic_core import SchemaValidator, core_schema

wrapper_schema = core_schema.model_fields_schema(
{'a': core_schema.model_field(core_schema.str_schema())}
)
v = SchemaValidator(wrapper_schema)
print(v.validate_python({'a': 'hello'}))
#> ({'a': 'hello'}, None, {'a'})
```

Args:
fields: The fields to use for the typed dict
computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model
strict: Whether the typed dict is strict
extra_validator: The extra validator to use for the typed dict
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
extra_behavior: The extra behavior to use for the typed dict
populate_by_name: Whether the typed dict should populate by name
from_attributes: Whether the typed dict should be populated from attributes
serialization: Custom serialization schema
"""
return dict_not_none(
type='model-fields',
fields=fields,
computed_fields=computed_fields,
strict=strict,
extra_validator=extra_validator,
extra_behavior=extra_behavior,
populate_by_name=populate_by_name,
from_attributes=from_attributes,
ref=ref,
metadata=metadata,
Expand Down Expand Up @@ -2768,14 +2872,13 @@ def model_schema(
from pydantic_core import CoreConfig, SchemaValidator, core_schema

class MyModel:
__slots__ = '__dict__', '__pydantic_fields_set__'
__slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__'

schema = core_schema.model_schema(
cls=MyModel,
config=CoreConfig(str_max_length=5),
schema=core_schema.typed_dict_schema(
fields={'a': core_schema.typed_dict_field(core_schema.str_schema())},
return_fields_set=True,
schema=core_schema.model_fields_schema(
fields={'a': core_schema.model_field(core_schema.str_schema())},
),
)
v = SchemaValidator(schema)
Expand Down Expand Up @@ -3236,16 +3339,15 @@ def json_schema(
```py
from pydantic_core import SchemaValidator, core_schema

dict_schema = core_schema.typed_dict_schema(
dict_schema = core_schema.model_fields_schema(
{
'field_a': core_schema.typed_dict_field(core_schema.str_schema()),
'field_b': core_schema.typed_dict_field(core_schema.bool_schema()),
'field_a': core_schema.model_field(core_schema.str_schema()),
'field_b': core_schema.model_field(core_schema.bool_schema()),
},
return_fields_set=True,
)

class MyModel:
__slots__ = '__dict__', '__pydantic_fields_set__'
__slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__'
field_a: str
field_b: bool

Expand Down Expand Up @@ -3497,6 +3599,7 @@ def definition_reference_schema(
ChainSchema,
LaxOrStrictSchema,
TypedDictSchema,
ModelFieldsSchema,
ModelSchema,
DataclassArgsSchema,
DataclassSchema,
Expand Down Expand Up @@ -3548,6 +3651,7 @@ def definition_reference_schema(
'chain',
'lax-or-strict',
'typed-dict',
'model-fields',
'model',
'dataclass-args',
'dataclass',
Expand Down
2 changes: 1 addition & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_dict()
}

fn validate_typed_dict(&'a self, strict: bool, _from_attributes: bool) -> ValResult<GenericMapping<'a>> {
fn validate_model_fields(&'a self, strict: bool, _from_attributes: bool) -> ValResult<GenericMapping<'a>> {
self.validate_dict(strict)
}

Expand Down
4 changes: 2 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ impl<'a> Input<'a> for PyAny {
}
}

fn validate_typed_dict(&'a self, strict: bool, from_attributes: bool) -> ValResult<GenericMapping<'a>> {
fn validate_model_fields(&'a self, strict: bool, from_attributes: bool) -> ValResult<GenericMapping<'a>> {
if from_attributes {
// if from_attributes, first try a dict, then mapping then from_attributes
if let Ok(dict) = self.downcast::<PyDict>() {
Expand All @@ -378,7 +378,7 @@ impl<'a> Input<'a> for PyAny {
Err(ValError::new(ErrorType::DictAttributesType, self))
}
} else {
// otherwise we just call back to lax_dict if from_mapping is allowed, not there error in this
// otherwise we just call back to validate_dict if from_mapping is allowed, note that errors in this
// case (correctly) won't hint about from_attributes
self.validate_dict(strict)
}
Expand Down
1 change: 1 addition & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ combined_serializer! {
super::type_serializers::function::FunctionAfterSerializerBuilder;
super::type_serializers::function::FunctionPlainSerializerBuilder;
super::type_serializers::function::FunctionWrapSerializerBuilder;
super::type_serializers::model::ModelFieldsBuilder;
}
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
// `find_serializer` so they can be used via a `type` str.
Expand Down
57 changes: 55 additions & 2 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,72 @@ use std::borrow::Cow;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyType};
use pyo3::types::{PyDict, PyString, PyType};

use ahash::AHashMap;

use crate::build_context::BuildContext;
use crate::build_tools::SchemaDict;
use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict};
use crate::serializers::computed_fields::ComputedFields;
use crate::serializers::extra::SerCheck;
use crate::serializers::filter::SchemaFilter;
use crate::serializers::infer::{infer_serialize, infer_to_python};
use crate::serializers::ob_type::ObType;
use crate::serializers::type_serializers::typed_dict::{TypedDictField, TypedDictSerializer};

use super::{
infer_json_key, infer_json_key_known, object_to_dict, py_err_se_err, BuildSerializer, CombinedSerializer, Extra,
TypeSerializer,
};

pub struct ModelFieldsBuilder;

impl BuildSerializer for ModelFieldsBuilder {
const EXPECTED_TYPE: &'static str = "model-fields";

fn build(
schema: &PyDict,
config: Option<&PyDict>,
build_context: &mut BuildContext<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();

let include_extra = matches!(
ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?,
ExtraBehavior::Allow
);

let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: AHashMap<String, TypedDictField> = AHashMap::with_capacity(fields_dict.len());
let mut exclude: Vec<Py<PyString>> = Vec::with_capacity(fields_dict.len());

for (key, value) in fields_dict.iter() {
let key_py: &PyString = key.downcast()?;
let key: String = key_py.extract()?;
let field_info: &PyDict = value.downcast()?;

let key_py: Py<PyString> = key_py.into_py(py);

if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
exclude.push(key_py.clone_ref(py));
} else {
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;

let schema = field_info.get_as_req(intern!(py, "schema"))?;
let serializer = CombinedSerializer::build(schema, config, build_context)
.map_err(|e| py_error_type!("Field `{}`:\n {}", key, e))?;

fields.insert(key, TypedDictField::new(py, key_py, alias, serializer, true));
}
}

let filter = SchemaFilter::from_vec_hash(py, exclude)?;
let computed_fields = ComputedFields::new(schema)?;

Ok(TypedDictSerializer::new(fields, include_extra, filter, computed_fields).into())
}
}

#[derive(Debug, Clone)]
pub struct ModelSerializer {
class: Py<PyType>,
Expand Down
3 changes: 3 additions & 0 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ mod lax_or_strict;
mod list;
mod literal;
mod model;
mod model_fields;
mod none;
mod nullable;
mod set;
Expand Down Expand Up @@ -378,6 +379,7 @@ pub fn build_validator<'a>(
nullable::NullableValidator,
// model classes
model::ModelValidator,
model_fields::ModelFieldsValidator,
// dataclasses
dataclass::DataclassArgsValidator,
dataclass::DataclassValidator,
Expand Down Expand Up @@ -505,6 +507,7 @@ pub enum CombinedValidator {
Nullable(nullable::NullableValidator),
// create new model classes
Model(model::ModelValidator),
ModelFields(model_fields::ModelFieldsValidator),
// dataclasses
DataclassArgs(dataclass::DataclassArgsValidator),
Dataclass(dataclass::DataclassValidator),
Expand Down
Loading