Skip to content

Commit b863be7

Browse files
authored
RootModel (#592)
1 parent 493643e commit b863be7

File tree

5 files changed

+296
-55
lines changed

5 files changed

+296
-55
lines changed

pydantic_core/core_schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2838,6 +2838,7 @@ class ModelSchema(TypedDict, total=False):
28382838
cls: Required[Type[Any]]
28392839
schema: Required[CoreSchema]
28402840
custom_init: bool
2841+
root_model: bool
28412842
post_init: str
28422843
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
28432844
strict: bool
@@ -2854,6 +2855,7 @@ def model_schema(
28542855
schema: CoreSchema,
28552856
*,
28562857
custom_init: bool | None = None,
2858+
root_model: bool | None = None,
28572859
post_init: str | None = None,
28582860
revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None,
28592861
strict: bool | None = None,
@@ -2894,6 +2896,7 @@ class MyModel:
28942896
cls: The class to use for the model
28952897
schema: The schema to use for the model
28962898
custom_init: Whether the model has a custom init method
2899+
root_model: Whether the model is a `RootModel`
28972900
post_init: The call after init to use for the model
28982901
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
28992902
should re-validate defaults to config.revalidate_instances, else 'never'
@@ -2910,6 +2913,7 @@ class MyModel:
29102913
cls=cls,
29112914
schema=schema,
29122915
custom_init=custom_init,
2916+
root_model=root_model,
29132917
post_init=post_init,
29142918
revalidate_instances=revalidate_instances,
29152919
strict=strict,

src/validators/model.rs

Lines changed: 99 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::recursion_guard::RecursionGuard;
1414
use super::function::convert_err;
1515
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
1616

17+
const ROOT_FIELD: &str = "root";
1718
const DUNDER_DICT: &str = "__dict__";
1819
const DUNDER_FIELDS_SET_KEY: &str = "__pydantic_fields_set__";
1920
const DUNDER_MODEL_EXTRA_KEY: &str = "__pydantic_extra__";
@@ -52,9 +53,10 @@ pub struct ModelValidator {
5253
validator: Box<CombinedValidator>,
5354
class: Py<PyType>,
5455
post_init: Option<Py<PyString>>,
55-
name: String,
5656
frozen: bool,
5757
custom_init: bool,
58+
root_model: bool,
59+
name: String,
5860
}
5961

6062
impl BuildValidator for ModelValidator {
@@ -87,11 +89,12 @@ impl BuildValidator for ModelValidator {
8789
post_init: schema
8890
.get_as::<&str>(intern!(py, "post_init"))?
8991
.map(|s| PyString::intern(py, s).into_py(py)),
92+
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
93+
custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false),
94+
root_model: schema.get_as(intern!(py, "root_model"))?.unwrap_or(false),
9095
// Get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
9196
// which is not what we want here
9297
name: class.getattr(intern!(py, "__name__"))?.extract()?,
93-
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
94-
custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false),
9598
}
9699
.into())
97100
}
@@ -125,28 +128,24 @@ impl Validator for ModelValidator {
125128
// mask 0 so JSON is input is never true here
126129
if input.input_is_instance(class, 0)? {
127130
if self.revalidate.should_revalidate(input, class) {
128-
let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?;
129-
130-
// get dict here so from_attributes logic doesn't apply
131-
let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?;
132-
let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?;
133-
134-
let full_model_dict: &PyAny = if model_extra.is_none() {
135-
dict
131+
if self.root_model {
132+
let inner_input: &PyAny = input.input_get_attr(intern!(py, ROOT_FIELD)).unwrap()?;
133+
self.validate_construct(py, inner_input, None, extra, definitions, recursion_guard)
136134
} else {
137-
let full_model_dict = dict.downcast::<PyDict>()?.copy()?;
138-
full_model_dict.update(model_extra.downcast()?)?;
139-
full_model_dict
140-
};
141-
142-
let output = self
143-
.validator
144-
.validate(py, full_model_dict, extra, definitions, recursion_guard)?;
145-
146-
let (model_dict, model_extra, _): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
147-
let instance = self.create_class(model_dict, model_extra, fields_set)?;
148-
149-
self.call_post_init(py, instance, input, extra)
135+
let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?;
136+
// get dict here so from_attributes logic doesn't apply
137+
let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?;
138+
let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?;
139+
140+
let inner_input: &PyAny = if model_extra.is_none() {
141+
dict
142+
} else {
143+
let full_model_dict = dict.downcast::<PyDict>()?.copy()?;
144+
full_model_dict.update(model_extra.downcast()?)?;
145+
full_model_dict
146+
};
147+
self.validate_construct(py, inner_input, Some(fields_set), extra, definitions, recursion_guard)
148+
}
150149
} else {
151150
Ok(input.to_object(py))
152151
}
@@ -158,22 +157,7 @@ impl Validator for ModelValidator {
158157
input,
159158
))
160159
} else {
161-
if self.custom_init {
162-
// If we wanted, we could introspect the __init__ signature, and store the
163-
// keyword arguments and types, and create a validator for them.
164-
// Perhaps something similar to `validate_call`? Could probably make
165-
// this work with from_attributes, and would essentially allow you to
166-
// handle init vars by adding them to the __init__ signature.
167-
if let Some(kwargs) = input.as_kwargs(py) {
168-
return Ok(self.class.call(py, (), Some(kwargs))?);
169-
}
170-
}
171-
let output = self
172-
.validator
173-
.validate(py, input, extra, definitions, recursion_guard)?;
174-
let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
175-
let instance = self.create_class(model_dict, model_extra, fields_set)?;
176-
self.call_post_init(py, instance, input, extra)
160+
self.validate_construct(py, input, None, extra, definitions, recursion_guard)
177161
}
178162
}
179163

@@ -189,9 +173,29 @@ impl Validator for ModelValidator {
189173
) -> ValResult<'data, PyObject> {
190174
if self.frozen {
191175
return Err(ValError::new(ErrorType::FrozenInstance, field_value));
176+
} else if self.root_model {
177+
return if field_name != ROOT_FIELD {
178+
Err(ValError::new_with_loc(
179+
ErrorType::NoSuchAttribute {
180+
attribute: field_name.to_string(),
181+
},
182+
field_value,
183+
field_name.to_string(),
184+
))
185+
} else {
186+
let field_extra = Extra {
187+
field_name: Some(field_name),
188+
..*extra
189+
};
190+
let output = self
191+
.validator
192+
.validate(py, field_value, &field_extra, definitions, recursion_guard)?;
193+
194+
force_setattr(py, model, intern!(py, ROOT_FIELD), output)?;
195+
Ok(model.into_py(py))
196+
};
192197
}
193-
let dict_py_str = intern!(py, DUNDER_DICT);
194-
let dict: &PyDict = model.getattr(dict_py_str)?.downcast()?;
198+
let dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?;
195199

196200
let new_dict = dict.copy()?;
197201
new_dict.set_item(field_name, field_value)?;
@@ -216,7 +220,7 @@ impl Validator for ModelValidator {
216220
}
217221
let output = output.to_object(py);
218222

219-
force_setattr(py, model, dict_py_str, output)?;
223+
force_setattr(py, model, intern!(py, DUNDER_DICT), output)?;
220224
Ok(model.into_py(py))
221225
}
222226

@@ -262,11 +266,61 @@ impl ModelValidator {
262266
let output = self
263267
.validator
264268
.validate(py, input, &new_extra, definitions, recursion_guard)?;
265-
let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
266-
set_model_attrs(self_instance, model_dict, model_extra, fields_set)?;
269+
270+
if self.root_model {
271+
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), output.as_ref(py))?;
272+
} else {
273+
let (model_dict, model_extra, fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
274+
set_model_attrs(self_instance, model_dict, model_extra, fields_set)?;
275+
}
267276
self.call_post_init(py, self_instance.into_py(py), input, extra)
268277
}
269278

279+
fn validate_construct<'s, 'data>(
280+
&'s self,
281+
py: Python<'data>,
282+
input: &'data impl Input<'data>,
283+
existing_fields_set: Option<&'data PyAny>,
284+
extra: &Extra,
285+
definitions: &'data Definitions<CombinedValidator>,
286+
recursion_guard: &'s mut RecursionGuard,
287+
) -> ValResult<'data, PyObject> {
288+
if self.custom_init {
289+
// If we wanted, we could introspect the __init__ signature, and store the
290+
// keyword arguments and types, and create a validator for them.
291+
// Perhaps something similar to `validate_call`? Could probably make
292+
// this work with from_attributes, and would essentially allow you to
293+
// handle init vars by adding them to the __init__ signature.
294+
if let Some(kwargs) = input.as_kwargs(py) {
295+
return Ok(self.class.call(py, (), Some(kwargs))?);
296+
}
297+
}
298+
299+
let output = if self.root_model {
300+
let field_extra = Extra {
301+
field_name: Some(ROOT_FIELD),
302+
..*extra
303+
};
304+
self.validator
305+
.validate(py, input, &field_extra, definitions, recursion_guard)?
306+
} else {
307+
self.validator
308+
.validate(py, input, extra, definitions, recursion_guard)?
309+
};
310+
311+
let instance = create_class(self.class.as_ref(py))?;
312+
let instance_ref = instance.as_ref(py);
313+
314+
if self.root_model {
315+
force_setattr(py, instance_ref, intern!(py, ROOT_FIELD), output)?;
316+
} else {
317+
let (model_dict, model_extra, val_fields_set): (&PyAny, &PyAny, &PyAny) = output.extract(py)?;
318+
let fields_set = existing_fields_set.unwrap_or(val_fields_set);
319+
set_model_attrs(instance_ref, model_dict, model_extra, fields_set)?;
320+
}
321+
self.call_post_init(py, instance, input, extra)
322+
}
323+
270324
fn call_post_init<'s, 'data>(
271325
&'s self,
272326
py: Python<'data>,
@@ -281,13 +335,6 @@ impl ModelValidator {
281335
}
282336
Ok(instance)
283337
}
284-
285-
fn create_class(&self, model_dict: &PyAny, model_extra: &PyAny, fields_set: &PyAny) -> PyResult<PyObject> {
286-
let py = model_dict.py();
287-
let instance = create_class(self.class.as_ref(py))?;
288-
set_model_attrs(instance.as_ref(py), model_dict, model_extra, fields_set)?;
289-
Ok(instance)
290-
}
291338
}
292339

293340
/// based on the following but with the second argument of new_func set to an empty tuple as required

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,29 @@ def test_validate_literal(
13131313
assert res == expected_val_res
13141314

13151315
benchmark(validator.validate_json, input_json)
1316+
1317+
1318+
@pytest.mark.benchmark(group='root_model')
1319+
def test_core_root_model(benchmark):
1320+
class MyModel:
1321+
__slots__ = 'root'
1322+
root: List[int]
1323+
1324+
v = SchemaValidator(
1325+
core_schema.model_schema(MyModel, core_schema.list_schema(core_schema.int_schema()), root_model=True)
1326+
)
1327+
assert v.validate_python([1, 2, '3']).root == [1, 2, 3]
1328+
input_data = list(range(100))
1329+
benchmark(v.validate_python, input_data)
1330+
1331+
1332+
@skip_pydantic
1333+
@pytest.mark.benchmark(group='root_model')
1334+
def test_v1_root_model(benchmark):
1335+
class MyModel(BaseModel):
1336+
__root__: List[int]
1337+
1338+
assert MyModel.parse_obj([1, 2, '3']).__root__ == [1, 2, 3]
1339+
input_data = list(range(100))
1340+
1341+
benchmark(MyModel.parse_obj, input_data)

tests/validators/test_model_init.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def test_model_init():
3333
m2 = MyModel()
3434
ans = v.validate_python({'field_a': 'test', 'field_b': 12}, self_instance=m2)
3535
assert ans == m2
36-
assert m.field_a == 'test'
37-
assert m.field_b == 12
38-
assert m.__pydantic_fields_set__ == {'field_a', 'field_b'}
36+
assert ans.field_a == 'test'
37+
assert ans.field_b == 12
38+
assert ans.__pydantic_fields_set__ == {'field_a', 'field_b'}
3939

4040

4141
def test_model_init_nested():
@@ -381,3 +381,40 @@ def __init__(self, **data):
381381
('inner', {'a': 1, 'b': 3}, {'b', 'z'}, {'z': 1}),
382382
('outer', {'a': 2, 'b': IsInstance(ModelInner)}, {'c', 'a', 'b'}, {'c': 1}),
383383
]
384+
385+
386+
def test_model_custom_init_revalidate():
387+
calls = []
388+
389+
class Model:
390+
__slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__'
391+
392+
def __init__(self, **kwargs):
393+
calls.append(repr(kwargs))
394+
self.__dict__.update(kwargs)
395+
self.__pydantic_fields_set__ = {'custom'}
396+
self.__pydantic_extra__ = None
397+
398+
v = SchemaValidator(
399+
core_schema.model_schema(
400+
Model,
401+
core_schema.model_fields_schema({'a': core_schema.model_field(core_schema.int_schema())}),
402+
custom_init=True,
403+
config=dict(revalidate_instances='always'),
404+
)
405+
)
406+
407+
m = v.validate_python({'a': '1'})
408+
assert isinstance(m, Model)
409+
assert m.a == '1'
410+
assert m.__pydantic_fields_set__ == {'custom'}
411+
assert calls == ["{'a': '1'}"]
412+
m.x = 4
413+
414+
m2 = v.validate_python(m)
415+
assert m2 is not m
416+
assert isinstance(m2, Model)
417+
assert m2.a == '1'
418+
assert m2.__dict__ == {'a': '1', 'x': 4}
419+
assert m2.__pydantic_fields_set__ == {'custom'}
420+
assert calls == ["{'a': '1'}", "{'a': '1', 'x': 4}"]

0 commit comments

Comments
 (0)