Skip to content

Add ability to specify dataclass name in dataclass_schema #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3111,6 +3111,7 @@ class DataclassSchema(TypedDict, total=False):
type: Required[Literal['dataclass']]
cls: Required[Type[Any]]
schema: Required[CoreSchema]
cls_name: str
post_init: bool # default: False
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
strict: bool # default: False
Expand All @@ -3124,6 +3125,7 @@ def dataclass_schema(
cls: Type[Any],
schema: CoreSchema,
*,
cls_name: str | None = None,
post_init: bool | None = None,
revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None,
strict: bool | None = None,
Expand All @@ -3139,6 +3141,7 @@ def dataclass_schema(
Args:
cls: The dataclass type, used to to perform subclass checks
schema: The schema to use for the dataclass fields
cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`)
post_init: Whether to call `__post_init__` after validation
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
should re-validate defaults to config.revalidate_instances, else 'never'
Expand All @@ -3151,6 +3154,7 @@ def dataclass_schema(
return dict_not_none(
type='dataclass',
cls=cls,
cls_name=cls_name,
schema=schema,
post_init=post_init,
revalidate_instances=revalidate_instances,
Expand Down
8 changes: 5 additions & 3 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ impl BuildValidator for DataclassValidator {
let py = schema.py();

let class: &PyType = schema.get_as_req(intern!(py, "cls"))?;
let name = match schema.get_as_req::<String>(intern!(py, "cls_name")) {
Ok(name) => name,
Err(_) => class.getattr(intern!(py, "__name__"))?.extract()?,
};
let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?;
let validator = build_validator(sub_schema, config, definitions)?;

Expand All @@ -447,9 +451,7 @@ impl BuildValidator for DataclassValidator {
config,
intern!(py, "revalidate_instances"),
)?)?,
// as with model, get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
// which is not what we want here
name: class.getattr(intern!(py, "__name__"))?.extract()?,
name,
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
}
.into())
Expand Down
55 changes: 54 additions & 1 deletion tests/validators/test_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import re
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import pytest
from dirty_equals import IsListOrTuple, IsStr
Expand Down Expand Up @@ -1137,3 +1137,56 @@ class Model:
v.validate_assignment(instance, 'number', 2)
assert instance.number == 2
assert calls == [({'number': 1}, ({'number': 1}, None)), ({'number': 2}, ({'number': 2}, None))]


@dataclasses.dataclass
class FooParentDataclass:
foo: Optional[FooDataclass]


def test_custom_dataclass_names():
# Note: normally you would use the same values for DataclassArgsSchema.dataclass_name and DataclassSchema.cls_name,
# but I have purposely made them different here to show which parts of the errors are affected by which.
# I have used square brackets in the names to hint that the most likely reason for using a value different from
# cls.__name__ is for use with generic types.
schema = core_schema.dataclass_schema(
FooParentDataclass,
core_schema.dataclass_args_schema(
'FooParentDataclass',
[
core_schema.dataclass_field(
name='foo',
schema=core_schema.union_schema(
[
core_schema.dataclass_schema(
FooDataclass,
core_schema.dataclass_args_schema(
'FooDataclass[dataclass_args_schema]',
[
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()),
],
),
cls_name='FooDataclass[cls_name]',
),
core_schema.none_schema(),
]
),
)
],
),
)

v = SchemaValidator(schema)
with pytest.raises(ValidationError) as exc_info:
v.validate_python({'foo': 123})
assert exc_info.value.errors(include_url=False) == [
{
'ctx': {'dataclass_name': 'FooDataclass[dataclass_args_schema]'},
'input': 123,
'loc': ('foo', 'FooDataclass[cls_name]'),
'msg': 'Input should be a dictionary or an instance of FooDataclass[dataclass_args_schema]',
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the value in ctx and msg comes from the dataclass_args_schema, whereas the value in loc comes from the dataclass_schema.

'type': 'dataclass_type',
},
{'input': 123, 'loc': ('foo', 'none'), 'msg': 'Input should be None', 'type': 'none_required'},
]