Skip to content

Support complex numbers #1331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal, complex):
return {'type': obj.__name__.lower()}
elif is_typeddict(obj):
return type_dict_schema(obj, definitions)
Expand Down
40 changes: 40 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,43 @@ def decimal_schema(
)


class ComplexSchema(TypedDict, total=False):
type: Required[Literal['complex']]
ref: str
metadata: Any
serialization: SerSchema


def complex_schema(
*,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> ComplexSchema:
"""
Returns a schema that matches a complex value, e.g.:

```py
from pydantic_core import SchemaValidator, core_schema

schema = core_schema.complex_schema()
v = SchemaValidator(schema)
assert v.validate_python({'real': 1, 'imag': 2}) == complex(1, 2)
```

Args:
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='complex',
ref=ref,
metadata=metadata,
serialization=serialization,
)


class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: Union[str, Pattern[str]]
Expand Down Expand Up @@ -3777,6 +3814,7 @@ def definition_reference_schema(
DefinitionsSchema,
DefinitionReferenceSchema,
UuidSchema,
ComplexSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3832,6 +3870,7 @@ def definition_reference_schema(
'definitions',
'definition-ref',
'uuid',
'complex',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down Expand Up @@ -3936,6 +3975,7 @@ def definition_reference_schema(
'decimal_max_digits',
'decimal_max_places',
'decimal_whole_digits',
'complex_type',
]


Expand Down
3 changes: 3 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ error_types! {
DecimalWholeDigits {
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
},
// Complex errors
ComplexType {},
}

macro_rules! render {
Expand Down Expand Up @@ -564,6 +566,7 @@ impl ErrorType {
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
Self::ComplexType { .. } => "Input should be a valid dictionary with exactly two keys, 'real' and 'imag', with float values",
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::py_err;

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString};
use super::{EitherFloat, GenericIterator, ValidationMatch};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -172,6 +172,8 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValMatch<EitherTimedelta<'py>>;

fn validate_complex(&self) -> ValMatch<EitherComplex<'py>>;
}

/// The problem to solve here is that iterating collections often returns owned
Expand Down
36 changes: 35 additions & 1 deletion src/input/input_json.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::collections::HashSet;

use jiter::{JsonArray, JsonObject, JsonValue, LazyIndexMap};
use pyo3::prelude::*;
Expand All @@ -16,7 +17,7 @@ use super::datetime::{
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
};
use super::input_abstract::{ConsumeIterator, Never, ValMatch};
use super::return_enums::ValidationMatch;
use super::return_enums::{EitherComplex, ValidationMatch};
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int};
use super::{
Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input,
Expand Down Expand Up @@ -296,6 +297,35 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
_ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
let default = JsonValue::Float(0.0);
match self {
JsonValue::Object(object) => {
let mut allowed_keys = HashSet::from(["real".to_owned(), "imag".to_owned()]);
for key in object.keys() {
let k = &key.to_string();
if !allowed_keys.remove(k) {
return Err(ValError::new(ErrorTypeDefaults::ComplexType, self));
}
}
let real = object.get("real").unwrap_or(&default).validate_float(true);
let imag = object.get("imag").unwrap_or(&default).validate_float(true);
if let Ok(re) = real {
if let Ok(im) = imag {
return Ok(ValidationMatch::strict(EitherComplex::Complex([
re.into_inner().as_f64(),
im.into_inner().as_f64(),
])));
}
}
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
}
JsonValue::Float(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([*f, 0.0]))),
JsonValue::Int(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([(*f) as f64, 0.0]))),
_ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
}
}
}

/// Required for JSON Object keys so the string can behave like an Input
Expand Down Expand Up @@ -425,6 +455,10 @@ impl<'py> Input<'py> for str {
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>> {
bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax)
}

fn validate_complex(&self) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
}
}

impl BorrowInput<'_> for &'_ String {
Expand Down
46 changes: 44 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use pyo3::prelude::*;

use pyo3::types::PyType;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySet, PyString, PyTime, PyTuple,
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
PyList, PyMapping, PySet, PyString, PyTime, PyTuple,
};

use pyo3::PyTypeCheck;
Expand All @@ -25,6 +25,7 @@ use super::datetime::{
EitherTime,
};
use super::input_abstract::ValMatch;
use super::return_enums::EitherComplex;
use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch};
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
Expand Down Expand Up @@ -592,6 +593,47 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {

Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self))
}

fn validate_complex<'a>(&'a self) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
if let Ok(complex) = self.downcast::<PyComplex>() {
return Ok(ValidationMatch::exact(EitherComplex::Py(complex.to_owned())));
} else if let Ok(complex) = self.downcast::<PyDict>() {
let re = complex.get_item("real");
let im = complex.get_item("imag");
if complex.len() > 2 || re.is_err() && im.is_err() {
return Err(ValError::new(ErrorTypeDefaults::ComplexType, self));
}
let mut res = [0.0, 0.0];
if let Some(v) = re.unwrap_or(None) {
if v.is_exact_instance_of::<PyFloat>() || v.is_exact_instance_of::<PyInt>() {
let u = v.extract::<f64>();
res[0] = u.unwrap();
} else {
return Err(ValError::new(ErrorTypeDefaults::ComplexType, self));
}
}
if let Some(v) = im.unwrap_or(None) {
if v.is_exact_instance_of::<PyFloat>() || v.is_exact_instance_of::<PyInt>() {
let u = v.extract::<f64>();
res[1] = u.unwrap();
} else {
return Err(ValError::new(ErrorTypeDefaults::ComplexType, self));
}
}
return Ok(ValidationMatch::exact(EitherComplex::Complex(res)));
} else if self.is_exact_instance_of::<PyFloat>() {
return Ok(ValidationMatch::exact(EitherComplex::Complex([
self.extract::<f64>().unwrap(),
0.0,
])));
} else if self.is_exact_instance_of::<PyInt>() {
return Ok(ValidationMatch::exact(EitherComplex::Complex([
self.extract::<i64>().unwrap() as f64,
0.0,
])));
}
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
}
}

impl<'py> BorrowInput<'py> for Bound<'py, PyAny> {
Expand Down
10 changes: 9 additions & 1 deletion src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
};
use super::input_abstract::{Never, ValMatch};
use super::shared::{str_as_bool, str_as_float, str_as_int};
use super::return_enums::EitherComplex;
use super::shared::{str_as_bool, str_as_complex, str_as_float, str_as_int};
use super::{
Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input,
KeywordArgs, ValidatedDict, ValidationMatch,
Expand Down Expand Up @@ -217,6 +218,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
match self {
Self::String(s) => str_as_complex(self, py_string_str(s)?).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
}
}
}

impl<'py> BorrowInput<'py> for StringMapping<'py> {
Expand Down
29 changes: 28 additions & 1 deletion src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pyo3::intern;
use pyo3::prelude::*;
#[cfg(not(PyPy))]
use pyo3::types::PyFunction;
use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};
use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};

use serde::{ser::Error, Serialize, Serializer};

Expand Down Expand Up @@ -716,3 +716,30 @@ impl ToPyObject for Int {
}
}
}

#[derive(Clone)]
pub enum EitherComplex<'a> {
Complex([f64; 2]),
Py(Bound<'a, PyComplex>),
}

impl<'a> IntoPy<PyObject> for EitherComplex<'a> {
fn into_py(self, py: Python<'_>) -> PyObject {
match self {
Self::Complex(c) => PyComplex::from_doubles_bound(py, c[0], c[1]).into_py(py),
Self::Py(c) => c.into_py(py),
}
}
}

impl<'a> EitherComplex<'a> {
pub fn as_f64(&self, py: Python<'_>) -> [f64; 2] {
match self {
EitherComplex::Complex(f) => *f,
EitherComplex::Py(f) => [
f.getattr(intern!(py, "real")).unwrap().extract().unwrap(),
f.getattr(intern!(py, "imag")).unwrap().extract().unwrap(),
],
}
}
}
7 changes: 7 additions & 0 deletions src/input/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use jiter::{JsonErrorType, NumberInt};

use crate::errors::{ErrorTypeDefaults, ValError, ValResult};

use super::return_enums::EitherComplex;
use super::{EitherFloat, EitherInt, Input};
static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

Expand Down Expand Up @@ -204,3 +205,9 @@ pub fn decimal_as_int<'py>(
}
Ok(EitherInt::Py(numerator))
}

/// parse a complex as a complex
pub fn str_as_complex<'py>(input: &(impl Input<'py> + ?Sized), _str: &str) -> ValResult<EitherComplex<'py>> {
// TODO
Err(ValError::new(ErrorTypeDefaults::ComplexType, input))
}
24 changes: 23 additions & 1 deletion src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyComplex;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};

use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
Expand Down Expand Up @@ -226,6 +227,13 @@ pub(crate) fn infer_to_python_known(
}
PyList::new_bound(py, items).into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Path => value.str()?.into_py(py),
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
ObType::Unknown => {
Expand Down Expand Up @@ -274,6 +282,13 @@ pub(crate) fn infer_to_python_known(
);
iter.into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
Expand Down Expand Up @@ -402,6 +417,13 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
ObType::None => serializer.serialize_none(),
ObType::Int | ObType::IntSubclass => serialize!(Int),
ObType::Bool => serialize!(bool),
ObType::Complex => {
let v = value.downcast::<PyComplex>().map_err(py_err_se_err)?;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry(&"real", &v.real())?;
map.serialize_entry(&"imag", &v.imag())?;
map.end()
}
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>().map_err(py_err_se_err)?;
type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode)
Expand Down Expand Up @@ -647,7 +669,7 @@ pub(crate) fn infer_json_key_known<'a>(
}
Ok(Cow::Owned(key_build.finish()))
}
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => {
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
}
ObType::Dataclass | ObType::PydanticSerializable => {
Expand Down
Loading
Loading