Skip to content

refactor Input trait to have single as_python cast for python inputs #1241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 2 additions & 31 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::fmt;

use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyList, PyType};
use pyo3::types::{PyDict, PyList};
use pyo3::{intern, prelude::*};

use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
Expand Down Expand Up @@ -52,44 +51,16 @@ pub type ValMatch<T> = ValResult<ValidationMatch<T>>;
pub trait Input<'py>: fmt::Debug + ToPyObject {
fn as_error_value(&self) -> InputValue;

fn identity(&self) -> Option<usize> {
None
}

fn is_none(&self) -> bool {
false
}

fn input_is_instance(&self, _class: &Bound<'py, PyType>) -> Option<&Bound<'py, PyAny>> {
fn as_python(&self) -> Option<&Bound<'py, PyAny>> {
None
}

fn input_is_exact_instance(&self, _class: &Bound<'py, PyType>) -> bool {
false
}

fn is_python(&self) -> bool {
false
}

fn as_kwargs(&self, py: Python<'py>) -> Option<Bound<'py, PyDict>>;

fn input_is_subclass(&self, _class: &Bound<'_, PyType>) -> PyResult<bool> {
Ok(false)
}

fn input_as_url(&self) -> Option<PyUrl> {
None
}

fn input_as_multi_host_url(&self) -> Option<PyMultiHostUrl> {
None
}

fn callable(&self) -> bool {
false
}

type Arguments<'a>: Arguments<'py>
where
Self: 'a;
Expand Down
56 changes: 17 additions & 39 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@ use std::str::from_utf8;
use pyo3::intern;
use pyo3::prelude::*;

use pyo3::types::PyType;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySet, PyString, PyTime, PyTuple, PyType,
PyMapping, PySet, PyString, PyTime, PyTuple,
};

use pyo3::PyTypeCheck;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};
use crate::ArgsKwargs;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
Expand All @@ -40,6 +42,17 @@ use super::{
Input,
};

pub(crate) fn downcast_python_input<'py, T: PyTypeCheck>(input: &(impl Input<'py> + ?Sized)) -> Option<&Bound<'py, T>> {
input.as_python().and_then(|any| any.downcast::<T>().ok())
}

pub(crate) fn input_as_python_instance<'a, 'py>(
input: &'a (impl Input<'py> + ?Sized),
class: &Bound<'py, PyType>,
) -> Option<&'a Bound<'py, PyAny>> {
input.as_python().filter(|any| any.is_instance(class).unwrap_or(false))
}

impl From<&Bound<'_, PyAny>> for LocItem {
fn from(py_any: &Bound<'_, PyAny>) -> Self {
if let Ok(py_str) = py_any.downcast::<PyString>() {
Expand All @@ -63,28 +76,12 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
InputValue::Python(self.clone().into())
}

fn identity(&self) -> Option<usize> {
Some(self.as_ptr() as usize)
}

fn is_none(&self) -> bool {
PyAnyMethods::is_none(self)
}

fn input_is_instance(&self, class: &Bound<'py, PyType>) -> Option<&Bound<'py, PyAny>> {
if self.is_instance(class).unwrap_or(false) {
Some(self)
} else {
None
}
}

fn input_is_exact_instance(&self, class: &Bound<'py, PyType>) -> bool {
self.is_exact_instance(class)
}

fn is_python(&self) -> bool {
true
fn as_python(&self) -> Option<&Bound<'py, PyAny>> {
Some(self)
}

fn as_kwargs(&self, py: Python<'py>) -> Option<Bound<'py, PyDict>> {
Expand All @@ -93,25 +90,6 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
.map(|dict| dict.to_owned().unbind().into_bound(py))
}

fn input_is_subclass(&self, class: &Bound<'_, PyType>) -> PyResult<bool> {
match self.downcast::<PyType>() {
Ok(py_type) => py_type.is_subclass(class),
Err(_) => Ok(false),
}
}

fn input_as_url(&self) -> Option<PyUrl> {
self.extract::<PyUrl>().ok()
}

fn input_as_multi_host_url(&self) -> Option<PyMultiHostUrl> {
self.extract::<PyMultiHostUrl>().ok()
}

fn callable(&self) -> bool {
self.is_callable()
}

type Arguments<'a> = PyArgs<'py> where Self: 'a;

fn validate_args(&self) -> ValResult<PyArgs<'py>> {
Expand Down
1 change: 1 addition & 0 deletions src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub(crate) use input_abstract::{
Arguments, BorrowInput, ConsumeIterator, Input, InputType, KeywordArgs, PositionalArgs, ValidatedDict,
ValidatedList, ValidatedSet, ValidatedTuple,
};
pub(crate) use input_python::{downcast_python_input, input_as_python_instance};
pub(crate) use input_string::StringMapping;
pub(crate) use return_enums::{
no_validator_iter_to_vec, py_string_str, validate_iter_to_set, validate_iter_to_vec, EitherBytes, EitherFloat,
Expand Down
21 changes: 12 additions & 9 deletions src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use crate::SchemaValidator;

static SCHEMA_DEFINITION_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();

#[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass)]
#[derive(Clone)]
#[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass, frozen)]
#[derive(Clone, Hash)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct PyUrl {
lib_url: Url,
Expand All @@ -28,8 +28,8 @@ impl PyUrl {
Self { lib_url }
}

pub fn into_url(self) -> Url {
self.lib_url
pub fn url(&self) -> &Url {
&self.lib_url
}
}

Expand Down Expand Up @@ -138,7 +138,7 @@ impl PyUrl {

fn __hash__(&self) -> u64 {
let mut s = DefaultHasher::new();
self.lib_url.to_string().hash(&mut s);
self.hash(&mut s);
s.finish()
}

Expand Down Expand Up @@ -192,8 +192,8 @@ impl PyUrl {
}
}

#[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core", subclass)]
#[derive(Clone)]
#[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core", subclass, frozen)]
#[derive(Clone, Hash)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct PyMultiHostUrl {
ref_url: PyUrl,
Expand All @@ -208,6 +208,10 @@ impl PyMultiHostUrl {
}
}

pub fn lib_url(&self) -> &Url {
&self.ref_url.lib_url
}

pub fn mut_lib_url(&mut self) -> &mut Url {
&mut self.ref_url.lib_url
}
Expand Down Expand Up @@ -338,8 +342,7 @@ impl PyMultiHostUrl {

fn __hash__(&self) -> u64 {
let mut s = DefaultHasher::new();
self.ref_url.clone().into_url().to_string().hash(&mut s);
self.extra_urls.hash(&mut s);
self.hash(&mut s);
s.finish()
}

Expand Down
2 changes: 1 addition & 1 deletion src/validators/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Validator for CallableValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
state.floor_exactness(Exactness::Lax);
match input.callable() {
match input.as_python().is_some_and(PyAnyMethods::is_callable) {
true => Ok(input.to_object(py)),
false => Err(ValError::new(ErrorTypeDefaults::CallableType, input)),
}
Expand Down
7 changes: 4 additions & 3 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use ahash::AHashSet;
use crate::build_tools::py_schema_err;
use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior};
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::InputType;
use crate::input::{Arguments, BorrowInput, Input, KeywordArgs, PositionalArgs, ValidationMatch};
use crate::input::{
input_as_python_instance, Arguments, BorrowInput, Input, InputType, KeywordArgs, PositionalArgs, ValidationMatch,
};
use crate::lookup_key::LookupKey;
use crate::tools::SchemaDict;
use crate::validators::function::convert_err;
Expand Down Expand Up @@ -501,7 +502,7 @@ impl Validator for DataclassValidator {

// same logic as on models
let class = self.class.bind(py);
if let Some(py_input) = input.input_is_instance(class) {
if let Some(py_input) = input_as_python_instance(input, class) {
if self.revalidate.should_revalidate(py_input, class) {
let input_dict = self.dataclass_to_dict(py_input)?;
let val_output = self.validator.validate(py, input_dict.as_any(), state)?;
Expand Down
19 changes: 10 additions & 9 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ impl Validator for DefinitionRefValidator {
) -> ValResult<PyObject> {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = input.identity() {
if let Some(id) = input.as_python().map(py_identity) {
// Python objects can be cyclic, so need recursion guard
let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
};
Expand All @@ -98,18 +99,18 @@ impl Validator for DefinitionRefValidator {
) -> ValResult<PyObject> {
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = obj.identity() {
let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
};
validator.validate_assignment(py, obj, field_name, field_value, guard.state())
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
}
let Ok(mut guard) = RecursionGuard::new(state, py_identity(obj), self.definition.id()) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
};
validator.validate_assignment(py, obj, field_name, field_value, guard.state())
})
}

fn get_name(&self) -> &str {
self.definition.get_or_init_name(|v| v.get_name().into())
}
}

fn py_identity(obj: &Bound<'_, PyAny>) -> usize {
obj.as_ptr() as usize
}
4 changes: 2 additions & 2 deletions src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let class = self.class.bind(py);
if input.input_is_exact_instance(class) {
if input.as_python().is_some_and(|any| any.is_exact_instance(class)) {
return Ok(input.to_object(py));
}
let strict = state.strict_or(self.strict);
if strict && input.is_python() {
if strict && input.as_python().is_some() {
// TODO what about instances of subclasses?
return Err(ValError::new(
ErrorType::IsInstanceOf {
Expand Down
10 changes: 4 additions & 6 deletions src/validators/is_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ impl Validator for IsInstanceValidator {
input: &(impl Input<'py> + ?Sized),
_state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
if !input.is_python() {
let Some(obj) = input.as_python() else {
return Err(ValError::InternalErr(PyNotImplementedError::new_err(
"Cannot check isinstance when validating from json, \
use a JsonOrPython validator instead.",
)));
}

let ob: Py<PyAny> = input.to_object(py);
match ob.bind(py).is_instance(self.class.bind(py))? {
true => Ok(ob),
};
match obj.is_instance(self.class.bind(py))? {
true => Ok(obj.clone().unbind()),
false => Err(ValError::new(
ErrorType::IsInstanceOf {
class: self.class_repr.clone(),
Expand Down
13 changes: 10 additions & 3 deletions src/validators/is_subclass.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use pyo3::exceptions::PyNotImplementedError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyType};
Expand Down Expand Up @@ -49,9 +50,15 @@ impl Validator for IsSubclassValidator {
input: &(impl Input<'py> + ?Sized),
_state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
match input.input_is_subclass(self.class.bind(py))? {
true => Ok(input.to_object(py)),
false => Err(ValError::new(
let Some(obj) = input.as_python() else {
return Err(ValError::InternalErr(PyNotImplementedError::new_err(
"Cannot check issubclass when validating from json, \
use a JsonOrPython validator instead.",
)));
};
match obj.downcast::<PyType>() {
Ok(py_type) if py_type.is_subclass(self.class.bind(py))? => Ok(obj.clone().unbind()),
_ => Err(ValError::new(
ErrorType::IsSubclassOf {
class: self.class_repr.clone(),
context: None,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl<T: Debug> LiteralLookup<T> {
}
}
if let Some(expected_strings) = &self.expected_str {
let validation_result = if input.is_python() {
let validation_result = if input.as_python().is_some() {
input.exact_str()
} else {
// Strings coming from JSON are treated as "strict" but not "exact" for reasons
Expand Down
4 changes: 2 additions & 2 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{
use crate::build_tools::py_schema_err;
use crate::build_tools::schema_or_config_same;
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
use crate::input::{py_error_on_minusone, Input};
use crate::input::{input_as_python_instance, py_error_on_minusone, Input};
use crate::tools::{py_err, SchemaDict};
use crate::PydanticUndefinedType;

Expand Down Expand Up @@ -124,7 +124,7 @@ impl Validator for ModelValidator {
// if the input is an instance of the class, we "revalidate" it - e.g. we extract and reuse `__pydantic_fields_set__`
// but use from attributes to create a new instance of the model field type
let class = self.class.bind(py);
if let Some(py_input) = input.input_is_instance(class) {
if let Some(py_input) = input_as_python_instance(input, class) {
if self.revalidate.should_revalidate(py_input, class) {
let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?;
if self.root_model {
Expand Down
Loading