Skip to content

Commit 027e679

Browse files
authored
Use return_schema for computed fields (#595)
1 parent 0d57243 commit 027e679

File tree

8 files changed

+54
-42
lines changed

8 files changed

+54
-42
lines changed

pydantic_core/core_schema.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,23 +415,25 @@ def model_ser_schema(cls: Type[Any], schema: CoreSchema) -> ModelSerSchema:
415415
class ComputedField(TypedDict, total=False):
416416
type: Required[Literal['computed-field']]
417417
property_name: Required[str]
418-
json_return_type: JsonReturnTypes
418+
return_schema: Required[CoreSchema]
419419
alias: str
420+
metadata: Any
420421

421422

422423
def computed_field(
423-
property_name: str, *, json_return_type: JsonReturnTypes | None = None, alias: str | None = None
424+
property_name: str, return_schema: CoreSchema, *, alias: str | None = None, metadata: Any = None
424425
) -> ComputedField:
425426
"""
426427
ComputedFields are properties of a model or dataclass that are included in serialization.
427428
428429
Args:
429430
property_name: The name of the property on the model or dataclass
430-
json_return_type: The type that the property returns if `mode='json'`
431+
return_schema: The schema used for the type returned by the computed field
431432
alias: The name to use in the serialized output
433+
metadata: Any other information you want to include with the schema, not used by pydantic-core
432434
"""
433435
return dict_not_none(
434-
type='computed-field', property_name=property_name, json_return_type=json_return_type, alias=alias
436+
type='computed-field', property_name=property_name, return_schema=return_schema, alias=alias, metadata=metadata
435437
)
436438

437439

@@ -3677,6 +3679,8 @@ def definition_reference_schema(
36773679
'definition-ref',
36783680
]
36793681

3682+
CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
3683+
36803684

36813685
# used in _pydantic_core.pyi::PydanticKnownError
36823686
# to update this, call `pytest -k test_all_errors` and copy the output

src/serializers/computed_fields.rs

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
11
use pyo3::intern;
22
use pyo3::prelude::*;
33
use pyo3::types::{PyDict, PyList, PyString};
4-
54
use serde::ser::SerializeMap;
65
use serde::Serialize;
76

8-
use crate::build_tools::SchemaDict;
7+
use crate::build_tools::{py_error_type, SchemaDict};
8+
use crate::definitions::DefinitionsBuilder;
99
use crate::serializers::filter::SchemaFilter;
10+
use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer};
1011

1112
use super::errors::py_err_se_err;
12-
use super::infer::{infer_serialize, infer_serialize_known, infer_to_python, infer_to_python_known};
13-
use super::ob_type::ObType;
1413
use super::Extra;
1514

16-
use super::type_serializers::function::get_json_return_type;
17-
1815
#[derive(Debug, Clone)]
1916
pub(super) struct ComputedFields(Vec<ComputedField>);
2017

2118
impl ComputedFields {
22-
pub fn new(schema: &PyDict) -> PyResult<Option<Self>> {
19+
pub fn new(
20+
schema: &PyDict,
21+
config: Option<&PyDict>,
22+
definitions: &mut DefinitionsBuilder<CombinedSerializer>,
23+
) -> PyResult<Option<Self>> {
2324
let py = schema.py();
2425
if let Some(computed_fields) = schema.get_as::<&PyList>(intern!(py, "computed_fields"))? {
2526
let computed_fields = computed_fields
2627
.iter()
27-
.map(ComputedField::new)
28+
.map(|field| ComputedField::new(field, config, definitions))
2829
.collect::<PyResult<Vec<_>>>()?;
2930
Ok(Some(Self(computed_fields)))
3031
} else {
@@ -88,22 +89,28 @@ impl ComputedFields {
8889
struct ComputedField {
8990
property_name: String,
9091
property_name_py: Py<PyString>,
91-
return_ob_type: Option<ObType>,
92+
serializer: CombinedSerializer,
9293
alias: String,
9394
alias_py: Py<PyString>,
9495
}
9596

9697
impl ComputedField {
97-
pub fn new(schema: &PyAny) -> PyResult<Self> {
98+
pub fn new(
99+
schema: &PyAny,
100+
config: Option<&PyDict>,
101+
definitions: &mut DefinitionsBuilder<CombinedSerializer>,
102+
) -> PyResult<Self> {
98103
let py = schema.py();
99104
let schema: &PyDict = schema.downcast()?;
100105
let property_name: &PyString = schema.get_as_req(intern!(py, "property_name"))?;
101-
let return_ob_type = get_json_return_type(schema)?;
106+
let return_schema = schema.get_as_req(intern!(py, "return_schema"))?;
107+
let serializer = CombinedSerializer::build(return_schema, config, definitions)
108+
.map_err(|e| py_error_type!("Computed field `{}`:\n {}", property_name, e))?;
102109
let alias_py: &PyString = schema.get_as(intern!(py, "alias"))?.unwrap_or(property_name);
103110
Ok(Self {
104111
property_name: property_name.extract()?,
105112
property_name_py: property_name.into_py(py),
106-
return_ob_type,
113+
serializer,
107114
alias: alias_py.extract()?,
108115
alias_py: alias_py.into_py(py),
109116
})
@@ -124,11 +131,9 @@ impl ComputedField {
124131
if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? {
125132
let next_value = model.getattr(property_name_py)?;
126133

127-
// TODO fix include & exclude
128-
let value = match self.return_ob_type {
129-
Some(ref ob_type) => infer_to_python_known(ob_type, next_value, next_include, next_exclude, extra),
130-
None => infer_to_python(next_value, next_include, next_exclude, extra),
131-
}?;
134+
let value = self
135+
.serializer
136+
.to_python(next_value, next_include, next_exclude, extra)?;
132137
let key = match extra.by_alias {
133138
true => self.alias_py.as_ref(py),
134139
false => property_name_py,
@@ -152,12 +157,13 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> {
152157
let py = self.model.py();
153158
let property_name_py = self.computed_field.property_name_py.as_ref(py);
154159
let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?;
155-
156-
match self.computed_field.return_ob_type {
157-
Some(ref ob_type) => {
158-
infer_serialize_known(ob_type, next_value, serializer, self.include, self.exclude, self.extra)
159-
}
160-
None => infer_serialize(next_value, serializer, self.include, self.exclude, self.extra),
161-
}
160+
let s = PydanticSerializer::new(
161+
next_value,
162+
&self.computed_field.serializer,
163+
self.include,
164+
self.exclude,
165+
self.extra,
166+
);
167+
s.serialize(serializer)
162168
}
163169
}

src/serializers/type_serializers/dataclass.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl BuildSerializer for DataclassArgsBuilder {
4848
}
4949
}
5050

51-
let computed_fields = ComputedFields::new(schema)?;
51+
let computed_fields = ComputedFields::new(schema, config, definitions)?;
5252

5353
Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
5454
}

src/serializers/type_serializers/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl BuildSerializer for ModelFieldsBuilder {
5555
}
5656
}
5757

58-
let computed_fields = ComputedFields::new(schema)?;
58+
let computed_fields = ComputedFields::new(schema, config, definitions)?;
5959

6060
Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
6161
}

src/serializers/type_serializers/typed_dict.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl BuildSerializer for TypedDictBuilder {
5353
}
5454
}
5555

56-
let computed_fields = ComputedFields::new(schema)?;
56+
let computed_fields = ComputedFields::new(schema, config, definitions)?;
5757

5858
Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
5959
}

tests/serializers/test_dataclasses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def c(self) -> str:
105105
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
106106
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema()),
107107
],
108-
computed_fields=[core_schema.computed_field('c')],
108+
computed_fields=[core_schema.computed_field('c', core_schema.str_schema())],
109109
),
110110
)
111111
s = SchemaSerializer(schema)

tests/serializers/test_model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def area(self) -> bytes:
546546
'width': core_schema.model_field(core_schema.int_schema()),
547547
'height': core_schema.model_field(core_schema.int_schema()),
548548
},
549-
computed_fields=[core_schema.computed_field('area', json_return_type='bytes')],
549+
computed_fields=[core_schema.computed_field('area', core_schema.bytes_schema())],
550550
),
551551
)
552552
)
@@ -578,8 +578,8 @@ def volume(self) -> int:
578578
'height': core_schema.model_field(core_schema.int_schema()),
579579
},
580580
computed_fields=[
581-
core_schema.computed_field('area', alias='Area'),
582-
core_schema.computed_field('volume'),
581+
core_schema.computed_field('area', core_schema.int_schema(), alias='Area'),
582+
core_schema.computed_field('volume', core_schema.int_schema()),
583583
],
584584
),
585585
)
@@ -608,7 +608,7 @@ def area(self) -> int:
608608
'width': core_schema.model_field(core_schema.int_schema()),
609609
'height': core_schema.model_field(core_schema.int_schema()),
610610
},
611-
computed_fields=[core_schema.computed_field('area')],
611+
computed_fields=[core_schema.computed_field('area', core_schema.int_schema())],
612612
),
613613
)
614614
)
@@ -627,7 +627,7 @@ class Model:
627627
Model,
628628
core_schema.model_fields_schema(
629629
{'width': core_schema.model_field(core_schema.int_schema())},
630-
computed_fields=[core_schema.computed_field('area', json_return_type='bytes')],
630+
computed_fields=[core_schema.computed_field('area', core_schema.bytes_schema())],
631631
),
632632
)
633633
)
@@ -655,7 +655,7 @@ def area(self) -> int:
655655
Model,
656656
core_schema.model_fields_schema(
657657
{'width': core_schema.model_field(core_schema.int_schema())},
658-
computed_fields=[core_schema.computed_field('area', json_return_type='bytes')],
658+
computed_fields=[core_schema.computed_field('area', core_schema.bytes_schema())],
659659
),
660660
)
661661
)
@@ -684,7 +684,7 @@ def b(self):
684684
Model,
685685
core_schema.model_fields_schema(
686686
{'a': core_schema.model_field(core_schema.int_schema())},
687-
computed_fields=[core_schema.computed_field('b')],
687+
computed_fields=[core_schema.computed_field('b', core_schema.list_schema())],
688688
),
689689
)
690690
)
@@ -734,8 +734,8 @@ def random_n(self) -> int:
734734
core_schema.model_fields_schema(
735735
{'side': core_schema.model_field(core_schema.float_schema())},
736736
computed_fields=[
737-
core_schema.computed_field('area', json_return_type='float'),
738-
core_schema.computed_field('random_n', alias='The random number', json_return_type='int'),
737+
core_schema.computed_field('area', core_schema.float_schema()),
738+
core_schema.computed_field('random_n', core_schema.int_schema(), alias='The random number'),
739739
],
740740
),
741741
)

tests/test.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ schema = {
4242
"cls": A,
4343
"config": {},
4444
"schema": {
45-
"computed_fields": [{"property_name": "b", "type": "computed-field"}],
45+
"computed_fields": [
46+
{"property_name": "b", "return_schema": {"type": "any"}, "type": "computed-field"}
47+
],
4648
"fields": {},
4749
"type": "model-fields",
4850
},

0 commit comments

Comments
 (0)