Skip to content

Commit ebf55d6

Browse files
authored
refactor Input trait to have single as_python cast for python inputs (#1241)
1 parent 6399592 commit ebf55d6

File tree

15 files changed

+178
-142
lines changed

15 files changed

+178
-142
lines changed

src/input/input_abstract.rs

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use std::fmt;
22

33
use pyo3::exceptions::PyValueError;
4-
use pyo3::types::{PyDict, PyList, PyType};
4+
use pyo3::types::{PyDict, PyList};
55
use pyo3::{intern, prelude::*};
66

77
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
88
use crate::lookup_key::{LookupKey, LookupPath};
99
use crate::tools::py_err;
10-
use crate::{PyMultiHostUrl, PyUrl};
1110

1211
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1312
use super::return_enums::{EitherBytes, EitherInt, EitherString};
@@ -52,44 +51,16 @@ pub type ValMatch<T> = ValResult<ValidationMatch<T>>;
5251
pub trait Input<'py>: fmt::Debug + ToPyObject {
5352
fn as_error_value(&self) -> InputValue;
5453

55-
fn identity(&self) -> Option<usize> {
56-
None
57-
}
58-
5954
fn is_none(&self) -> bool {
6055
false
6156
}
6257

63-
fn input_is_instance(&self, _class: &Bound<'py, PyType>) -> Option<&Bound<'py, PyAny>> {
58+
fn as_python(&self) -> Option<&Bound<'py, PyAny>> {
6459
None
6560
}
6661

67-
fn input_is_exact_instance(&self, _class: &Bound<'py, PyType>) -> bool {
68-
false
69-
}
70-
71-
fn is_python(&self) -> bool {
72-
false
73-
}
74-
7562
fn as_kwargs(&self, py: Python<'py>) -> Option<Bound<'py, PyDict>>;
7663

77-
fn input_is_subclass(&self, _class: &Bound<'_, PyType>) -> PyResult<bool> {
78-
Ok(false)
79-
}
80-
81-
fn input_as_url(&self) -> Option<PyUrl> {
82-
None
83-
}
84-
85-
fn input_as_multi_host_url(&self) -> Option<PyMultiHostUrl> {
86-
None
87-
}
88-
89-
fn callable(&self) -> bool {
90-
false
91-
}
92-
9364
type Arguments<'a>: Arguments<'py>
9465
where
9566
Self: 'a;

src/input/input_python.rs

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@ use std::str::from_utf8;
44
use pyo3::intern;
55
use pyo3::prelude::*;
66

7+
use pyo3::types::PyType;
78
use pyo3::types::{
89
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
9-
PyMapping, PySet, PyString, PyTime, PyTuple, PyType,
10+
PyMapping, PySet, PyString, PyTime, PyTuple,
1011
};
1112

13+
use pyo3::PyTypeCheck;
1214
use speedate::MicrosecondsPrecisionOverflowBehavior;
1315

1416
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
1517
use crate::tools::{extract_i64, safe_repr};
1618
use crate::validators::decimal::{create_decimal, get_decimal_type};
1719
use crate::validators::Exactness;
18-
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};
20+
use crate::ArgsKwargs;
1921

2022
use super::datetime::{
2123
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
@@ -40,6 +42,17 @@ use super::{
4042
Input,
4143
};
4244

45+
pub(crate) fn downcast_python_input<'py, T: PyTypeCheck>(input: &(impl Input<'py> + ?Sized)) -> Option<&Bound<'py, T>> {
46+
input.as_python().and_then(|any| any.downcast::<T>().ok())
47+
}
48+
49+
pub(crate) fn input_as_python_instance<'a, 'py>(
50+
input: &'a (impl Input<'py> + ?Sized),
51+
class: &Bound<'py, PyType>,
52+
) -> Option<&'a Bound<'py, PyAny>> {
53+
input.as_python().filter(|any| any.is_instance(class).unwrap_or(false))
54+
}
55+
4356
impl From<&Bound<'_, PyAny>> for LocItem {
4457
fn from(py_any: &Bound<'_, PyAny>) -> Self {
4558
if let Ok(py_str) = py_any.downcast::<PyString>() {
@@ -63,28 +76,12 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
6376
InputValue::Python(self.clone().into())
6477
}
6578

66-
fn identity(&self) -> Option<usize> {
67-
Some(self.as_ptr() as usize)
68-
}
69-
7079
fn is_none(&self) -> bool {
7180
PyAnyMethods::is_none(self)
7281
}
7382

74-
fn input_is_instance(&self, class: &Bound<'py, PyType>) -> Option<&Bound<'py, PyAny>> {
75-
if self.is_instance(class).unwrap_or(false) {
76-
Some(self)
77-
} else {
78-
None
79-
}
80-
}
81-
82-
fn input_is_exact_instance(&self, class: &Bound<'py, PyType>) -> bool {
83-
self.is_exact_instance(class)
84-
}
85-
86-
fn is_python(&self) -> bool {
87-
true
83+
fn as_python(&self) -> Option<&Bound<'py, PyAny>> {
84+
Some(self)
8885
}
8986

9087
fn as_kwargs(&self, py: Python<'py>) -> Option<Bound<'py, PyDict>> {
@@ -93,25 +90,6 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
9390
.map(|dict| dict.to_owned().unbind().into_bound(py))
9491
}
9592

96-
fn input_is_subclass(&self, class: &Bound<'_, PyType>) -> PyResult<bool> {
97-
match self.downcast::<PyType>() {
98-
Ok(py_type) => py_type.is_subclass(class),
99-
Err(_) => Ok(false),
100-
}
101-
}
102-
103-
fn input_as_url(&self) -> Option<PyUrl> {
104-
self.extract::<PyUrl>().ok()
105-
}
106-
107-
fn input_as_multi_host_url(&self) -> Option<PyMultiHostUrl> {
108-
self.extract::<PyMultiHostUrl>().ok()
109-
}
110-
111-
fn callable(&self) -> bool {
112-
self.is_callable()
113-
}
114-
11593
type Arguments<'a> = PyArgs<'py> where Self: 'a;
11694

11795
fn validate_args(&self) -> ValResult<PyArgs<'py>> {

src/input/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub(crate) use input_abstract::{
1919
Arguments, BorrowInput, ConsumeIterator, Input, InputType, KeywordArgs, PositionalArgs, ValidatedDict,
2020
ValidatedList, ValidatedSet, ValidatedTuple,
2121
};
22+
pub(crate) use input_python::{downcast_python_input, input_as_python_instance};
2223
pub(crate) use input_string::StringMapping;
2324
pub(crate) use return_enums::{
2425
no_validator_iter_to_vec, py_string_str, validate_iter_to_set, validate_iter_to_vec, EitherBytes, EitherFloat,

src/url.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ use crate::SchemaValidator;
1616

1717
static SCHEMA_DEFINITION_URL: GILOnceCell<SchemaValidator> = GILOnceCell::new();
1818

19-
#[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass)]
20-
#[derive(Clone)]
19+
#[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass, frozen)]
20+
#[derive(Clone, Hash)]
2121
#[cfg_attr(debug_assertions, derive(Debug))]
2222
pub struct PyUrl {
2323
lib_url: Url,
@@ -28,8 +28,8 @@ impl PyUrl {
2828
Self { lib_url }
2929
}
3030

31-
pub fn into_url(self) -> Url {
32-
self.lib_url
31+
pub fn url(&self) -> &Url {
32+
&self.lib_url
3333
}
3434
}
3535

@@ -138,7 +138,7 @@ impl PyUrl {
138138

139139
fn __hash__(&self) -> u64 {
140140
let mut s = DefaultHasher::new();
141-
self.lib_url.to_string().hash(&mut s);
141+
self.hash(&mut s);
142142
s.finish()
143143
}
144144

@@ -192,8 +192,8 @@ impl PyUrl {
192192
}
193193
}
194194

195-
#[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core", subclass)]
196-
#[derive(Clone)]
195+
#[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core", subclass, frozen)]
196+
#[derive(Clone, Hash)]
197197
#[cfg_attr(debug_assertions, derive(Debug))]
198198
pub struct PyMultiHostUrl {
199199
ref_url: PyUrl,
@@ -208,6 +208,10 @@ impl PyMultiHostUrl {
208208
}
209209
}
210210

211+
pub fn lib_url(&self) -> &Url {
212+
&self.ref_url.lib_url
213+
}
214+
211215
pub fn mut_lib_url(&mut self) -> &mut Url {
212216
&mut self.ref_url.lib_url
213217
}
@@ -338,8 +342,7 @@ impl PyMultiHostUrl {
338342

339343
fn __hash__(&self) -> u64 {
340344
let mut s = DefaultHasher::new();
341-
self.ref_url.clone().into_url().to_string().hash(&mut s);
342-
self.extra_urls.hash(&mut s);
345+
self.hash(&mut s);
343346
s.finish()
344347
}
345348

src/validators/callable.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl Validator for CallableValidator {
3232
state: &mut ValidationState<'_, 'py>,
3333
) -> ValResult<PyObject> {
3434
state.floor_exactness(Exactness::Lax);
35-
match input.callable() {
35+
match input.as_python().is_some_and(PyAnyMethods::is_callable) {
3636
true => Ok(input.to_object(py)),
3737
false => Err(ValError::new(ErrorTypeDefaults::CallableType, input)),
3838
}

src/validators/dataclass.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ use ahash::AHashSet;
88
use crate::build_tools::py_schema_err;
99
use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior};
1010
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult};
11-
use crate::input::InputType;
12-
use crate::input::{Arguments, BorrowInput, Input, KeywordArgs, PositionalArgs, ValidationMatch};
11+
use crate::input::{
12+
input_as_python_instance, Arguments, BorrowInput, Input, InputType, KeywordArgs, PositionalArgs, ValidationMatch,
13+
};
1314
use crate::lookup_key::LookupKey;
1415
use crate::tools::SchemaDict;
1516
use crate::validators::function::convert_err;
@@ -501,7 +502,7 @@ impl Validator for DataclassValidator {
501502

502503
// same logic as on models
503504
let class = self.class.bind(py);
504-
if let Some(py_input) = input.input_is_instance(class) {
505+
if let Some(py_input) = input_as_python_instance(input, class) {
505506
if self.revalidate.should_revalidate(py_input, class) {
506507
let input_dict = self.dataclass_to_dict(py_input)?;
507508
let val_output = self.validator.validate(py, input_dict.as_any(), state)?;

src/validators/definitions.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ impl Validator for DefinitionRefValidator {
7777
) -> ValResult<PyObject> {
7878
self.definition.read(|validator| {
7979
let validator = validator.unwrap();
80-
if let Some(id) = input.identity() {
80+
if let Some(id) = input.as_python().map(py_identity) {
81+
// Python objects can be cyclic, so need recursion guard
8182
let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
8283
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
8384
};
@@ -98,18 +99,18 @@ impl Validator for DefinitionRefValidator {
9899
) -> ValResult<PyObject> {
99100
self.definition.read(|validator| {
100101
let validator = validator.unwrap();
101-
if let Some(id) = obj.identity() {
102-
let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else {
103-
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
104-
};
105-
validator.validate_assignment(py, obj, field_name, field_value, guard.state())
106-
} else {
107-
validator.validate_assignment(py, obj, field_name, field_value, state)
108-
}
102+
let Ok(mut guard) = RecursionGuard::new(state, py_identity(obj), self.definition.id()) else {
103+
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
104+
};
105+
validator.validate_assignment(py, obj, field_name, field_value, guard.state())
109106
})
110107
}
111108

112109
fn get_name(&self) -> &str {
113110
self.definition.get_or_init_name(|v| v.get_name().into())
114111
}
115112
}
113+
114+
fn py_identity(obj: &Bound<'_, PyAny>) -> usize {
115+
obj.as_ptr() as usize
116+
}

src/validators/enum_.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
103103
state: &mut ValidationState<'_, 'py>,
104104
) -> ValResult<PyObject> {
105105
let class = self.class.bind(py);
106-
if input.input_is_exact_instance(class) {
106+
if input.as_python().is_some_and(|any| any.is_exact_instance(class)) {
107107
return Ok(input.to_object(py));
108108
}
109109
let strict = state.strict_or(self.strict);
110-
if strict && input.is_python() {
110+
if strict && input.as_python().is_some() {
111111
// TODO what about instances of subclasses?
112112
return Err(ValError::new(
113113
ErrorType::IsInstanceOf {

src/validators/is_instance.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,14 @@ impl Validator for IsInstanceValidator {
5555
input: &(impl Input<'py> + ?Sized),
5656
_state: &mut ValidationState<'_, 'py>,
5757
) -> ValResult<PyObject> {
58-
if !input.is_python() {
58+
let Some(obj) = input.as_python() else {
5959
return Err(ValError::InternalErr(PyNotImplementedError::new_err(
6060
"Cannot check isinstance when validating from json, \
6161
use a JsonOrPython validator instead.",
6262
)));
63-
}
64-
65-
let ob: Py<PyAny> = input.to_object(py);
66-
match ob.bind(py).is_instance(self.class.bind(py))? {
67-
true => Ok(ob),
63+
};
64+
match obj.is_instance(self.class.bind(py))? {
65+
true => Ok(obj.clone().unbind()),
6866
false => Err(ValError::new(
6967
ErrorType::IsInstanceOf {
7068
class: self.class_repr.clone(),

src/validators/is_subclass.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use pyo3::exceptions::PyNotImplementedError;
12
use pyo3::intern;
23
use pyo3::prelude::*;
34
use pyo3::types::{PyDict, PyType};
@@ -49,9 +50,15 @@ impl Validator for IsSubclassValidator {
4950
input: &(impl Input<'py> + ?Sized),
5051
_state: &mut ValidationState<'_, 'py>,
5152
) -> ValResult<PyObject> {
52-
match input.input_is_subclass(self.class.bind(py))? {
53-
true => Ok(input.to_object(py)),
54-
false => Err(ValError::new(
53+
let Some(obj) = input.as_python() else {
54+
return Err(ValError::InternalErr(PyNotImplementedError::new_err(
55+
"Cannot check issubclass when validating from json, \
56+
use a JsonOrPython validator instead.",
57+
)));
58+
};
59+
match obj.downcast::<PyType>() {
60+
Ok(py_type) if py_type.is_subclass(self.class.bind(py))? => Ok(obj.clone().unbind()),
61+
_ => Err(ValError::new(
5562
ErrorType::IsSubclassOf {
5663
class: self.class_repr.clone(),
5764
context: None,

src/validators/literal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl<T: Debug> LiteralLookup<T> {
123123
}
124124
}
125125
if let Some(expected_strings) = &self.expected_str {
126-
let validation_result = if input.is_python() {
126+
let validation_result = if input.as_python().is_some() {
127127
input.exact_str()
128128
} else {
129129
// Strings coming from JSON are treated as "strict" but not "exact" for reasons

src/validators/model.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use super::{
1313
use crate::build_tools::py_schema_err;
1414
use crate::build_tools::schema_or_config_same;
1515
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
16-
use crate::input::{py_error_on_minusone, Input};
16+
use crate::input::{input_as_python_instance, py_error_on_minusone, Input};
1717
use crate::tools::{py_err, SchemaDict};
1818
use crate::PydanticUndefinedType;
1919

@@ -124,7 +124,7 @@ impl Validator for ModelValidator {
124124
// if the input is an instance of the class, we "revalidate" it - e.g. we extract and reuse `__pydantic_fields_set__`
125125
// but use from attributes to create a new instance of the model field type
126126
let class = self.class.bind(py);
127-
if let Some(py_input) = input.input_is_instance(class) {
127+
if let Some(py_input) = input_as_python_instance(input, class) {
128128
if self.revalidate.should_revalidate(py_input, class) {
129129
let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?;
130130
if self.root_model {

0 commit comments

Comments
 (0)