Skip to content

Commit 6a0effd

Browse files
committed
Pass field name and model into field serializers
1 parent 5efeaf9 commit 6a0effd

File tree

14 files changed

+342
-83
lines changed

14 files changed

+342
-83
lines changed

pydantic_core/core_schema.py

Lines changed: 120 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ def __repr__(self) -> str:
9393
...
9494

9595

96+
class FieldSerializationInfo(SerializationInfo, Protocol):
97+
@property
98+
def field_name(self) -> str:
99+
...
100+
101+
96102
class ValidationInfo(Protocol):
97103
"""
98104
Argument passed to validation functions.
@@ -109,7 +115,7 @@ def config(self) -> CoreConfig | None:
109115
...
110116

111117

112-
class ModelFieldValidationInfo(ValidationInfo, Protocol):
118+
class FieldValidationInfo(ValidationInfo, Protocol):
113119
"""
114120
Argument passed to model field validation functions.
115121
"""
@@ -166,11 +172,26 @@ def simple_ser_schema(type: ExpectedSerializationTypes) -> SimpleSerSchema:
166172
return SimpleSerSchema(type=type)
167173

168174

169-
class SerializePlainFunction(Protocol): # pragma: no cover
175+
class GeneralSerializePlainFunction(Protocol): # pragma: no cover
170176
def __call__(self, __input_value: Any, __info: SerializationInfo) -> Any:
171177
...
172178

173179

180+
class FieldSerializePlainFunction(Protocol): # pragma: no cover
181+
def __call__(self, __model: Any, __input_value: Any, __info: FieldSerializationInfo) -> Any:
182+
...
183+
184+
185+
class GeneralSerializePlainFunctionSchema(TypedDict):
186+
type: Literal['general']
187+
function: GeneralSerializePlainFunction
188+
189+
190+
class FieldSerializePlainFunctionSchema(TypedDict):
191+
type: Literal['field']
192+
function: FieldSerializePlainFunction
193+
194+
174195
# must match `src/serializers/ob_type.rs::ObType`
175196
JsonReturnTypes = Literal[
176197
'int',
@@ -212,13 +233,41 @@ def __call__(self, __input_value: Any, __info: SerializationInfo) -> Any:
212233

213234
class FunctionPlainSerSchema(TypedDict, total=False):
214235
type: Required[Literal['function-plain']]
215-
function: Required[SerializePlainFunction]
236+
function: Required[Union[GeneralSerializePlainFunctionSchema, FieldSerializePlainFunctionSchema]]
216237
json_return_type: JsonReturnTypes
217238
when_used: WhenUsed # default: 'always'
218239

219240

220-
def function_plain_ser_schema(
221-
function: SerializePlainFunction, *, json_return_type: JsonReturnTypes | None = None, when_used: WhenUsed = 'always'
241+
def general_function_plain_ser_schema(
242+
function: GeneralSerializePlainFunction,
243+
*,
244+
json_return_type: JsonReturnTypes | None = None,
245+
when_used: WhenUsed = 'always',
246+
) -> FunctionPlainSerSchema:
247+
"""
248+
Returns a schema for serialization with a function.
249+
250+
Args:
251+
function: The function to use for serialization
252+
json_return_type: The type that the function returns if `mode='json'`
253+
when_used: When the function should be called
254+
"""
255+
if when_used == 'always':
256+
# just to avoid extra elements in schema, and to use the actual default defined in rust
257+
when_used = None # type: ignore
258+
return dict_not_none(
259+
type='function-plain',
260+
function={'type': 'general', 'function': function},
261+
json_return_type=json_return_type,
262+
when_used=when_used,
263+
)
264+
265+
266+
def field_function_plain_ser_schema(
267+
function: FieldSerializePlainFunction,
268+
*,
269+
json_return_type: JsonReturnTypes | None = None,
270+
when_used: WhenUsed = 'always',
222271
) -> FunctionPlainSerSchema:
223272
"""
224273
Returns a schema for serialization with a function.
@@ -232,7 +281,10 @@ def function_plain_ser_schema(
232281
# just to avoid extra elements in schema, and to use the actual default defined in rust
233282
when_used = None # type: ignore
234283
return dict_not_none(
235-
type='function-plain', function=function, json_return_type=json_return_type, when_used=when_used
284+
type='function-plain',
285+
function={'type': 'field', 'function': function},
286+
json_return_type=json_return_type,
287+
when_used=when_used,
236288
)
237289

238290

@@ -241,21 +293,38 @@ def __call__(self, __input_value: Any, __index_key: int | str | None = None) ->
241293
...
242294

243295

244-
class SerializeWrapFunction(Protocol): # pragma: no cover
296+
class GeneralSerializeWrapFunction(Protocol): # pragma: no cover
245297
def __call__(self, __input_value: Any, __serializer: SerializeWrapHandler, __info: SerializationInfo) -> Any:
246298
...
247299

248300

301+
class FieldSerializeWrapFunction(Protocol): # pragma: no cover
302+
def __call__(
303+
self, __model: Any, __input_value: Any, __serializer: SerializeWrapHandler, __info: FieldSerializationInfo
304+
) -> Any:
305+
...
306+
307+
308+
class GeneralSerializeWrapFunctionSchema(TypedDict):
309+
type: Literal['general']
310+
function: GeneralSerializeWrapFunction
311+
312+
313+
class FieldSerializeWrapFunctionSchema(TypedDict):
314+
type: Literal['field']
315+
function: FieldSerializeWrapFunction
316+
317+
249318
class FunctionWrapSerSchema(TypedDict, total=False):
250319
type: Required[Literal['function-wrap']]
251-
function: Required[SerializeWrapFunction]
320+
function: Required[Union[GeneralSerializeWrapFunctionSchema, FieldSerializeWrapFunctionSchema]]
252321
schema: Required[CoreSchema]
253322
json_return_type: JsonReturnTypes
254323
when_used: WhenUsed # default: 'always'
255324

256325

257-
def function_wrap_ser_schema(
258-
function: SerializeWrapFunction,
326+
def general_function_wrap_ser_schema(
327+
function: GeneralSerializeWrapFunction,
259328
schema: CoreSchema,
260329
*,
261330
json_return_type: JsonReturnTypes | None = None,
@@ -274,7 +343,39 @@ def function_wrap_ser_schema(
274343
# just to avoid extra elements in schema, and to use the actual default defined in rust
275344
when_used = None # type: ignore
276345
return dict_not_none(
277-
type='function-wrap', schema=schema, function=function, json_return_type=json_return_type, when_used=when_used
346+
type='function-wrap',
347+
schema=schema,
348+
function={'type': 'general', 'function': function},
349+
json_return_type=json_return_type,
350+
when_used=when_used,
351+
)
352+
353+
354+
def field_function_wrap_ser_schema(
355+
function: FieldSerializeWrapFunction,
356+
schema: CoreSchema,
357+
*,
358+
json_return_type: JsonReturnTypes | None = None,
359+
when_used: WhenUsed = 'always',
360+
) -> FunctionWrapSerSchema:
361+
"""
362+
Returns a schema for serialization with a function for a model field.
363+
364+
Args:
365+
function: The function to use for serialization
366+
schema: The schema to use for the inner serialization
367+
json_return_type: The type that the function returns if `mode='json'`
368+
when_used: When the function should be called
369+
"""
370+
if when_used == 'always':
371+
# just to avoid extra elements in schema, and to use the actual default defined in rust
372+
when_used = None # type: ignore
373+
return dict_not_none(
374+
type='function-wrap',
375+
schema=schema,
376+
function={'type': 'field', 'function': function},
377+
json_return_type=json_return_type,
378+
when_used=when_used,
278379
)
279380

280381

@@ -290,7 +391,7 @@ def format_ser_schema(formatting_string: str, *, when_used: WhenUsed = 'json-unl
290391
291392
Args:
292393
formatting_string: String defining the format to use
293-
when_used: Same meaning as for [function_plain_ser_schema], but with a different default
394+
when_used: Same meaning as for [general_function_plain_ser_schema], but with a different default
294395
"""
295396
if when_used == 'json-unless-none':
296397
# just to avoid extra elements in schema, and to use the actual default defined in rust
@@ -308,7 +409,7 @@ def to_string_ser_schema(*, when_used: WhenUsed = 'json-unless-none') -> ToStrin
308409
Returns a schema for serialization using python's `str()` / `__str__` method.
309410
310411
Args:
311-
when_used: Same meaning as for [function_plain_ser_schema], but with a different default
412+
when_used: Same meaning as for [general_function_plain_ser_schema], but with a different default
312413
"""
313414
s = dict(type='to-string')
314415
if when_used != 'json-unless-none':
@@ -1491,7 +1592,7 @@ def __call__(self, __input_value: Any, __info: ValidationInfo) -> Any: # pragma
14911592

14921593

14931594
class FieldValidatorFunction(Protocol):
1494-
def __call__(self, __input_value: Any, __info: ModelFieldValidationInfo) -> Any: # pragma: no cover
1595+
def __call__(self, __input_value: Any, __info: FieldValidationInfo) -> Any: # pragma: no cover
14951596
...
14961597

14971598

@@ -1532,7 +1633,7 @@ def field_before_validation_function(
15321633
```py
15331634
from pydantic_core import SchemaValidator, core_schema
15341635
1535-
def fn(v: bytes, info: core_schema.ModelFieldValidationInfo) -> str:
1636+
def fn(v: bytes, info: core_schema.FieldValidationInfo) -> str:
15361637
assert info.data is not None
15371638
assert info.field_name is not None
15381639
return v.decode() + 'world'
@@ -1624,7 +1725,7 @@ def field_after_validation_function(
16241725
```py
16251726
from pydantic_core import SchemaValidator, core_schema
16261727
1627-
def fn(v: str, info: core_schema.ModelFieldValidationInfo) -> str:
1728+
def fn(v: str, info: core_schema.FieldValidationInfo) -> str:
16281729
assert info.data is not None
16291730
assert info.field_name is not None
16301731
return v + 'world'
@@ -1709,7 +1810,7 @@ def __call__(
17091810

17101811
class FieldWrapValidatorFunction(Protocol):
17111812
def __call__(
1712-
self, __input_value: Any, __validator: CallableValidator, __info: ModelFieldValidationInfo
1813+
self, __input_value: Any, __validator: CallableValidator, __info: FieldValidationInfo
17131814
) -> Any: # pragma: no cover
17141815
...
17151816

@@ -1791,7 +1892,7 @@ def field_wrap_validation_function(
17911892
```py
17921893
from pydantic_core import SchemaValidator, core_schema
17931894
1794-
def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.ModelFieldValidationInfo) -> str:
1895+
def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.FieldValidationInfo) -> str:
17951896
assert info.data is not None
17961897
assert info.field_name is not None
17971898
return validator(v) + 'world'
@@ -1881,7 +1982,7 @@ def field_plain_validation_function(
18811982
from typing import Any
18821983
from pydantic_core import SchemaValidator, core_schema
18831984
1884-
def fn(v: Any, info: core_schema.ModelFieldValidationInfo) -> str:
1985+
def fn(v: Any, info: core_schema.FieldValidationInfo) -> str:
18851986
assert info.data is not None
18861987
assert info.field_name is not None
18871988
return str(v) + 'world'

src/serializers/extra.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ pub(crate) struct Extra<'a> {
3232
pub rec_guard: &'a SerRecursionGuard,
3333
// the next two are used for union logic
3434
pub check: SerCheck,
35+
// data representing the current model field
36+
// that is being serialized, if this is a model serializer
37+
// it will be None otherwise
38+
pub model: Option<&'a PyAny>,
39+
pub field_name: Option<&'a str>,
3540
}
3641

3742
impl<'a> Extra<'a> {
@@ -62,6 +67,8 @@ impl<'a> Extra<'a> {
6267
config,
6368
rec_guard,
6469
check: SerCheck::None,
70+
model: None,
71+
field_name: None,
6572
}
6673
}
6774
}
@@ -97,6 +104,8 @@ pub(crate) struct ExtraOwned {
97104
config: SerializationConfig,
98105
rec_guard: SerRecursionGuard,
99106
check: SerCheck,
107+
model: Option<Py<PyAny>>,
108+
field_name: Option<String>,
100109
}
101110

102111
impl ExtraOwned {
@@ -113,6 +122,8 @@ impl ExtraOwned {
113122
config: extra.config.clone(),
114123
rec_guard: extra.rec_guard.clone(),
115124
check: extra.check,
125+
model: extra.model.map(|v| v.into()),
126+
field_name: extra.field_name.map(|v| v.to_string()),
116127
}
117128
}
118129

@@ -130,6 +141,8 @@ impl ExtraOwned {
130141
config: &self.config,
131142
rec_guard: &self.rec_guard,
132143
check: self.check,
144+
model: self.model.as_ref().map(|m| m.as_ref(py)),
145+
field_name: self.field_name.as_ref().map(|n| n.as_ref()),
133146
}
134147
}
135148
}

0 commit comments

Comments
 (0)