Skip to content

Allow serializers to accept a model instance #437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 120 additions & 19 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def __repr__(self) -> str:
...


class FieldSerializationInfo(SerializationInfo, Protocol):
@property
def field_name(self) -> str:
...


class ValidationInfo(Protocol):
"""
Argument passed to validation functions.
Expand All @@ -109,7 +115,7 @@ def config(self) -> CoreConfig | None:
...


class ModelFieldValidationInfo(ValidationInfo, Protocol):
class FieldValidationInfo(ValidationInfo, Protocol):
"""
Argument passed to model field validation functions.
"""
Expand Down Expand Up @@ -166,11 +172,26 @@ def simple_ser_schema(type: ExpectedSerializationTypes) -> SimpleSerSchema:
return SimpleSerSchema(type=type)


class SerializePlainFunction(Protocol): # pragma: no cover
class GeneralSerializePlainFunction(Protocol): # pragma: no cover
def __call__(self, __input_value: Any, __info: SerializationInfo) -> Any:
...


class FieldSerializePlainFunction(Protocol): # pragma: no cover
def __call__(self, __model: Any, __input_value: Any, __info: FieldSerializationInfo) -> Any:
...


class GeneralSerializePlainFunctionSchema(TypedDict):
type: Literal['general']
function: GeneralSerializePlainFunction


class FieldSerializePlainFunctionSchema(TypedDict):
type: Literal['field']
function: FieldSerializePlainFunction


# must match `src/serializers/ob_type.rs::ObType`
JsonReturnTypes = Literal[
'int',
Expand Down Expand Up @@ -212,13 +233,16 @@ def __call__(self, __input_value: Any, __info: SerializationInfo) -> Any:

class FunctionPlainSerSchema(TypedDict, total=False):
type: Required[Literal['function-plain']]
function: Required[SerializePlainFunction]
function: Required[Union[GeneralSerializePlainFunctionSchema, FieldSerializePlainFunctionSchema]]
json_return_type: JsonReturnTypes
when_used: WhenUsed # default: 'always'


def function_plain_ser_schema(
function: SerializePlainFunction, *, json_return_type: JsonReturnTypes | None = None, when_used: WhenUsed = 'always'
def general_function_plain_ser_schema(
function: GeneralSerializePlainFunction,
*,
json_return_type: JsonReturnTypes | None = None,
when_used: WhenUsed = 'always',
) -> FunctionPlainSerSchema:
"""
Returns a schema for serialization with a function.
Expand All @@ -232,7 +256,35 @@ def function_plain_ser_schema(
# just to avoid extra elements in schema, and to use the actual default defined in rust
when_used = None # type: ignore
return dict_not_none(
type='function-plain', function=function, json_return_type=json_return_type, when_used=when_used
type='function-plain',
function={'type': 'general', 'function': function},
json_return_type=json_return_type,
when_used=when_used,
)


def field_function_plain_ser_schema(
function: FieldSerializePlainFunction,
*,
json_return_type: JsonReturnTypes | None = None,
when_used: WhenUsed = 'always',
) -> FunctionPlainSerSchema:
"""
Returns a schema to serialize a field from a model, TypedDict or dataclass.

Args:
function: The function to use for serialization
json_return_type: The type that the function returns if `mode='json'`
when_used: When the function should be called
"""
if when_used == 'always':
# just to avoid extra elements in schema, and to use the actual default defined in rust
when_used = None # type: ignore
return dict_not_none(
type='function-plain',
function={'type': 'field', 'function': function},
json_return_type=json_return_type,
when_used=when_used,
)


Expand All @@ -241,21 +293,38 @@ def __call__(self, __input_value: Any, __index_key: int | str | None = None) ->
...


class SerializeWrapFunction(Protocol): # pragma: no cover
class GeneralSerializeWrapFunction(Protocol): # pragma: no cover
def __call__(self, __input_value: Any, __serializer: SerializeWrapHandler, __info: SerializationInfo) -> Any:
...


class FieldSerializeWrapFunction(Protocol): # pragma: no cover
def __call__(
self, __model: Any, __input_value: Any, __serializer: SerializeWrapHandler, __info: FieldSerializationInfo
) -> Any:
...


class GeneralSerializeWrapFunctionSchema(TypedDict):
type: Literal['general']
function: GeneralSerializeWrapFunction


class FieldSerializeWrapFunctionSchema(TypedDict):
type: Literal['field']
function: FieldSerializeWrapFunction


class FunctionWrapSerSchema(TypedDict, total=False):
type: Required[Literal['function-wrap']]
function: Required[SerializeWrapFunction]
function: Required[Union[GeneralSerializeWrapFunctionSchema, FieldSerializeWrapFunctionSchema]]
schema: Required[CoreSchema]
json_return_type: JsonReturnTypes
when_used: WhenUsed # default: 'always'


def function_wrap_ser_schema(
function: SerializeWrapFunction,
def general_function_wrap_ser_schema(
function: GeneralSerializeWrapFunction,
schema: CoreSchema,
*,
json_return_type: JsonReturnTypes | None = None,
Expand All @@ -274,7 +343,39 @@ def function_wrap_ser_schema(
# just to avoid extra elements in schema, and to use the actual default defined in rust
when_used = None # type: ignore
return dict_not_none(
type='function-wrap', schema=schema, function=function, json_return_type=json_return_type, when_used=when_used
type='function-wrap',
schema=schema,
function={'type': 'general', 'function': function},
json_return_type=json_return_type,
when_used=when_used,
)


def field_function_wrap_ser_schema(
function: FieldSerializeWrapFunction,
schema: CoreSchema,
*,
json_return_type: JsonReturnTypes | None = None,
when_used: WhenUsed = 'always',
) -> FunctionWrapSerSchema:
"""
Returns a schema to serialize a field from a model, TypedDict or dataclass.

Args:
function: The function to use for serialization
schema: The schema to use for the inner serialization
json_return_type: The type that the function returns if `mode='json'`
when_used: When the function should be called
"""
if when_used == 'always':
# just to avoid extra elements in schema, and to use the actual default defined in rust
when_used = None # type: ignore
return dict_not_none(
type='function-wrap',
schema=schema,
function={'type': 'field', 'function': function},
json_return_type=json_return_type,
when_used=when_used,
)


Expand All @@ -290,7 +391,7 @@ def format_ser_schema(formatting_string: str, *, when_used: WhenUsed = 'json-unl

Args:
formatting_string: String defining the format to use
when_used: Same meaning as for [function_plain_ser_schema], but with a different default
when_used: Same meaning as for [general_function_plain_ser_schema], but with a different default
"""
if when_used == 'json-unless-none':
# just to avoid extra elements in schema, and to use the actual default defined in rust
Expand All @@ -308,7 +409,7 @@ def to_string_ser_schema(*, when_used: WhenUsed = 'json-unless-none') -> ToStrin
Returns a schema for serialization using python's `str()` / `__str__` method.

Args:
when_used: Same meaning as for [function_plain_ser_schema], but with a different default
when_used: Same meaning as for [general_function_plain_ser_schema], but with a different default
"""
s = dict(type='to-string')
if when_used != 'json-unless-none':
Expand Down Expand Up @@ -1491,7 +1592,7 @@ def __call__(self, __input_value: Any, __info: ValidationInfo) -> Any: # pragma


class FieldValidatorFunction(Protocol):
def __call__(self, __input_value: Any, __info: ModelFieldValidationInfo) -> Any: # pragma: no cover
def __call__(self, __input_value: Any, __info: FieldValidationInfo) -> Any: # pragma: no cover
...


Expand Down Expand Up @@ -1532,7 +1633,7 @@ def field_before_validation_function(
```py
from pydantic_core import SchemaValidator, core_schema

def fn(v: bytes, info: core_schema.ModelFieldValidationInfo) -> str:
def fn(v: bytes, info: core_schema.FieldValidationInfo) -> str:
assert info.data is not None
assert info.field_name is not None
return v.decode() + 'world'
Expand Down Expand Up @@ -1624,7 +1725,7 @@ def field_after_validation_function(
```py
from pydantic_core import SchemaValidator, core_schema

def fn(v: str, info: core_schema.ModelFieldValidationInfo) -> str:
def fn(v: str, info: core_schema.FieldValidationInfo) -> str:
assert info.data is not None
assert info.field_name is not None
return v + 'world'
Expand Down Expand Up @@ -1709,7 +1810,7 @@ def __call__(

class FieldWrapValidatorFunction(Protocol):
def __call__(
self, __input_value: Any, __validator: CallableValidator, __info: ModelFieldValidationInfo
self, __input_value: Any, __validator: CallableValidator, __info: FieldValidationInfo
) -> Any: # pragma: no cover
...

Expand Down Expand Up @@ -1791,7 +1892,7 @@ def field_wrap_validation_function(
```py
from pydantic_core import SchemaValidator, core_schema

def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.ModelFieldValidationInfo) -> str:
def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.FieldValidationInfo) -> str:
assert info.data is not None
assert info.field_name is not None
return validator(v) + 'world'
Expand Down Expand Up @@ -1881,7 +1982,7 @@ def field_plain_validation_function(
from typing import Any
from pydantic_core import SchemaValidator, core_schema

def fn(v: Any, info: core_schema.ModelFieldValidationInfo) -> str:
def fn(v: Any, info: core_schema.FieldValidationInfo) -> str:
assert info.data is not None
assert info.field_name is not None
return str(v) + 'world'
Expand Down
13 changes: 13 additions & 0 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ pub(crate) struct Extra<'a> {
pub rec_guard: &'a SerRecursionGuard,
// the next two are used for union logic
pub check: SerCheck,
// data representing the current model field
// that is being serialized, if this is a model serializer
// it will be None otherwise
pub model: Option<&'a PyAny>,
pub field_name: Option<&'a str>,
}

impl<'a> Extra<'a> {
Expand Down Expand Up @@ -62,6 +67,8 @@ impl<'a> Extra<'a> {
config,
rec_guard,
check: SerCheck::None,
model: None,
field_name: None,
}
}
}
Expand Down Expand Up @@ -97,6 +104,8 @@ pub(crate) struct ExtraOwned {
config: SerializationConfig,
rec_guard: SerRecursionGuard,
check: SerCheck,
model: Option<Py<PyAny>>,
field_name: Option<String>,
}

impl ExtraOwned {
Expand All @@ -113,6 +122,8 @@ impl ExtraOwned {
config: extra.config.clone(),
rec_guard: extra.rec_guard.clone(),
check: extra.check,
model: extra.model.map(|v| v.into()),
field_name: extra.field_name.map(|v| v.to_string()),
}
}

Expand All @@ -130,6 +141,8 @@ impl ExtraOwned {
config: &self.config,
rec_guard: &self.rec_guard,
check: self.check,
model: self.model.as_ref().map(|m| m.as_ref(py)),
field_name: self.field_name.as_ref().map(|n| n.as_ref()),
}
}
}
Expand Down
Loading