Skip to content

Commit 1a326fd

Browse files
committed
✨ Implement root model serialization
1 parent 084172a commit 1a326fd

File tree

4 files changed

+169
-22
lines changed

4 files changed

+169
-22
lines changed

src/serializers/type_serializers/function.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ impl FunctionPlainSerializer {
140140
self.func.call1(py, (model, value))
141141
}
142142
} else {
143-
Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found"))
143+
Err(PyRuntimeError::new_err("Function plain serializer expected to be run inside the context of a model field but no model was found"))
144144
}
145145
} else if self.info_arg {
146146
let info = SerializationInfo::new(py, include, exclude, extra, self.is_field_serializer)?;
@@ -368,7 +368,7 @@ impl FunctionWrapSerializer {
368368
self.func.call1(py, (model, value, serialize))
369369
}
370370
} else {
371-
Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found"))
371+
Err(PyRuntimeError::new_err("Function wrap serializer expected to be run inside the context of a model field but no model was found"))
372372
}
373373
} else if self.info_arg {
374374
let info = SerializationInfo::new(py, include, exclude, extra, self.is_field_serializer)?;
@@ -492,20 +492,20 @@ impl SerializationInfo {
492492
) -> PyResult<Self> {
493493
if is_field_serializer {
494494
match extra.field_name {
495-
Some(field_name) => Ok(
496-
Self {
497-
include: include.map(|i| i.into_py(py)),
498-
exclude: exclude.map(|e| e.into_py(py)),
499-
_mode: extra.mode.clone(),
500-
by_alias: extra.by_alias,
501-
exclude_unset: extra.exclude_unset,
502-
exclude_defaults: extra.exclude_defaults,
503-
exclude_none: extra.exclude_none,
504-
round_trip: extra.round_trip,
505-
field_name: Some(field_name.to_string()),
506-
}
507-
),
508-
_ => Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found")),
495+
Some(field_name) => Ok(Self {
496+
include: include.map(|i| i.into_py(py)),
497+
exclude: exclude.map(|e| e.into_py(py)),
498+
_mode: extra.mode.clone(),
499+
by_alias: extra.by_alias,
500+
exclude_unset: extra.exclude_unset,
501+
exclude_defaults: extra.exclude_defaults,
502+
exclude_none: extra.exclude_none,
503+
round_trip: extra.round_trip,
504+
field_name: Some(field_name.to_string()),
505+
}),
506+
_ => Err(PyRuntimeError::new_err(
507+
"Model field context expected for field serialization info but no model field was found",
508+
)),
509509
}
510510
} else {
511511
Ok(Self {

src/serializers/type_serializers/model.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use super::{
1515
SerField, TypeSerializer,
1616
};
1717

18+
const ROOT_FIELD: &str = "root";
19+
1820
pub struct ModelFieldsBuilder;
1921

2022
impl BuildSerializer for ModelFieldsBuilder {
@@ -66,6 +68,7 @@ pub struct ModelSerializer {
6668
class: Py<PyType>,
6769
serializer: Box<CombinedSerializer>,
6870
has_extra: bool,
71+
root_model: bool,
6972
name: String,
7073
}
7174

@@ -85,11 +88,13 @@ impl BuildSerializer for ModelSerializer {
8588
let class: &PyType = schema.get_as_req(intern!(py, "cls"))?;
8689
let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
8790
let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?);
91+
let root_model = schema.get_as(intern!(py, "root_model"))?.unwrap_or(false);
8892

8993
Ok(Self {
9094
class: class.into(),
9195
serializer,
9296
has_extra: has_extra(schema, config)?,
97+
root_model,
9398
name: class.getattr(intern!(py, "__name__"))?.extract()?,
9499
}
95100
.into())
@@ -139,11 +144,16 @@ impl TypeSerializer for ModelSerializer {
139144
exclude: Option<&PyAny>,
140145
extra: &Extra,
141146
) -> PyResult<PyObject> {
142-
let extra = Extra {
147+
let mut extra = Extra {
143148
model: Some(value),
144149
..*extra
145150
};
146-
if self.allow_value(value, &extra)? {
151+
if self.root_model {
152+
extra.field_name = Some(ROOT_FIELD);
153+
let py = value.py();
154+
let root = value.getattr(intern!(py, ROOT_FIELD))?;
155+
self.serializer.to_python(root, include, exclude, &extra)
156+
} else if self.allow_value(value, &extra)? {
147157
let inner_value = self.get_inner_value(value, &extra)?;
148158
self.serializer.to_python(inner_value, include, exclude, &extra)
149159
} else {
@@ -169,11 +179,17 @@ impl TypeSerializer for ModelSerializer {
169179
exclude: Option<&PyAny>,
170180
extra: &Extra,
171181
) -> Result<S::Ok, S::Error> {
172-
let extra = Extra {
182+
let mut extra = Extra {
173183
model: Some(value),
174184
..*extra
175185
};
176-
if self.allow_value(value, &extra).map_err(py_err_se_err)? {
186+
if self.root_model {
187+
extra.field_name = Some(ROOT_FIELD);
188+
let py = value.py();
189+
let root = value.getattr(intern!(py, ROOT_FIELD)).map_err(py_err_se_err)?;
190+
self.serializer
191+
.serde_serialize(root, serializer, include, exclude, &extra)
192+
} else if self.allow_value(value, &extra).map_err(py_err_se_err)? {
177193
let inner_value = self.get_inner_value(value, &extra).map_err(py_err_se_err)?;
178194
self.serializer
179195
.serde_serialize(inner_value, serializer, include, exclude, &extra)

tests/serializers/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,5 +857,5 @@ class InnerModel:
857857
s = SchemaSerializer(schema)
858858
# debug(s)
859859
s_repr = plain_repr(s)
860-
assert 'has_extra:true,name:"InnerModel"' in s_repr
861-
assert 'has_extra:false,name:"OuterModel"' in s_repr
860+
assert 'has_extra:true,root_model:false,name:"InnerModel"' in s_repr
861+
assert 'has_extra:false,root_model:false,name:"OuterModel"' in s_repr

tests/serializers/test_model_root.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import json
2+
import platform
3+
from typing import Any
4+
5+
try:
6+
from functools import cached_property
7+
except ImportError:
8+
cached_property = None
9+
10+
11+
from pydantic_core import SchemaSerializer, core_schema
12+
13+
from ..conftest import plain_repr
14+
15+
on_pypy = platform.python_implementation() == 'PyPy'
16+
# pypy doesn't seem to maintain order of `__dict__`
17+
if on_pypy:
18+
IsStrictDict = dict
19+
else:
20+
pass
21+
22+
23+
class RootModel:
24+
__slots__ = 'root'
25+
root: str
26+
27+
def __init__(self, data):
28+
self.root = data
29+
30+
31+
class RootSubModel(RootModel):
32+
pass
33+
34+
35+
def test_model_root():
36+
s = SchemaSerializer(core_schema.model_schema(RootModel, core_schema.int_schema(), root_model=True))
37+
print(plain_repr(s))
38+
# TODO: assert 'mode:RootModel' in plain_repr(s)
39+
assert 'has_extra:false' in plain_repr(s)
40+
assert s.to_python(RootModel(1)) == 1
41+
assert s.to_python(RootSubModel(1)) == 1
42+
43+
j = s.to_json(RootModel(1))
44+
if on_pypy:
45+
assert json.loads(j) == 1
46+
else:
47+
assert j == b'1'
48+
49+
assert json.loads(s.to_json(RootSubModel(1))) == 1
50+
51+
52+
def test_function_plain_field_serializer_to_python():
53+
class Model(RootModel):
54+
def ser_root(self, v: Any, _) -> str:
55+
assert self.root == 1_000
56+
return f'{v:_}'
57+
58+
s = SchemaSerializer(
59+
core_schema.model_schema(
60+
Model,
61+
core_schema.int_schema(
62+
serialization=core_schema.plain_serializer_function_ser_schema(
63+
Model.ser_root, is_field_serializer=True, info_arg=True
64+
)
65+
),
66+
root_model=True,
67+
)
68+
)
69+
assert s.to_python(Model(1000)) == '1_000'
70+
71+
72+
def test_function_wrap_field_serializer_to_python():
73+
class Model(RootModel):
74+
def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
75+
root = serializer(v)
76+
assert self.root == 1_000
77+
return f'{root:_}'
78+
79+
s = SchemaSerializer(
80+
core_schema.model_schema(
81+
Model,
82+
core_schema.int_schema(
83+
serialization=core_schema.wrap_serializer_function_ser_schema(
84+
Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
85+
)
86+
),
87+
root_model=True,
88+
)
89+
)
90+
assert s.to_python(Model(1000)) == '1_000'
91+
92+
93+
def test_function_plain_field_serializer_to_json():
94+
class Model(RootModel):
95+
def ser_root(self, v: Any, _) -> str:
96+
assert self.root == 1_000
97+
return f'{v:_}'
98+
99+
s = SchemaSerializer(
100+
core_schema.model_schema(
101+
Model,
102+
core_schema.int_schema(
103+
serialization=core_schema.plain_serializer_function_ser_schema(
104+
Model.ser_root, is_field_serializer=True, info_arg=True
105+
)
106+
),
107+
root_model=True,
108+
)
109+
)
110+
assert json.loads(s.to_json(Model(1000))) == '1_000'
111+
112+
113+
def test_function_wrap_field_serializer_to_json():
114+
class Model(RootModel):
115+
def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
116+
assert self.root == 1_000
117+
root = serializer(v)
118+
return f'{root:_}'
119+
120+
s = SchemaSerializer(
121+
core_schema.model_schema(
122+
Model,
123+
core_schema.int_schema(
124+
serialization=core_schema.wrap_serializer_function_ser_schema(
125+
Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
126+
)
127+
),
128+
root_model=True,
129+
)
130+
)
131+
assert json.loads(s.to_json(Model(1000))) == '1_000'

0 commit comments

Comments
 (0)