Skip to content

Support complex numbers #1331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal, complex):
return {'type': obj.__name__.lower()}
elif is_typeddict(obj):
return type_dict_schema(obj, definitions)
Expand Down
46 changes: 46 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,48 @@ def decimal_schema(
)


class ComplexSchema(TypedDict, total=False):
type: Required[Literal['complex']]
strict: bool
ref: str
metadata: Any
serialization: SerSchema


def complex_schema(
*,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> ComplexSchema:
"""
Returns a schema that matches a complex value, e.g.:

```py
from pydantic_core import SchemaValidator, core_schema

schema = core_schema.complex_schema()
v = SchemaValidator(schema)
assert v.validate_python('1+2j') == complex(1, 2)
assert v.validate_python(complex(1, 2)) == complex(1, 2)
```

Args:
strict: Whether the value should be a complex object instance or a value that can be converted to a complex object
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='complex',
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: Union[str, Pattern[str]]
Expand Down Expand Up @@ -3796,6 +3838,7 @@ def definition_reference_schema(
DefinitionsSchema,
DefinitionReferenceSchema,
UuidSchema,
ComplexSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3851,6 +3894,7 @@ def definition_reference_schema(
'definitions',
'definition-ref',
'uuid',
'complex',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down Expand Up @@ -3956,6 +4000,8 @@ def definition_reference_schema(
'decimal_max_digits',
'decimal_max_places',
'decimal_whole_digits',
'complex_type',
'complex_str_parsing',
]


Expand Down
5 changes: 5 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ error_types! {
DecimalWholeDigits {
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
},
// Complex errors
ComplexType {},
ComplexStrParsing {},
}

macro_rules! render {
Expand Down Expand Up @@ -569,6 +572,8 @@ impl ErrorType {
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::tools::py_err;
use crate::validators::ValBytesMode;

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString};
use super::{EitherFloat, GenericIterator, ValidationMatch};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -173,6 +173,8 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValMatch<EitherTimedelta<'py>>;

fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValMatch<EitherComplex<'py>>;
}

/// The problem to solve here is that iterating collections often returns owned
Expand Down
33 changes: 33 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::input::return_enums::EitherComplex;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

Expand Down Expand Up @@ -304,6 +306,30 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
_ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
match self {
JsonValue::Str(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(
&PyString::new_bound(py, s),
self,
)?))),
JsonValue::Float(f) => {
if !strict {
Ok(ValidationMatch::lax(EitherComplex::Complex([*f, 0.0])))
} else {
Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self))
}
}
JsonValue::Int(f) => {
if !strict {
Ok(ValidationMatch::lax(EitherComplex::Complex([(*f) as f64, 0.0])))
} else {
Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self))
}
}
_ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
}
}
}

/// Required for JSON Object keys so the string can behave like an Input
Expand Down Expand Up @@ -440,6 +466,13 @@ impl<'py> Input<'py> for str {
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>> {
bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax)
}

fn validate_complex(&self, _strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(
self.to_object(py).downcast_bound::<PyString>(py)?,
self,
)?)))
}
}

impl BorrowInput<'_> for &'_ String {
Expand Down
46 changes: 44 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ use pyo3::prelude::*;

use pyo3::types::PyType;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySet, PyString, PyTime, PyTuple,
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
PyList, PyMapping, PySet, PyString, PyTime, PyTuple,
};

use pyo3::PyTypeCheck;
use pyo3::PyTypeInfo;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::validators::ValBytesMode;
Expand All @@ -25,6 +27,7 @@ use super::datetime::{
EitherTime,
};
use super::input_abstract::ValMatch;
use super::return_enums::EitherComplex;
use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch};
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
Expand Down Expand Up @@ -598,6 +601,45 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {

Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self))
}

fn validate_complex<'a>(&'a self, strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
if let Ok(complex) = self.downcast::<PyComplex>() {
return Ok(ValidationMatch::strict(EitherComplex::Py(complex.to_owned())));
}
if strict {
return Err(ValError::new(
ErrorType::IsInstanceOf {
class: PyComplex::type_object_bound(py)
.qualname()
.and_then(|name| name.extract())
.unwrap_or_else(|_| "complex".to_owned()),
context: None,
},
self,
));
}

if let Ok(s) = self.downcast::<PyString>() {
// If input is not a valid complex string, instead of telling users to correct
// the string, it makes more sense to tell them to provide any acceptable value
// since they might have just given values of some incorrect types instead
// of actually trying some complex strings.
if let Ok(c) = string_to_complex(s, self) {
return Ok(ValidationMatch::lax(EitherComplex::Py(c)));
}
} else if self.is_exact_instance_of::<PyFloat>() {
return Ok(ValidationMatch::lax(EitherComplex::Complex([
self.extract::<f64>().unwrap(),
0.0,
])));
} else if self.is_exact_instance_of::<PyInt>() {
return Ok(ValidationMatch::lax(EitherComplex::Complex([
self.extract::<i64>().unwrap() as f64,
0.0,
])));
}
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
}
}

impl<'py> BorrowInput<'py> for Bound<'py, PyAny> {
Expand Down
9 changes: 9 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}
use crate::input::py_string_str;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
};
use super::input_abstract::{Never, ValMatch};
use super::return_enums::EitherComplex;
use super::shared::{str_as_bool, str_as_float, str_as_int};
use super::{
Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input,
Expand Down Expand Up @@ -225,6 +227,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self, _strict: bool, _py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
match self {
Self::String(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(s, self)?))),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
}
}
}

impl<'py> BorrowInput<'py> for StringMapping<'py> {
Expand Down
29 changes: 28 additions & 1 deletion src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pyo3::intern;
use pyo3::prelude::*;
#[cfg(not(PyPy))]
use pyo3::types::PyFunction;
use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};
use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};

use serde::{ser::Error, Serialize, Serializer};

Expand Down Expand Up @@ -724,3 +724,30 @@ impl ToPyObject for Int {
}
}
}

#[derive(Clone)]
pub enum EitherComplex<'a> {
Complex([f64; 2]),
Py(Bound<'a, PyComplex>),
}

impl<'a> IntoPy<PyObject> for EitherComplex<'a> {
fn into_py(self, py: Python<'_>) -> PyObject {
match self {
Self::Complex(c) => PyComplex::from_doubles_bound(py, c[0], c[1]).into_py(py),
Self::Py(c) => c.into_py(py),
}
}
}

impl<'a> EitherComplex<'a> {
pub fn as_f64(&self, py: Python<'_>) -> [f64; 2] {
match self {
EitherComplex::Complex(f) => *f,
EitherComplex::Py(f) => [
f.getattr(intern!(py, "real")).unwrap().extract().unwrap(),
f.getattr(intern!(py, "imag")).unwrap().extract().unwrap(),
],
}
}
}
24 changes: 23 additions & 1 deletion src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyComplex;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};

use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
Expand Down Expand Up @@ -226,6 +227,13 @@ pub(crate) fn infer_to_python_known(
}
PyList::new_bound(py, items).into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Path => value.str()?.into_py(py),
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
ObType::Unknown => {
Expand Down Expand Up @@ -274,6 +282,13 @@ pub(crate) fn infer_to_python_known(
);
iter.into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
Expand Down Expand Up @@ -402,6 +417,13 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
ObType::None => serializer.serialize_none(),
ObType::Int | ObType::IntSubclass => serialize!(Int),
ObType::Bool => serialize!(bool),
ObType::Complex => {
let v = value.downcast::<PyComplex>().map_err(py_err_se_err)?;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry(&"real", &v.real())?;
map.serialize_entry(&"imag", &v.imag())?;
map.end()
}
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>().map_err(py_err_se_err)?;
type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode)
Expand Down Expand Up @@ -647,7 +669,7 @@ pub(crate) fn infer_json_key_known<'a>(
}
Ok(Cow::Owned(key_build.finish()))
}
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => {
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
}
ObType::Dataclass | ObType::PydanticSerializable => {
Expand Down
Loading
Loading