Skip to content

Commit 9381862

Browse files
authored
Use __pydantic_serializer__ for serializing to JSON (#557)
* Use __pydantic_serializer__ for serializing to JSON * Make to_json(by_alias=False) work * Use presence of __pydantic_serializer__ instead of anything else to go down model serialization path
1 parent 4f1ce33 commit 9381862

File tree

7 files changed

+106
-30
lines changed

7 files changed

+106
-30
lines changed

pydantic_core/core_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ def simple_ser_schema(type: ExpectedSerializationTypes) -> SimpleSerSchema:
223223
'timedelta',
224224
'url',
225225
'multi_host_url',
226+
'pydantic_serializable',
226227
'dataclass',
227-
'model',
228228
'enum',
229229
'path',
230230
]

src/serializers/infer.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
1313
use crate::build_tools::{py_err, safe_repr};
1414
use crate::serializers::errors::SERIALIZATION_ERR_MARKER;
1515
use crate::serializers::filter::SchemaFilter;
16+
use crate::serializers::shared::PydanticSerializer;
1617
use crate::serializers::{shared::TypeSerializer, SchemaSerializer};
1718
use crate::url::{PyMultiHostUrl, PyUrl};
1819

@@ -98,8 +99,8 @@ pub(crate) fn infer_to_python_known(
9899
return serializer.serializer.to_python(value, include, exclude, extra);
99100
}
100101
}
101-
// Fallback to dict serialization if `__pydantic_serializer__` is not set.else
102-
// This is currently only relevant to non-pydantic dataclasses.
102+
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
103+
// This currently only affects non-pydantic dataclasses.
103104
let dict = object_to_dict(value, is_model, extra)?;
104105
serialize_dict(dict)
105106
};
@@ -173,8 +174,8 @@ pub(crate) fn infer_to_python_known(
173174
let py_url: PyMultiHostUrl = value.extract()?;
174175
py_url.__str__().into_py(py)
175176
}
177+
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
176178
ObType::Dataclass => serialize_with_serializer(value, false)?,
177-
ObType::Model => serialize_with_serializer(value, true)?,
178179
ObType::Enum => {
179180
let v = value.getattr(intern!(py, "value"))?;
180181
infer_to_python(v, include, exclude, extra)?.into_py(py)
@@ -239,8 +240,8 @@ pub(crate) fn infer_to_python_known(
239240
}
240241
new_dict.into_py(py)
241242
}
242-
ObType::Dataclass => serialize_dict(object_to_dict(value, false, extra)?)?,
243-
ObType::Model => serialize_dict(object_to_dict(value, true, extra)?)?,
243+
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
244+
ObType::Dataclass => serialize_with_serializer(value, false)?,
244245
ObType::Generator => {
245246
let iter = super::type_serializers::generator::SerializationIterator::new(
246247
value.downcast()?,
@@ -388,6 +389,22 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
388389
}};
389390
}
390391

392+
macro_rules! serialize_with_serializer {
393+
($py_serializable:expr, $is_model:expr) => {{
394+
if let Ok(py_serializer) = value.getattr(intern!($py_serializable.py(), "__pydantic_serializer__")) {
395+
if let Ok(extracted_serializer) = py_serializer.extract::<SchemaSerializer>() {
396+
let pydantic_serializer =
397+
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, extra);
398+
return pydantic_serializer.serialize(serializer);
399+
}
400+
}
401+
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
402+
// This currently only affects non-pydantic dataclasses.
403+
let dict = object_to_dict(value, $is_model, extra).map_err(py_err_se_err)?;
404+
serialize_dict!(dict)
405+
}};
406+
}
407+
391408
let ser_result = match ob_type {
392409
ObType::None => serializer.serialize_none(),
393410
ObType::Int | ObType::IntSubclass => serialize!(i64),
@@ -442,8 +459,8 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
442459
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
443460
serializer.serialize_str(&py_url.__str__())
444461
}
445-
ObType::Dataclass => serialize_dict!(object_to_dict(value, false, extra).map_err(py_err_se_err)?),
446-
ObType::Model => serialize_dict!(object_to_dict(value, true, extra).map_err(py_err_se_err)?),
462+
ObType::Dataclass => serialize_with_serializer!(value, false),
463+
ObType::PydanticSerializable => serialize_with_serializer!(value, true),
447464
ObType::Enum => {
448465
let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
449466
infer_serialize(v, serializer, include, exclude, extra)
@@ -565,7 +582,7 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: &ObType, key: &'py PyAny, extra
565582
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
566583
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
567584
}
568-
ObType::Dataclass | ObType::Model => {
585+
ObType::Dataclass | ObType::PydanticSerializable => {
569586
// check that the instance is hashable
570587
key.hash()?;
571588
let key = key.str()?.to_string();

src/serializers/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,16 @@ impl SchemaSerializer {
170170

171171
#[allow(clippy::too_many_arguments)]
172172
#[pyfunction]
173-
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, exclude_none = false, round_trip = false,
174-
timedelta_mode = None, bytes_mode = None, serialize_unknown = false, fallback = None))]
173+
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
174+
exclude_none = false, round_trip = false, timedelta_mode = None, bytes_mode = None,
175+
serialize_unknown = false, fallback = None))]
175176
pub fn to_json(
176177
py: Python,
177178
value: &PyAny,
178179
indent: Option<usize>,
179180
include: Option<&PyAny>,
180181
exclude: Option<&PyAny>,
182+
by_alias: bool,
181183
exclude_none: bool,
182184
round_trip: bool,
183185
timedelta_mode: Option<&str>,
@@ -189,7 +191,7 @@ pub fn to_json(
189191
let extra = state.extra(
190192
py,
191193
&SerMode::Json,
192-
true,
194+
by_alias,
193195
exclude_none,
194196
round_trip,
195197
serialize_unknown,

src/serializers/ob_type.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ impl ObTypeLookup {
126126
ObType::Url => self.url == ob_type,
127127
ObType::MultiHostUrl => self.multi_host_url == ob_type,
128128
ObType::Dataclass => is_dataclass(op_value),
129-
ObType::Model => is_pydantic_model(op_value),
129+
ObType::PydanticSerializable => is_pydantic_serializable(op_value),
130130
ObType::Enum => self.enum_type == ob_type,
131131
ObType::Generator => self.generator == ob_type,
132132
ObType::Path => self.path == ob_type,
@@ -208,10 +208,10 @@ impl ObTypeLookup {
208208
ObType::Url
209209
} else if ob_type == self.multi_host_url {
210210
ObType::MultiHostUrl
211+
} else if is_pydantic_serializable(op_value) {
212+
ObType::PydanticSerializable
211213
} else if is_dataclass(op_value) {
212214
ObType::Dataclass
213-
} else if is_pydantic_model(op_value) {
214-
ObType::Model
215215
} else if self.is_enum(op_value, type_ptr) {
216216
ObType::Enum
217217
} else if ob_type == self.generator || is_generator(op_value) {
@@ -256,10 +256,10 @@ fn is_dataclass(op_value: Option<&PyAny>) -> bool {
256256
}
257257
}
258258

259-
fn is_pydantic_model(op_value: Option<&PyAny>) -> bool {
259+
fn is_pydantic_serializable(op_value: Option<&PyAny>) -> bool {
260260
if let Some(value) = op_value {
261261
value
262-
.hasattr(intern!(value.py(), "__pydantic_validator__"))
262+
.hasattr(intern!(value.py(), "__pydantic_serializer__"))
263263
.unwrap_or(false)
264264
} else {
265265
false
@@ -305,9 +305,10 @@ pub enum ObType {
305305
// types from this package
306306
Url,
307307
MultiHostUrl,
308-
// dataclasses and pydantic models
308+
// anything with __pydantic_serializer__, including BaseModel and pydantic dataclasses
309+
PydanticSerializable,
310+
// vanilla dataclasses
309311
Dataclass,
310-
Model,
311312
// enum type
312313
Enum,
313314
// generator type

src/serializers/type_serializers/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl TypeSerializer for ModelSerializer {
8383

8484
fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
8585
if self.allow_value(key, extra)? {
86-
infer_json_key_known(&ObType::Model, key, extra)
86+
infer_json_key_known(&ObType::PydanticSerializable, key, extra)
8787
} else {
8888
extra.warnings.on_fallback_py(&self.name, key, extra)?;
8989
infer_json_key(key, extra)

tests/serializers/test_any.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
import pytest
1212
from dirty_equals import HasRepr, IsList
1313

14-
from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema, to_json
14+
from pydantic_core import PydanticSerializationError, SchemaSerializer, SchemaValidator, core_schema, to_json
1515

1616
from ..conftest import plain_repr
17+
from .test_dataclasses import IsStrictDict, on_pypy
1718
from .test_list_tuple import as_list, as_tuple
1819

1920

@@ -35,7 +36,7 @@ class MyDataclass:
3536

3637

3738
class MyModel:
38-
__pydantic_validator__ = 42
39+
__pydantic_serializer__ = 42
3940

4041
def __init__(self, **kwargs):
4142
for key, value in kwargs.items():
@@ -242,7 +243,7 @@ def test_include_dict(any_serializer):
242243

243244

244245
class FieldsSetModel:
245-
__pydantic_validator__ = 42
246+
__pydantic_serializer__ = 42
246247
__slots__ = '__dict__', '__pydantic_fields_set__'
247248

248249
def __init__(self, **kwargs):
@@ -412,3 +413,63 @@ def test_encoding(any_serializer, gen_input, kwargs, expected_json):
412413
assert to_json(gen_input(), **kwargs) == expected_json
413414
if not kwargs:
414415
assert any_serializer.to_python(gen_input(), mode='json') == json.loads(expected_json)
416+
417+
418+
def test_any_dataclass():
419+
@dataclasses.dataclass
420+
class Foo:
421+
a: str
422+
b: bytes
423+
424+
# Build a schema that does not include the field 'b', to test that it is not serialized
425+
schema = core_schema.dataclass_schema(
426+
Foo,
427+
core_schema.dataclass_args_schema(
428+
'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())]
429+
),
430+
)
431+
Foo.__pydantic_serializer__ = SchemaSerializer(schema)
432+
433+
s = SchemaSerializer(core_schema.any_schema())
434+
assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello')
435+
assert s.to_python(Foo(a='hello', b=b'more'), mode='json') == IsStrictDict(a='hello')
436+
j = s.to_json(Foo(a='hello', b=b'more'))
437+
438+
if on_pypy:
439+
assert json.loads(j) == {'a': 'hello'}
440+
else:
441+
assert j == b'{"a":"hello"}'
442+
443+
assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict()
444+
445+
446+
def test_any_model():
447+
class Foo:
448+
a: str
449+
b: bytes
450+
451+
def __init__(self, a: str, b: bytes):
452+
self.a = a
453+
self.b = b
454+
455+
# Build a schema that does not include the field 'b', to test that it is not serialized
456+
schema = core_schema.dataclass_schema(
457+
Foo,
458+
core_schema.dataclass_args_schema(
459+
'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())]
460+
),
461+
)
462+
Foo.__pydantic_validator__ = SchemaValidator(schema)
463+
Foo.__pydantic_serializer__ = SchemaSerializer(schema)
464+
465+
s = SchemaSerializer(core_schema.any_schema())
466+
assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello')
467+
assert s.to_python(Foo(a='hello', b=b'more'), mode='json') == IsStrictDict(a='hello')
468+
j = s.to_json(Foo(a='hello', b=b'more'))
469+
470+
if on_pypy:
471+
assert json.loads(j) == {'a': 'hello'}
472+
else:
473+
assert j == b'{"a":"hello"}'
474+
475+
assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict()

tests/test_json.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,8 @@ def __init__(self, my_foo: int, my_bar: str):
262262
instance = Foobar(my_foo=1, my_bar='a')
263263
assert to_jsonable_python(instance) == {'myFoo': 1, 'myBar': 'a'}
264264
assert to_jsonable_python(instance, by_alias=False) == {'my_foo': 1, 'my_bar': 'a'}
265-
266-
# Just including this to document the behavior. Note that trying to get `to_json` to respect aliases
267-
# by putting the by_alias field into the `extra` in `to_json` won't work because it always uses an
268-
# AnySchema for serialization. If you want any python-introspection to happen during the json serialization
269-
# process for the purpose of respecting objects' own serialization schemas, you need to call
270-
# to_jsonable_python and serialize the result yourself.
271-
assert to_json(instance) == b'{"my_foo":1,"my_bar":"a"}'
265+
assert to_json(instance) == b'{"myFoo":1,"myBar":"a"}'
266+
assert to_json(instance, by_alias=False) == b'{"my_foo":1,"my_bar":"a"}'
272267

273268

274269
def test_cycle_same():

0 commit comments

Comments
 (0)