Skip to content

Commit 288dd1c

Browse files
Support revalidation of parametrized generics (#1489)
1 parent 1ced3e6 commit 288dd1c

File tree

3 files changed

+72
-8
lines changed

3 files changed

+72
-8
lines changed

python/pydantic_core/core_schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3056,6 +3056,7 @@ def model_fields_schema(
30563056
class ModelSchema(TypedDict, total=False):
30573057
type: Required[Literal['model']]
30583058
cls: Required[Type[Any]]
3059+
generic_origin: Type[Any]
30593060
schema: Required[CoreSchema]
30603061
custom_init: bool
30613062
root_model: bool
@@ -3074,6 +3075,7 @@ def model_schema(
30743075
cls: Type[Any],
30753076
schema: CoreSchema,
30763077
*,
3078+
generic_origin: Type[Any] | None = None,
30773079
custom_init: bool | None = None,
30783080
root_model: bool | None = None,
30793081
post_init: str | None = None,
@@ -3120,6 +3122,8 @@ class MyModel:
31203122
Args:
31213123
cls: The class to use for the model
31223124
schema: The schema to use for the model
3125+
generic_origin: The origin type used for this model, if it's a parametrized generic. Ex,
3126+
if this model schema represents `SomeModel[int]`, generic_origin is `SomeModel`
31233127
custom_init: Whether the model has a custom init method
31243128
root_model: Whether the model is a `RootModel`
31253129
post_init: The call after init to use for the model
@@ -3136,6 +3140,7 @@ class MyModel:
31363140
return _dict_not_none(
31373141
type='model',
31383142
cls=cls,
3143+
generic_origin=generic_origin,
31393144
schema=schema,
31403145
custom_init=custom_init,
31413146
root_model=root_model,
@@ -3289,6 +3294,7 @@ def dataclass_args_schema(
32893294
class DataclassSchema(TypedDict, total=False):
32903295
type: Required[Literal['dataclass']]
32913296
cls: Required[Type[Any]]
3297+
generic_origin: Type[Any]
32923298
schema: Required[CoreSchema]
32933299
fields: Required[List[str]]
32943300
cls_name: str
@@ -3308,6 +3314,7 @@ def dataclass_schema(
33083314
schema: CoreSchema,
33093315
fields: List[str],
33103316
*,
3317+
generic_origin: Type[Any] | None = None,
33113318
cls_name: str | None = None,
33123319
post_init: bool | None = None,
33133320
revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None,
@@ -3328,6 +3335,8 @@ def dataclass_schema(
33283335
schema: The schema to use for the dataclass fields
33293336
fields: Fields of the dataclass, this is used in serialization and in validation during re-validation
33303337
and while validating assignment
3338+
generic_origin: The origin type used for this dataclass, if it's a parametrized generic. Ex,
3339+
if this model schema represents `SomeDataclass[int]`, generic_origin is `SomeDataclass`
33313340
cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`)
33323341
post_init: Whether to call `__post_init__` after validation
33333342
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
@@ -3343,6 +3352,7 @@ def dataclass_schema(
33433352
return _dict_not_none(
33443353
type='dataclass',
33453354
cls=cls,
3355+
generic_origin=generic_origin,
33463356
fields=fields,
33473357
cls_name=cls_name,
33483358
schema=schema,

src/validators/dataclass.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ pub struct DataclassValidator {
438438
strict: bool,
439439
validator: Box<CombinedValidator>,
440440
class: Py<PyType>,
441+
generic_origin: Option<Py<PyType>>,
441442
fields: Vec<Py<PyString>>,
442443
post_init: Option<Py<PyString>>,
443444
revalidate: Revalidate,
@@ -461,6 +462,7 @@ impl BuildValidator for DataclassValidator {
461462
let config = config.as_ref();
462463

463464
let class = schema.get_as_req::<Bound<'_, PyType>>(intern!(py, "cls"))?;
465+
let generic_origin: Option<Bound<'_, PyType>> = schema.get_as(intern!(py, "generic_origin"))?;
464466
let name = match schema.get_as_req::<String>(intern!(py, "cls_name")) {
465467
Ok(name) => name,
466468
Err(_) => class.getattr(intern!(py, "__name__"))?.extract()?,
@@ -480,6 +482,7 @@ impl BuildValidator for DataclassValidator {
480482
strict: is_strict(schema, config)?,
481483
validator: Box::new(validator),
482484
class: class.into(),
485+
generic_origin: generic_origin.map(std::convert::Into::into),
483486
fields,
484487
post_init,
485488
revalidate: Revalidate::from_str(
@@ -496,7 +499,11 @@ impl BuildValidator for DataclassValidator {
496499
}
497500
}
498501

499-
impl_py_gc_traverse!(DataclassValidator { class, validator });
502+
impl_py_gc_traverse!(DataclassValidator {
503+
class,
504+
generic_origin,
505+
validator
506+
});
500507

501508
impl Validator for DataclassValidator {
502509
fn validate<'py>(
@@ -510,10 +517,30 @@ impl Validator for DataclassValidator {
510517
return self.validate_init(py, self_instance, input, state);
511518
}
512519

513-
// same logic as on models
520+
// same logic as on models, see more explicit comment in model.rs
514521
let class = self.class.bind(py);
515-
if let Some(py_input) = input_as_python_instance(input, class) {
516-
if self.revalidate.should_revalidate(py_input, class) {
522+
let generic_origin_class = self.generic_origin.as_ref().map(|go| go.bind(py));
523+
524+
let (py_instance_input, force_revalidate): (Option<&Bound<'_, PyAny>>, bool) =
525+
match input_as_python_instance(input, class) {
526+
Some(x) => (Some(x), false),
527+
None => {
528+
// if the model has a generic origin, we allow input data to be instances of the generic origin rather than the class,
529+
// as cases like isinstance(SomeModel[Int], SomeModel[Any]) fail the isinstance check, but are valid, we just have to enforce
530+
// that the data is revalidated, hence we set force_revalidate to true
531+
if generic_origin_class.is_some() {
532+
match input_as_python_instance(input, generic_origin_class.unwrap()) {
533+
Some(x) => (Some(x), true),
534+
None => (None, false),
535+
}
536+
} else {
537+
(None, false)
538+
}
539+
}
540+
};
541+
542+
if let Some(py_input) = py_instance_input {
543+
if self.revalidate.should_revalidate(py_input, class) || force_revalidate {
517544
let input_dict = self.dataclass_to_dict(py_input)?;
518545
let val_output = self.validator.validate(py, input_dict.as_any(), state)?;
519546
let dc = create_class(self.class.bind(py))?;

src/validators/model.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pub struct ModelValidator {
5555
revalidate: Revalidate,
5656
validator: Box<CombinedValidator>,
5757
class: Py<PyType>,
58+
generic_origin: Option<Py<PyType>>,
5859
post_init: Option<Py<PyString>>,
5960
frozen: bool,
6061
custom_init: bool,
@@ -76,6 +77,7 @@ impl BuildValidator for ModelValidator {
7677
let config = schema.get_as(intern!(py, "config"))?;
7778

7879
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
80+
let generic_origin: Option<Bound<'_, PyType>> = schema.get_as(intern!(py, "generic_origin"))?;
7981
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
8082
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
8183
let name = class.getattr(intern!(py, "__name__"))?.extract()?;
@@ -93,6 +95,7 @@ impl BuildValidator for ModelValidator {
9395
)?,
9496
validator: Box::new(validator),
9597
class: class.into(),
98+
generic_origin: generic_origin.map(std::convert::Into::into),
9699
post_init: schema.get_as(intern!(py, "post_init"))?,
97100
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
98101
custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false),
@@ -105,7 +108,11 @@ impl BuildValidator for ModelValidator {
105108
}
106109
}
107110

108-
impl_py_gc_traverse!(ModelValidator { class, validator });
111+
impl_py_gc_traverse!(ModelValidator {
112+
class,
113+
generic_origin,
114+
validator
115+
});
109116

110117
impl Validator for ModelValidator {
111118
fn validate<'py>(
@@ -119,13 +126,33 @@ impl Validator for ModelValidator {
119126
return self.validate_init(py, self_instance, input, state);
120127
}
121128

129+
let class = self.class.bind(py);
130+
let generic_origin_class = self.generic_origin.as_ref().map(|go| go.bind(py));
131+
122132
// if we're in strict mode, we require an exact instance of the class (from python, with JSON an object is ok)
123133
// if we're not in strict mode, instances subclasses are okay, as well as dicts, mappings, from attributes etc.
124134
// if the input is an instance of the class, we "revalidate" it - e.g. we extract and reuse `__pydantic_fields_set__`
125135
// but use from attributes to create a new instance of the model field type
126-
let class = self.class.bind(py);
127-
if let Some(py_input) = input_as_python_instance(input, class) {
128-
if self.revalidate.should_revalidate(py_input, class) {
136+
let (py_instance_input, force_revalidate): (Option<&Bound<'_, PyAny>>, bool) =
137+
match input_as_python_instance(input, class) {
138+
Some(x) => (Some(x), false),
139+
None => {
140+
// if the model has a generic origin, we allow input data to be instances of the generic origin rather than the class,
141+
// as cases like isinstance(SomeModel[Int], SomeModel[Any]) fail the isinstance check, but are valid, we just have to enforce
142+
// that the data is revalidated, hence we set force_revalidate to true
143+
if generic_origin_class.is_some() {
144+
match input_as_python_instance(input, generic_origin_class.unwrap()) {
145+
Some(x) => (Some(x), true),
146+
None => (None, false),
147+
}
148+
} else {
149+
(None, false)
150+
}
151+
}
152+
};
153+
154+
if let Some(py_input) = py_instance_input {
155+
if self.revalidate.should_revalidate(py_input, class) || force_revalidate {
129156
let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?;
130157
if self.root_model {
131158
let inner_input = py_input.getattr(intern!(py, ROOT_FIELD))?;

0 commit comments

Comments
 (0)