Skip to content

Commit f2368a4

Browse files
committed
Add custom_init logic
This reverts commit 4c4fe8a.
1 parent 4f1ce33 commit f2368a4

File tree

6 files changed

+172
-3
lines changed

6 files changed

+172
-3
lines changed

pydantic_core/core_schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2737,6 +2737,7 @@ class ModelSchema(TypedDict, total=False):
27372737
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
27382738
strict: bool
27392739
frozen: bool
2740+
custom_init: bool
27402741
config: CoreConfig
27412742
ref: str
27422743
metadata: Any
@@ -2751,6 +2752,7 @@ def model_schema(
27512752
revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None,
27522753
strict: bool | None = None,
27532754
frozen: bool | None = None,
2755+
custom_init: bool | None = None,
27542756
config: CoreConfig | None = None,
27552757
ref: str | None = None,
27562758
metadata: Any = None,
@@ -2791,6 +2793,7 @@ class MyModel:
27912793
should re-validate defaults to config.revalidate_instances, else 'never'
27922794
strict: Whether the model is strict
27932795
frozen: Whether the model is frozen
2796+
custom_init: Whether the model has a custom init method
27942797
config: The config to use for the model
27952798
ref: optional unique identifier of the schema, used to reference the schema in other places
27962799
metadata: Any other information you want to include with the schema, not used by pydantic-core
@@ -2804,6 +2807,7 @@ class MyModel:
28042807
revalidate_instances=revalidate_instances,
28052808
strict=strict,
28062809
frozen=frozen,
2810+
custom_init=custom_init,
28072811
config=config,
28082812
ref=ref,
28092813
metadata=metadata,

src/argument_markers.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,30 @@ impl ArgsKwargs {
6060
}
6161
}
6262
}
63+
64+
pub(crate) const VALIDATED_DATA_KEY: &str = "validated_data";
65+
66+
#[pyclass(module = "pydantic_core._pydantic_core", frozen, get_all, freelist = 100)]
67+
#[derive(Debug, Clone)]
68+
pub struct ValidatedData {
69+
pub model_dict: PyObject,
70+
pub fields_set: PyObject,
71+
}
72+
73+
impl ValidatedData {
74+
pub(crate) fn new(model_dict: &PyAny, fields_set: &PyAny) -> Self {
75+
Self {
76+
model_dict: model_dict.to_object(model_dict.py()),
77+
fields_set: fields_set.to_object(model_dict.py()),
78+
}
79+
}
80+
}
81+
82+
#[pymethods]
83+
impl ValidatedData {
84+
fn __repr__(&self, py: Python) -> String {
85+
let model_dict = safe_repr(self.model_dict.as_ref(py));
86+
let fields_set = safe_repr(self.fields_set.as_ref(py));
87+
format!("ValidatedData(model_dict={model_dict}, fields_set={fields_set})")
88+
}
89+
}

src/input/input_abstract.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::fmt;
33
use pyo3::prelude::*;
44
use pyo3::types::{PyString, PyType};
55

6+
use crate::argument_markers::ValidatedData;
67
use crate::errors::{InputValue, LocItem, ValResult};
78
use crate::{PyMultiHostUrl, PyUrl};
89

@@ -44,6 +45,11 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
4445
None
4546
}
4647

48+
#[cfg_attr(has_no_coverage, no_coverage)]
49+
fn validated_data(&self) -> Option<ValidatedData> {
50+
None
51+
}
52+
4753
// input_ prefix to differentiate from the function on PyAny
4854
fn input_is_instance(&self, class: &PyAny, json_mask: u8) -> PyResult<bool>;
4955

src/input/input_python.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use pyo3::types::{
1111
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
1212
use pyo3::{ffi, intern, AsPyPointer, PyTypeInfo};
1313

14+
use crate::argument_markers::{ValidatedData, VALIDATED_DATA_KEY};
1415
use crate::build_tools::safe_repr;
1516
use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult};
1617
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};
@@ -102,6 +103,16 @@ impl<'a> Input<'a> for PyAny {
102103
Some(self.getattr(name))
103104
}
104105

106+
#[cfg_attr(has_no_coverage, no_coverage)]
107+
fn validated_data(&self) -> Option<ValidatedData> {
108+
if let Ok(v) = self.get_item(intern!(self.py(), VALIDATED_DATA_KEY)) {
109+
if let Ok(validated_data) = v.extract::<ValidatedData>() {
110+
return Some(validated_data);
111+
}
112+
}
113+
None
114+
}
115+
105116
fn input_is_instance(&self, class: &PyAny, _json_mask: u8) -> PyResult<bool> {
106117
// See PyO3/pyo3#2694 - we can't use `is_instance` here since it requires PyType,
107118
// and some check objects are not types, this logic is lifted from `is_instance` in PyO3

src/validators/model.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use pyo3::prelude::*;
77
use pyo3::types::{PyDict, PySet, PyString, PyTuple, PyType};
88
use pyo3::{ffi, intern};
99

10+
use crate::argument_markers::{ValidatedData, VALIDATED_DATA_KEY};
1011
use crate::build_tools::{py_err, schema_or_config_same, SchemaDict};
1112
use crate::errors::{ErrorType, ValError, ValResult};
1213
use crate::input::{py_error_on_minusone, Input};
@@ -51,6 +52,7 @@ pub struct ModelValidator {
5152
post_init: Option<Py<PyString>>,
5253
name: String,
5354
frozen: bool,
55+
custom_init: bool,
5456
}
5557

5658
impl BuildValidator for ModelValidator {
@@ -87,6 +89,7 @@ impl BuildValidator for ModelValidator {
8789
// which is not what we want here
8890
name: class.getattr(intern!(py, "__name__"))?.extract()?,
8991
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
92+
custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false),
9093
}
9194
.into())
9295
}
@@ -227,6 +230,18 @@ impl ModelValidator {
227230
..*extra
228231
};
229232

233+
if self.custom_init {
234+
if let Some(validated_data) = input.validated_data() {
235+
set_model_attrs(
236+
self_instance,
237+
validated_data.model_dict.as_ref(py),
238+
validated_data.fields_set.as_ref(py),
239+
)?;
240+
// we don't call post_init here, it'll be called by the original validator
241+
return Ok(self_instance.into_py(py));
242+
}
243+
}
244+
230245
let output = self.validator.validate(py, input, &new_extra, slots, recursion_guard)?;
231246
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;
232247
set_model_attrs(self_instance, model_dict, fields_set)?;
@@ -250,9 +265,16 @@ impl ModelValidator {
250265

251266
fn create_class(&self, model_dict: &PyAny, fields_set: &PyAny) -> PyResult<PyObject> {
252267
let py = model_dict.py();
253-
let instance = create_class(self.class.as_ref(py))?;
254-
set_model_attrs(instance.as_ref(py), model_dict, fields_set)?;
255-
Ok(instance)
268+
if self.custom_init {
269+
let kwargs = PyDict::new(py);
270+
let vd = ValidatedData::new(model_dict, fields_set);
271+
kwargs.set_item(intern!(py, VALIDATED_DATA_KEY), vd.into_py(py))?;
272+
self.class.call(py, (), Some(kwargs))
273+
} else {
274+
let instance = create_class(self.class.as_ref(py))?;
275+
set_model_attrs(instance.as_ref(py), model_dict, fields_set)?;
276+
Ok(instance)
277+
}
256278
}
257279
}
258280

tests/validators/test_model.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,3 +1161,102 @@ def f(values_or_values_and_fields_set: Any, *args: Any) -> Any:
11611161
assert m.b == 2
11621162
assert m.__pydantic_fields_set__ == {'a', 'b'}
11631163
assert calls == [call1, call2]
1164+
1165+
1166+
def test_custom_init():
1167+
calls = []
1168+
1169+
class Model:
1170+
__slots__ = '__dict__', '__pydantic_fields_set__'
1171+
1172+
def __init__(self, **kwargs):
1173+
validated_data = kwargs['validated_data']
1174+
self.a = validated_data.model_dict['a']
1175+
self.b = validated_data.model_dict['b']
1176+
self.__pydantic_fields_set__ = validated_data.fields_set
1177+
calls.append(repr(kwargs))
1178+
1179+
v = SchemaValidator(
1180+
core_schema.model_schema(
1181+
Model,
1182+
core_schema.typed_dict_schema(
1183+
{
1184+
'a': core_schema.typed_dict_field(
1185+
core_schema.with_default_schema(core_schema.int_schema(), default=1)
1186+
),
1187+
'b': core_schema.typed_dict_field(core_schema.int_schema()),
1188+
},
1189+
return_fields_set=True,
1190+
),
1191+
custom_init=True,
1192+
)
1193+
)
1194+
1195+
m = v.validate_python({'b': 2})
1196+
assert m.a == 1
1197+
assert m.b == 2
1198+
assert m.__pydantic_fields_set__ == {'b'}
1199+
assert calls == ["{'validated_data': ValidatedData(model_dict={'a': 1, 'b': 2}, fields_set={'b'})}"]
1200+
1201+
1202+
def test_custom_init_nested():
1203+
calls = []
1204+
1205+
class ModelInner:
1206+
__slots__ = '__dict__', '__pydantic_fields_set__'
1207+
a: int
1208+
b: int
1209+
1210+
def __init__(self, **data):
1211+
calls.append(f'inner: {data!r}')
1212+
self.__pydantic_validator__.validate_python(data, self_instance=self)
1213+
1214+
inner_schema = core_schema.model_schema(
1215+
ModelInner,
1216+
core_schema.typed_dict_schema(
1217+
{
1218+
'a': core_schema.typed_dict_field(core_schema.with_default_schema(core_schema.int_schema(), default=1)),
1219+
'b': core_schema.typed_dict_field(core_schema.int_schema()),
1220+
},
1221+
return_fields_set=True,
1222+
),
1223+
custom_init=True,
1224+
)
1225+
ModelInner.__pydantic_validator__ = SchemaValidator(inner_schema)
1226+
1227+
class ModelOuter:
1228+
__slots__ = '__dict__', '__pydantic_fields_set__'
1229+
a: int
1230+
b: ModelInner
1231+
1232+
def __init__(self, **data):
1233+
calls.append(f'outer: {data!r}')
1234+
self.__pydantic_validator__.validate_python(data, self_instance=self)
1235+
1236+
ModelOuter.__pydantic_validator__ = SchemaValidator(
1237+
core_schema.model_schema(
1238+
ModelOuter,
1239+
core_schema.typed_dict_schema(
1240+
{
1241+
'a': core_schema.typed_dict_field(
1242+
core_schema.with_default_schema(core_schema.int_schema(), default=1)
1243+
),
1244+
'b': core_schema.typed_dict_field(inner_schema),
1245+
},
1246+
return_fields_set=True,
1247+
),
1248+
custom_init=True,
1249+
)
1250+
)
1251+
1252+
m = ModelOuter(a=2, b={'b': 3})
1253+
assert m.__pydantic_fields_set__ == {'a', 'b'}
1254+
assert m.a == 2
1255+
assert isinstance(m.b, ModelInner)
1256+
assert m.b.a == 1
1257+
assert m.b.b == 3
1258+
# insert_assert(calls)
1259+
assert calls == [
1260+
"outer: {'a': 2, 'b': {'b': 3}}",
1261+
"inner: {'validated_data': ValidatedData(model_dict={'a': 1, 'b': 3}, fields_set={'b'})}",
1262+
]

0 commit comments

Comments
 (0)