Skip to content

✨ Implement root model serialization #613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/serializers/type_serializers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl FunctionPlainSerializer {
self.func.call1(py, (model, value))
}
} else {
Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found"))
Err(PyRuntimeError::new_err("Function plain serializer expected to be run inside the context of a model field but no model was found"))
}
} else if self.info_arg {
let info = SerializationInfo::new(py, include, exclude, extra, self.is_field_serializer)?;
Expand Down Expand Up @@ -368,7 +368,7 @@ impl FunctionWrapSerializer {
self.func.call1(py, (model, value, serialize))
}
} else {
Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found"))
Err(PyRuntimeError::new_err("Function wrap serializer expected to be run inside the context of a model field but no model was found"))
}
} else if self.info_arg {
let info = SerializationInfo::new(py, include, exclude, extra, self.is_field_serializer)?;
Expand Down Expand Up @@ -492,20 +492,20 @@ impl SerializationInfo {
) -> PyResult<Self> {
if is_field_serializer {
match extra.field_name {
Some(field_name) => Ok(
Self {
include: include.map(|i| i.into_py(py)),
exclude: exclude.map(|e| e.into_py(py)),
_mode: extra.mode.clone(),
by_alias: extra.by_alias,
exclude_unset: extra.exclude_unset,
exclude_defaults: extra.exclude_defaults,
exclude_none: extra.exclude_none,
round_trip: extra.round_trip,
field_name: Some(field_name.to_string()),
}
),
_ => Err(PyRuntimeError::new_err("This serializer expected to be run inside the context of a model field but no model field was found")),
Some(field_name) => Ok(Self {
include: include.map(|i| i.into_py(py)),
exclude: exclude.map(|e| e.into_py(py)),
_mode: extra.mode.clone(),
by_alias: extra.by_alias,
exclude_unset: extra.exclude_unset,
exclude_defaults: extra.exclude_defaults,
exclude_none: extra.exclude_none,
round_trip: extra.round_trip,
field_name: Some(field_name.to_string()),
}),
_ => Err(PyRuntimeError::new_err(
"Model field context expected for field serialization info but no model field was found",
)),
}
} else {
Ok(Self {
Expand Down
24 changes: 20 additions & 4 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use super::{
SerField, TypeSerializer,
};

const ROOT_FIELD: &str = "root";

pub struct ModelFieldsBuilder;

impl BuildSerializer for ModelFieldsBuilder {
Expand Down Expand Up @@ -66,6 +68,7 @@ pub struct ModelSerializer {
class: Py<PyType>,
serializer: Box<CombinedSerializer>,
has_extra: bool,
root_model: bool,
name: String,
}

Expand All @@ -85,11 +88,13 @@ impl BuildSerializer for ModelSerializer {
let class: &PyType = schema.get_as_req(intern!(py, "cls"))?;
let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?);
let root_model = schema.get_as(intern!(py, "root_model"))?.unwrap_or(false);

Ok(Self {
class: class.into(),
serializer,
has_extra: has_extra(schema, config)?,
root_model,
name: class.getattr(intern!(py, "__name__"))?.extract()?,
}
.into())
Expand Down Expand Up @@ -139,11 +144,16 @@ impl TypeSerializer for ModelSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let extra = Extra {
let mut extra = Extra {
model: Some(value),
..*extra
};
if self.allow_value(value, &extra)? {
if self.root_model {
extra.field_name = Some(ROOT_FIELD);
let py = value.py();
let root = value.getattr(intern!(py, ROOT_FIELD))?;
self.serializer.to_python(root, include, exclude, &extra)
} else if self.allow_value(value, &extra)? {
let inner_value = self.get_inner_value(value, &extra)?;
self.serializer.to_python(inner_value, include, exclude, &extra)
} else {
Expand All @@ -169,11 +179,17 @@ impl TypeSerializer for ModelSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let extra = Extra {
let mut extra = Extra {
model: Some(value),
..*extra
};
if self.allow_value(value, &extra).map_err(py_err_se_err)? {
if self.root_model {
extra.field_name = Some(ROOT_FIELD);
let py = value.py();
let root = value.getattr(intern!(py, ROOT_FIELD)).map_err(py_err_se_err)?;
self.serializer
.serde_serialize(root, serializer, include, exclude, &extra)
} else if self.allow_value(value, &extra).map_err(py_err_se_err)? {
let inner_value = self.get_inner_value(value, &extra).map_err(py_err_se_err)?;
self.serializer
.serde_serialize(inner_value, serializer, include, exclude, &extra)
Expand Down
4 changes: 2 additions & 2 deletions tests/serializers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,5 +857,5 @@ class InnerModel:
s = SchemaSerializer(schema)
# debug(s)
s_repr = plain_repr(s)
assert 'has_extra:true,name:"InnerModel"' in s_repr
assert 'has_extra:false,name:"OuterModel"' in s_repr
assert 'has_extra:true,root_model:false,name:"InnerModel"' in s_repr
assert 'has_extra:false,root_model:false,name:"OuterModel"' in s_repr
131 changes: 131 additions & 0 deletions tests/serializers/test_model_root.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import json
import platform
from typing import Any

try:
from functools import cached_property
except ImportError:
cached_property = None


from pydantic_core import SchemaSerializer, core_schema

from ..conftest import plain_repr

on_pypy = platform.python_implementation() == 'PyPy'
# pypy doesn't seem to maintain order of `__dict__`
if on_pypy:
IsStrictDict = dict
else:
pass


class RootModel:
__slots__ = 'root'
root: str

def __init__(self, data):
self.root = data


class RootSubModel(RootModel):
pass


def test_model_root():
s = SchemaSerializer(core_schema.model_schema(RootModel, core_schema.int_schema(), root_model=True))
print(plain_repr(s))
# TODO: assert 'mode:RootModel' in plain_repr(s)
assert 'has_extra:false' in plain_repr(s)
assert s.to_python(RootModel(1)) == 1
assert s.to_python(RootSubModel(1)) == 1

j = s.to_json(RootModel(1))
if on_pypy:
assert json.loads(j) == 1
else:
assert j == b'1'

assert json.loads(s.to_json(RootSubModel(1))) == 1


def test_function_plain_field_serializer_to_python():
class Model(RootModel):
def ser_root(self, v: Any, _) -> str:
assert self.root == 1_000
return f'{v:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.plain_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True
)
),
root_model=True,
)
)
assert s.to_python(Model(1000)) == '1_000'


def test_function_wrap_field_serializer_to_python():
class Model(RootModel):
def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
root = serializer(v)
assert self.root == 1_000
return f'{root:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
)
),
root_model=True,
)
)
assert s.to_python(Model(1000)) == '1_000'


def test_function_plain_field_serializer_to_json():
class Model(RootModel):
def ser_root(self, v: Any, _) -> str:
assert self.root == 1_000
return f'{v:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.plain_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True
)
),
root_model=True,
)
)
assert json.loads(s.to_json(Model(1000))) == '1_000'


def test_function_wrap_field_serializer_to_json():
class Model(RootModel):
def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
assert self.root == 1_000
root = serializer(v)
return f'{root:_}'

s = SchemaSerializer(
core_schema.model_schema(
Model,
core_schema.int_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
)
),
root_model=True,
)
)
assert json.loads(s.to_json(Model(1000))) == '1_000'
13 changes: 13 additions & 0 deletions tests/validators/test_model_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,16 @@ def f(input_value: str, info):
"ValidationInfo(config=None, context='call 1', field_name='root')",
"ValidationInfo(config=None, context='assignment call', field_name='root')",
]


def test_extra():
class RootModel:
__slots__ = 'root'
root: int

v = SchemaValidator(core_schema.model_schema(RootModel, core_schema.int_schema(), root_model=True))

m = v.validate_python(1)

with pytest.raises(AttributeError):
m.__pydantic_extra__