Skip to content

Commit 1081443

Browse files
committed
Fix for validation and revalidation
1 parent d2623da commit 1081443

File tree

2 files changed

+118
-8
lines changed

2 files changed

+118
-8
lines changed

src/validators/dataclass.rs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ impl BuildValidator for DataclassValidator {
454454
)?)?,
455455
name,
456456
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
457-
slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)),
457+
slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)) | schema.get_as(intern!(py, "slots"))?.unwrap_or(false),
458458
}
459459
.into())
460460
}
@@ -478,10 +478,19 @@ impl Validator for DataclassValidator {
478478
let class = self.class.as_ref(py);
479479
if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? {
480480
if self.revalidate.should_revalidate(input, class) {
481-
let input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?;
481+
let mut validator_input = PyDict::new(py);
482+
if self.slots {
483+
let slots = input.input_get_attr(intern!(py, "__slots__")).unwrap()?.downcast::<PyTuple>()?;
484+
for key in slots.iter() {
485+
let key: &PyString = key.downcast()?;
486+
validator_input.set_item(key, input.input_get_attr(key).unwrap()?)?;
487+
}
488+
} else {
489+
validator_input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?.downcast::<PyDict>()?;
490+
}
482491
let val_output = self
483492
.validator
484-
.validate(py, input, extra, definitions, recursion_guard)?;
493+
.validate(py, validator_input.downcast::<PyAny>()?, extra, definitions, recursion_guard)?;
485494
let dc = create_class(self.class.as_ref(py))?;
486495
self.set_dict_call(py, dc.as_ref(py), val_output, input)?;
487496
Ok(dc)
@@ -518,8 +527,19 @@ impl Validator for DataclassValidator {
518527
if self.frozen {
519528
return Err(ValError::new(ErrorType::FrozenInstance, field_value));
520529
}
530+
531+
let mut dict = PyDict::new(py);
521532
let dict_py_str = intern!(py, "__dict__");
522-
let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?;
533+
534+
if self.slots {
535+
let slots = obj.input_get_attr(intern!(py, "__slots__")).unwrap()?.downcast::<PyTuple>()?;
536+
for key in slots.iter() {
537+
let key: &PyString = key.downcast()?;
538+
dict.set_item(key, obj.input_get_attr(key).unwrap()?)?;
539+
}
540+
} else {
541+
dict = obj.getattr(dict_py_str)?.downcast()?;
542+
}
523543

524544
let new_dict = dict.copy()?;
525545
new_dict.set_item(field_name, field_value)?;
@@ -538,7 +558,11 @@ impl Validator for DataclassValidator {
538558

539559
let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?;
540560

541-
force_setattr(py, obj, dict_py_str, dc_dict)?;
561+
if self.slots {
562+
force_setattr(py, obj, field_name, field_value)?;
563+
} else {
564+
force_setattr(py, obj, dict_py_str, dc_dict)?;
565+
}
542566

543567
Ok(obj.to_object(py))
544568
}

tests/validators/test_dataclasses.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,9 +1195,7 @@ def test_custom_dataclass_names():
11951195

11961196
@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10')
11971197
def test_slots() -> None:
1198-
kwargs = {'slots': True}
1199-
1200-
@dataclasses.dataclass(**kwargs)
1198+
@dataclasses.dataclass(slots=True)
12011199
class Model:
12021200
x: int
12031201

@@ -1290,3 +1288,91 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str:
12901288
v = SchemaValidator(schema)
12911289
foo = v.validate_python({'a': 1, 'b': b'hello'})
12921290
assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'}
1291+
1292+
1293+
@dataclasses.dataclass(slots=True)
1294+
class FooDataclassSlots:
1295+
a: str
1296+
b: bool
1297+
1298+
1299+
@dataclasses.dataclass(slots=True)
1300+
class FooDataclassSameSlots(FooDataclassSlots):
1301+
pass
1302+
1303+
1304+
@dataclasses.dataclass(slots=True)
1305+
class FooDataclassMoreSlots(FooDataclassSlots):
1306+
c: str
1307+
1308+
1309+
@dataclasses.dataclass(slots=True)
1310+
class DuplicateDifferentSlots:
1311+
a: str
1312+
b: bool
1313+
1314+
1315+
@pytest.mark.parametrize(
1316+
'revalidate_instances,input_value,expected',
1317+
[
1318+
('always', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}),
1319+
('always', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}),
1320+
('always', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}),
1321+
('always', FooDataclassMoreSlots(a='hello', b=True, c='more'), Err(r'c\s+Unexpected keyword argument')),
1322+
(
1323+
'always',
1324+
DuplicateDifferentSlots(a='hello', b=True),
1325+
Err('should be a dictionary or an instance of FooDataclass'),
1326+
),
1327+
# revalidate_instances='subclass-instances'
1328+
('subclass-instances', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}),
1329+
('subclass-instances', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}),
1330+
('subclass-instances', FooDataclassSlots(a=b'hello', b='true'), {'a': b'hello', 'b': 'true'}),
1331+
('subclass-instances', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}),
1332+
('subclass-instances', FooDataclassSameSlots(a=b'hello', b='true'), {'a': 'hello', 'b': True}),
1333+
('subclass-instances', FooDataclassMoreSlots(a='hello', b=True, c='more'), Err('Unexpected keyword argument')),
1334+
(
1335+
'subclass-instances',
1336+
DuplicateDifferentSlots(a='hello', b=True),
1337+
Err('dictionary or an instance of FooDataclass'),
1338+
),
1339+
# revalidate_instances='never'
1340+
('never', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}),
1341+
('never', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}),
1342+
('never', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}),
1343+
('never', FooDataclassMoreSlots(a='hello', b=True, c='more'), {'a': 'hello', 'b': True, 'c': 'more'}),
1344+
('never', FooDataclassMoreSlots(a='hello', b='wrong', c='more'), {'a': 'hello', 'b': 'wrong', 'c': 'more'}),
1345+
(
1346+
'never',
1347+
DuplicateDifferentSlots(a='hello', b=True),
1348+
Err('should be a dictionary or an instance of FooDataclass'),
1349+
),
1350+
],
1351+
)
1352+
def test_slots_dataclass_subclass(revalidate_instances, input_value, expected):
1353+
schema = core_schema.dataclass_schema(
1354+
FooDataclassSlots,
1355+
core_schema.dataclass_args_schema(
1356+
'FooDataclass',
1357+
[
1358+
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
1359+
core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()),
1360+
],
1361+
extra_behavior='forbid',
1362+
),
1363+
revalidate_instances=revalidate_instances,
1364+
slots=True,
1365+
)
1366+
v = SchemaValidator(schema)
1367+
1368+
if isinstance(expected, Err):
1369+
with pytest.raises(ValidationError, match=expected.message) as exc_info:
1370+
print(v.validate_python(input_value))
1371+
1372+
# debug(exc_info.value.errors(include_url=False))
1373+
if expected.errors is not None:
1374+
assert exc_info.value.errors(include_url=False) == expected.errors
1375+
else:
1376+
dc = v.validate_python(input_value)
1377+
assert dataclasses.is_dataclass(dc)
1378+
assert dataclasses.asdict(dc) == expected

0 commit comments

Comments
 (0)