Skip to content

Commit 3f96b09

Browse files
authored
Set info.data and info.field_name for dataclass field validators (#454)
1 parent 653308f commit 3f96b09

File tree

2 files changed

+176
-2
lines changed

2 files changed

+176
-2
lines changed

src/validators/dataclass.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ impl Validator for DataclassArgsValidator {
130130
let mut errors: Vec<ValLineError> = Vec::new();
131131
let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len());
132132

133+
let extra = Extra {
134+
data: Some(output_dict),
135+
..*extra
136+
};
137+
133138
macro_rules! set_item {
134139
($field:ident, $value:expr) => {{
135140
let py_name = $field.py_name.as_ref(py);
@@ -147,6 +152,10 @@ impl Validator for DataclassArgsValidator {
147152
($args:ident, $get_method:ident, $get_macro:ident, $slice_macro:ident) => {{
148153
// go through fields getting the value from args or kwargs and validating it
149154
for (index, field) in self.fields.iter().enumerate() {
155+
let extra = Extra {
156+
field_name: Some(&field.name),
157+
..extra
158+
};
150159
let mut pos_value = None;
151160
if let Some(args) = $args.args {
152161
if !field.kw_only {
@@ -175,7 +184,7 @@ impl Validator for DataclassArgsValidator {
175184
(Some(pos_value), None) => {
176185
match field
177186
.validator
178-
.validate(py, pos_value, extra, slots, recursion_guard)
187+
.validate(py, pos_value, &extra, slots, recursion_guard)
179188
{
180189
Ok(value) => set_item!(field, value),
181190
Err(ValError::LineErrors(line_errors)) => {
@@ -192,7 +201,7 @@ impl Validator for DataclassArgsValidator {
192201
(None, Some(kw_value)) => {
193202
match field
194203
.validator
195-
.validate(py, kw_value, extra, slots, recursion_guard)
204+
.validate(py, kw_value, &extra, slots, recursion_guard)
196205
{
197206
Ok(value) => set_item!(field, value),
198207
Err(ValError::LineErrors(line_errors)) => {

tests/validators/test_dataclasses.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,168 @@ def test_dataclass_exact_validation(input_value, expected):
421421
v = SchemaValidator(schema)
422422
foo = v.validate_python(input_value)
423423
assert dataclasses.asdict(foo) == expected
424+
425+
426+
def test_dataclass_field_after_validator():
427+
@dataclasses.dataclass
428+
class Foo:
429+
a: int
430+
b: str
431+
432+
@classmethod
433+
def validate_b(cls, v: str, info: core_schema.ModelFieldValidationInfo) -> str:
434+
assert v == 'hello'
435+
assert info.field_name == 'b'
436+
assert info.data == {'a': 1}
437+
return 'hello world!'
438+
439+
schema = core_schema.dataclass_schema(
440+
Foo,
441+
core_schema.dataclass_args_schema(
442+
'Foo',
443+
[
444+
core_schema.dataclass_field(name='a', schema=core_schema.int_schema()),
445+
core_schema.dataclass_field(
446+
name='b',
447+
schema=core_schema.field_after_validation_function(Foo.validate_b, core_schema.str_schema()),
448+
),
449+
],
450+
),
451+
)
452+
453+
v = SchemaValidator(schema)
454+
foo = v.validate_python({'a': 1, 'b': b'hello'})
455+
assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'}
456+
457+
458+
def test_dataclass_field_plain_validator():
459+
@dataclasses.dataclass
460+
class Foo:
461+
a: int
462+
b: str
463+
464+
@classmethod
465+
def validate_b(cls, v: bytes, info: core_schema.ModelFieldValidationInfo) -> str:
466+
assert v == b'hello'
467+
assert info.field_name == 'b'
468+
assert info.data == {'a': 1}
469+
return 'hello world!'
470+
471+
schema = core_schema.dataclass_schema(
472+
Foo,
473+
core_schema.dataclass_args_schema(
474+
'Foo',
475+
[
476+
core_schema.dataclass_field(name='a', schema=core_schema.int_schema()),
477+
core_schema.dataclass_field(
478+
name='b', schema=core_schema.field_plain_validation_function(Foo.validate_b)
479+
),
480+
],
481+
),
482+
)
483+
484+
v = SchemaValidator(schema)
485+
foo = v.validate_python({'a': 1, 'b': b'hello'})
486+
assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'}
487+
488+
489+
def test_dataclass_field_before_validator():
490+
@dataclasses.dataclass
491+
class Foo:
492+
a: int
493+
b: str
494+
495+
@classmethod
496+
def validate_b(cls, v: bytes, info: core_schema.ModelFieldValidationInfo) -> bytes:
497+
assert v == b'hello'
498+
assert info.field_name == 'b'
499+
assert info.data == {'a': 1}
500+
return b'hello world!'
501+
502+
schema = core_schema.dataclass_schema(
503+
Foo,
504+
core_schema.dataclass_args_schema(
505+
'Foo',
506+
[
507+
core_schema.dataclass_field(name='a', schema=core_schema.int_schema()),
508+
core_schema.dataclass_field(
509+
name='b',
510+
schema=core_schema.field_before_validation_function(Foo.validate_b, core_schema.str_schema()),
511+
),
512+
],
513+
),
514+
)
515+
516+
v = SchemaValidator(schema)
517+
foo = v.validate_python({'a': 1, 'b': b'hello'})
518+
assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'}
519+
520+
521+
def test_dataclass_field_wrap_validator1():
522+
@dataclasses.dataclass
523+
class Foo:
524+
a: int
525+
b: str
526+
527+
@classmethod
528+
def validate_b(
529+
cls, v: bytes, nxt: core_schema.CallableValidator, info: core_schema.ModelFieldValidationInfo
530+
) -> str:
531+
assert v == b'hello'
532+
v = nxt(v)
533+
assert v == 'hello'
534+
assert info.field_name == 'b'
535+
assert info.data == {'a': 1}
536+
return 'hello world!'
537+
538+
schema = core_schema.dataclass_schema(
539+
Foo,
540+
core_schema.dataclass_args_schema(
541+
'Foo',
542+
[
543+
core_schema.dataclass_field(name='a', schema=core_schema.int_schema()),
544+
core_schema.dataclass_field(
545+
name='b',
546+
schema=core_schema.field_wrap_validation_function(Foo.validate_b, core_schema.str_schema()),
547+
),
548+
],
549+
),
550+
)
551+
552+
v = SchemaValidator(schema)
553+
foo = v.validate_python({'a': 1, 'b': b'hello'})
554+
assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'}
555+
556+
557+
def test_dataclass_field_wrap_validator2():
558+
@dataclasses.dataclass
559+
class Foo:
560+
a: int
561+
b: str
562+
563+
@classmethod
564+
def validate_b(
565+
cls, v: bytes, nxt: core_schema.CallableValidator, info: core_schema.ModelFieldValidationInfo
566+
) -> bytes:
567+
assert v == b'hello'
568+
assert info.field_name == 'b'
569+
assert info.data == {'a': 1}
570+
return nxt(b'hello world!')
571+
572+
schema = core_schema.dataclass_schema(
573+
Foo,
574+
core_schema.dataclass_args_schema(
575+
'Foo',
576+
[
577+
core_schema.dataclass_field(name='a', schema=core_schema.int_schema()),
578+
core_schema.dataclass_field(
579+
name='b',
580+
schema=core_schema.field_wrap_validation_function(Foo.validate_b, core_schema.str_schema()),
581+
),
582+
],
583+
),
584+
)
585+
586+
v = SchemaValidator(schema)
587+
foo = v.validate_python({'a': 1, 'b': b'hello'})
588+
assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'}

0 commit comments

Comments
 (0)