Skip to content

Commit 3f7c010

Browse files
authored
Make validating assignment work properly with allowed extra (#766)
1 parent f5b804b commit 3f7c010

File tree

6 files changed

+62
-22
lines changed

6 files changed

+62
-22
lines changed

python/pydantic_core/_pydantic_core.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ class SchemaValidator:
102102
strict: bool | None = None,
103103
from_attributes: bool | None = None,
104104
context: 'dict[str, Any] | None' = None,
105-
) -> dict[str, Any]: ...
105+
) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]:
106+
"""
107+
ModelValidator and ModelFieldsValidator will return a tuple of (fields data, extra data, fields set)
108+
"""
106109
def get_default_value(self, *, strict: bool | None = None, context: Any = None) -> Some | None: ...
107110

108111
_IncEx: TypeAlias = set[int] | set[str] | dict[int, _IncEx] | dict[str, _IncEx] | None

src/validators/model.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,32 +183,41 @@ impl Validator for ModelValidator {
183183
Ok(model.into_py(py))
184184
};
185185
}
186-
let dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?;
186+
let old_dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?;
187187

188-
let new_dict = dict.copy()?;
189-
new_dict.set_item(field_name, field_value)?;
188+
let input_dict = old_dict.copy()?;
189+
let old_extra: Option<&PyDict> = model.getattr(intern!(py, DUNDER_MODEL_EXTRA_KEY))?.downcast().ok();
190+
if let Some(old_extra) = old_extra {
191+
input_dict.update(old_extra.as_mapping())?;
192+
}
193+
input_dict.set_item(field_name, field_value)?;
190194

191195
let output = self.validator.validate_assignment(
192196
py,
193-
new_dict,
197+
input_dict,
194198
field_name,
195199
field_value,
196200
extra,
197201
definitions,
198202
recursion_guard,
199203
)?;
200204

201-
let (output, _, updated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?;
205+
let (validated_dict, validated_extra, validated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?;
202206

203207
if let Ok(fields_set) = model.getattr(intern!(py, DUNDER_FIELDS_SET_KEY)) {
204208
let fields_set: &PySet = fields_set.downcast()?;
205-
for field_name in updated_fields_set {
209+
for field_name in validated_fields_set {
206210
fields_set.add(field_name)?;
207211
}
208212
}
209-
let output = output.to_object(py);
210213

211-
force_setattr(py, model, intern!(py, DUNDER_DICT), output)?;
214+
force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
215+
force_setattr(
216+
py,
217+
model,
218+
intern!(py, DUNDER_MODEL_EXTRA_KEY),
219+
validated_extra.to_object(py),
220+
)?;
212221
Ok(model.into_py(py))
213222
}
214223

src/validators/model_fields.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,13 @@ impl Validator for ModelFieldsValidator {
302302
) -> ValResult<'data, PyObject> {
303303
let dict: &PyDict = obj.downcast()?;
304304

305-
let ok = |output: PyObject| {
305+
let get_updated_dict = |output: PyObject| {
306306
dict.set_item(field_name, output)?;
307-
Ok(dict.to_object(py))
307+
Ok(dict)
308308
};
309309

310310
let prepare_result = |result: ValResult<'data, PyObject>| match result {
311-
Ok(output) => ok(output),
311+
Ok(output) => get_updated_dict(output),
312312
Err(ValError::LineErrors(line_errors)) => {
313313
let errors = line_errors
314314
.into_iter()
@@ -358,7 +358,7 @@ impl Validator for ModelFieldsValidator {
358358
Some(ref validator) => {
359359
prepare_result(validator.validate(py, field_value, &extra, definitions, recursion_guard))
360360
}
361-
None => ok(field_value.to_object(py)),
361+
None => get_updated_dict(field_value.to_object(py)),
362362
},
363363
ExtraBehavior::Forbid | ExtraBehavior::Ignore => {
364364
return Err(ValError::new_with_loc(
@@ -372,8 +372,24 @@ impl Validator for ModelFieldsValidator {
372372
}
373373
}?;
374374

375+
let new_extra = match &self.extra_behavior {
376+
ExtraBehavior::Allow => {
377+
let non_extra_data = PyDict::new(py);
378+
self.fields.iter().for_each(|f| {
379+
let popped_value = PyAny::get_item(new_data, &f.name).unwrap();
380+
new_data.del_item(&f.name).unwrap();
381+
non_extra_data.set_item(&f.name, popped_value).unwrap();
382+
});
383+
let new_extra = new_data.copy()?;
384+
new_data.clear();
385+
new_data.update(non_extra_data.as_mapping())?;
386+
new_extra.to_object(py)
387+
}
388+
_ => py.None(),
389+
};
390+
375391
let fields_set: &PySet = PySet::new(py, &[field_name.to_string()])?;
376-
Ok((new_data, py.None(), fields_set.to_object(py)).to_object(py))
392+
Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py))
377393
}
378394

379395
fn different_strict_behavior(

tests/validators/test_function.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,9 @@ class Model:
461461
__slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__'
462462
field_a: str
463463

464+
def __init__(self):
465+
self.__pydantic_extra__ = None # this attribute must be present for validate_assignment
466+
464467
v = SchemaValidator(
465468
core_schema.no_info_after_validator_function(
466469
f,
@@ -474,6 +477,7 @@ class Model:
474477
assert m.field_a == 'test'
475478
assert m.__pydantic_fields_set__ == {'field_a'}
476479
assert m.__dict__ == {'field_a': 'test', 'more': 'foobar'}
480+
assert m.__pydantic_extra__ is None
477481

478482
m2 = Model()
479483
m2.field_a = 'test'

tests/validators/test_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,9 @@ class MyModel:
943943
field_a: str
944944
field_b: int
945945

946+
def __init__(self):
947+
self.__pydantic_extra__ = None
948+
946949
v = SchemaValidator(
947950
{
948951
'type': 'model',
@@ -1019,7 +1022,10 @@ def func(x, info):
10191022

10201023
def test_validate_assignment_no_fields_set():
10211024
class MyModel:
1022-
__slots__ = ('__dict__',)
1025+
__slots__ = ('__dict__', '__pydantic_extra__')
1026+
1027+
def __init__(self):
1028+
self.__pydantic_extra__ = None
10231029

10241030
v = SchemaValidator(
10251031
{

tests/validators/test_model_fields.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ def test_validate_assignment_allow_extra():
347347
assert v.validate_python({'field_a': 'test'}) == ({'field_a': 'test'}, {}, {'field_a'})
348348

349349
assert v.validate_assignment({'field_a': 'test'}, 'other_field', 456) == (
350-
{'field_a': 'test', 'other_field': 456},
351-
None,
350+
{'field_a': 'test'},
351+
{'other_field': 456},
352352
{'other_field'},
353353
)
354354

@@ -364,8 +364,8 @@ def test_validate_assignment_allow_extra_validate():
364364
)
365365

366366
assert v.validate_assignment({'field_a': 'test'}, 'other_field', '456') == (
367-
{'field_a': 'test', 'other_field': 456},
368-
None,
367+
{'field_a': 'test'},
368+
{'other_field': 456},
369369
{'other_field'},
370370
)
371371

@@ -1682,10 +1682,12 @@ def test_extra_behavior_allow(
16821682
assert fields_set == {'f', 'extra_field'}
16831683

16841684
v.validate_assignment(m, 'f', 'y')
1685-
assert m['f'] == 'y'
1685+
assert m == {'f': 'y'}
16861686

1687-
v.validate_assignment(m, 'not_f', '123')
1688-
assert m['not_f'] == expected_extra_value
1687+
new_m, new_model_extra, new_fields_set = v.validate_assignment({**m, **model_extra}, 'not_f', '123')
1688+
assert new_m == {'f': 'y'}
1689+
assert new_model_extra == {'extra_field': expected_extra_value, 'not_f': expected_extra_value}
1690+
assert new_fields_set == {'not_f'}
16891691

16901692

16911693
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)