Skip to content

Add support for dataclass fields init #1163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,6 +2985,7 @@ class DataclassField(TypedDict, total=False):
name: Required[str]
schema: Required[CoreSchema]
kw_only: bool # default: True
init: bool # default: True
init_only: bool # default: False
frozen: bool # default: False
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
Expand All @@ -2998,6 +2999,7 @@ def dataclass_field(
schema: CoreSchema,
*,
kw_only: bool | None = None,
init: bool | None = None,
init_only: bool | None = None,
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
serialization_alias: str | None = None,
Expand All @@ -3023,6 +3025,7 @@ def dataclass_field(
name: The name to use for the argument parameter
schema: The schema to use for the argument parameter
kw_only: Whether the field can be set with a positional argument as well as a keyword argument
init: Whether the field should be validated during initialization
init_only: Whether the field should be omitted from `__dict__` and passed to `__post_init__`
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
Expand All @@ -3035,6 +3038,7 @@ def dataclass_field(
name=name,
schema=schema,
kw_only=kw_only,
init=init,
init_only=init_only,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
Expand Down
4 changes: 4 additions & 0 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct Field {
kw_only: bool,
name: String,
py_name: Py<PyString>,
init: bool,
init_only: bool,
lookup_key: LookupKey,
validator: CombinedValidator,
Expand Down Expand Up @@ -107,6 +108,7 @@ impl BuildValidator for DataclassArgsValidator {
py_name: py_name.into(),
lookup_key,
validator,
init: field.get_as(intern!(py, "init"))?.unwrap_or(true),
init_only: field.get_as(intern!(py, "init_only"))?.unwrap_or(false),
frozen: field.get_as::<bool>(intern!(py, "frozen"))?.unwrap_or(false),
});
Expand Down Expand Up @@ -176,6 +178,8 @@ impl Validator for DataclassArgsValidator {
($args:ident, $get_method:ident, $get_macro:ident, $slice_macro:ident) => {{
// go through fields getting the value from args or kwargs and validating it
for (index, field) in self.fields.iter().enumerate() {
if (!field.init) { continue };

let mut pos_value = None;
if let Some(args) = $args.args {
if !field.kw_only {
Expand Down
61 changes: 61 additions & 0 deletions tests/validators/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,3 +1592,64 @@ def _wrap_validator(cls, v, validator, info):
gc.collect()

assert ref() is None


@pytest.mark.parametrize(
'input_value,extra_behavior,expected',
[
({'a': 'hello', 'b': 'bye'}, 'ignore', {'a': 'hello', 'b': 'HELLO'}),
({'a': 'hello'}, 'ignore', {'a': 'hello', 'b': 'HELLO'}),
({'a': 'hello', 'b': 'bye'}, 'allow', {'a': 'hello', 'b': 'HELLO'}),
({'a': 'hello'}, 'allow', {'a': 'hello', 'b': 'HELLO'}),
(
{'a': 'hello', 'b': 'bye'},
'forbid',
Err(
'Unexpected keyword argument',
errors=[
{
'type': 'unexpected_keyword_argument',
'loc': ('b',),
'msg': 'Unexpected keyword argument',
'input': 'bye',
}
],
),
),
({'a': 'hello'}, 'forbid', {'a': 'hello', 'b': 'HELLO'}),
],
)
def test_dataclass_args_init(input_value, extra_behavior, expected):
@dataclasses.dataclass
class Foo:
a: str
b: str

def __post_init__(self):
self.b = self.a.upper()

schema = core_schema.dataclass_schema(
Foo,
core_schema.dataclass_args_schema(
'Foo',
[
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
core_schema.dataclass_field(name='b', schema=core_schema.str_schema(), init=False),
],
extra_behavior=extra_behavior,
),
['a', 'b'],
post_init=True,
)

v = SchemaValidator(schema)

if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)) as exc_info:
v.validate_python(input_value)

# debug(exc_info.value.errors(include_url=False))
if expected.errors is not None:
assert exc_info.value.errors(include_url=False) == expected.errors
else:
assert dataclasses.asdict(v.validate_python(input_value)) == expected