Skip to content

Commit 93af8cc

Browse files
authored
Fix arguments to dataclass validation functions (#563)
1 parent b23b38a commit 93af8cc

File tree

2 files changed

+107
-3
lines changed

2 files changed

+107
-3
lines changed

src/validators/dataclass.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,11 @@ impl Validator for DataclassArgsValidator {
323323

324324
let ok = |output: PyObject| {
325325
dict.set_item(field_name, output)?;
326-
Ok(dict.to_object(py))
326+
// The second return value represents `init_only_args`
327+
// which doesn't make much sense in this context but we need to put something there
328+
// so that function validators that sit between DataclassValidator and DataclassArgsValidator
329+
// always get called the same shape of data.
330+
Ok(PyTuple::new(py, vec![dict.to_object(py), py.None()]).into_py(py))
327331
};
328332

329333
if let Some(field) = self.fields.iter().find(|f| f.name == field_name) {
@@ -511,10 +515,14 @@ impl Validator for DataclassValidator {
511515
let new_dict = dict.copy()?;
512516
new_dict.set_item(field_name, field_value)?;
513517

514-
let dc_dict =
518+
// Discard the second return value, which is `init_only_args` but is always
519+
// None anyway for validate_assignment; see validate_assignment in DataclassArgsValidator
520+
let val_assignment_result =
515521
self.validator
516522
.validate_assignment(py, new_dict, field_name, field_value, extra, slots, recursion_guard)?;
517523

524+
let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?;
525+
518526
force_setattr(py, obj, dict_py_str, dc_dict)?;
519527

520528
Ok(obj.to_object(py))

tests/validators/test_dataclasses.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
import re
3-
from typing import Any, Dict, Union
3+
from typing import Any, Dict, List, Union
44

55
import pytest
66
from dirty_equals import IsListOrTuple, IsStr
@@ -1039,3 +1039,99 @@ class MyModel:
10391039

10401040
v.validate_assignment(m, 'not_f', '123')
10411041
assert getattr(m, 'not_f') == '123'
1042+
1043+
1044+
def test_function_validator_wrapping_args_schema_after() -> None:
1045+
calls: List[Any] = []
1046+
1047+
def func(*args: Any) -> Any:
1048+
calls.append(args)
1049+
return args[0]
1050+
1051+
@dataclasses.dataclass
1052+
class Model:
1053+
number: int = 1
1054+
1055+
cs = core_schema.dataclass_schema(
1056+
Model,
1057+
core_schema.no_info_after_validator_function(
1058+
func,
1059+
core_schema.dataclass_args_schema(
1060+
'Model', [core_schema.dataclass_field('number', core_schema.int_schema())]
1061+
),
1062+
),
1063+
)
1064+
1065+
v = SchemaValidator(cs)
1066+
1067+
instance: Model = v.validate_python({'number': 1})
1068+
assert instance.number == 1
1069+
assert calls == [(({'number': 1}, None),)]
1070+
v.validate_assignment(instance, 'number', 2)
1071+
assert instance.number == 2
1072+
assert calls == [(({'number': 1}, None),), (({'number': 2}, None),)]
1073+
1074+
1075+
def test_function_validator_wrapping_args_schema_before() -> None:
1076+
calls: List[Any] = []
1077+
1078+
def func(*args: Any) -> Any:
1079+
calls.append(args)
1080+
return args[0]
1081+
1082+
@dataclasses.dataclass
1083+
class Model:
1084+
number: int = 1
1085+
1086+
cs = core_schema.dataclass_schema(
1087+
Model,
1088+
core_schema.no_info_before_validator_function(
1089+
func,
1090+
core_schema.dataclass_args_schema(
1091+
'Model', [core_schema.dataclass_field('number', core_schema.int_schema())]
1092+
),
1093+
),
1094+
)
1095+
1096+
v = SchemaValidator(cs)
1097+
1098+
instance: Model = v.validate_python({'number': 1})
1099+
assert instance.number == 1
1100+
assert calls == [({'number': 1},)]
1101+
v.validate_assignment(instance, 'number', 2)
1102+
assert instance.number == 2
1103+
assert calls == [({'number': 1},), ({'number': 2},)]
1104+
1105+
1106+
def test_function_validator_wrapping_args_schema_wrap() -> None:
1107+
calls: List[Any] = []
1108+
1109+
def func(*args: Any) -> Any:
1110+
assert len(args) == 2
1111+
input, handler = args
1112+
output = handler(input)
1113+
calls.append((input, output))
1114+
return output
1115+
1116+
@dataclasses.dataclass
1117+
class Model:
1118+
number: int = 1
1119+
1120+
cs = core_schema.dataclass_schema(
1121+
Model,
1122+
core_schema.no_info_wrap_validator_function(
1123+
func,
1124+
core_schema.dataclass_args_schema(
1125+
'Model', [core_schema.dataclass_field('number', core_schema.int_schema())]
1126+
),
1127+
),
1128+
)
1129+
1130+
v = SchemaValidator(cs)
1131+
1132+
instance: Model = v.validate_python({'number': 1})
1133+
assert instance.number == 1
1134+
assert calls == [({'number': 1}, ({'number': 1}, None))]
1135+
v.validate_assignment(instance, 'number', 2)
1136+
assert instance.number == 2
1137+
assert calls == [({'number': 1}, ({'number': 1}, None)), ({'number': 2}, ({'number': 2}, None))]

0 commit comments

Comments
 (0)