Skip to content

Commit c9a83c8

Browse files
authored
cleanup string validation a little (#624)
1 parent 8d12f96 commit c9a83c8

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

src/input/input_python.rs

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -187,26 +187,22 @@ impl<'a> Input<'a> for PyAny {
187187
}
188188

189189
fn strict_str(&'a self) -> ValResult<EitherString<'a>> {
190-
if let Ok(py_str) = self.downcast::<PyString>() {
191-
if is_builtin_str(py_str) {
192-
Ok(py_str.into())
193-
} else {
194-
Err(ValError::new(ErrorType::StringSubType, self))
195-
}
190+
if let Ok(py_str) = <PyString as PyTryFrom>::try_from_exact(self) {
191+
Ok(py_str.into())
192+
} else if PyString::is_type_of(self) {
193+
Err(ValError::new(ErrorType::StringSubType, self))
196194
} else {
197195
Err(ValError::new(ErrorType::StringType, self))
198196
}
199197
}
200198

201199
fn lax_str(&'a self) -> ValResult<EitherString<'a>> {
202-
if let Ok(py_str) = self.downcast::<PyString>() {
203-
if is_builtin_str(py_str) {
204-
Ok(py_str.into())
205-
} else {
206-
// force to a rust string to make sure behaviour is consistent whether or not we go via a
207-
// rust string in StrConstrainedValidator - e.g. to_lower
208-
Ok(py_string_str(py_str)?.into())
209-
}
200+
if let Ok(py_str) = <PyString as PyTryFrom>::try_from_exact(self) {
201+
Ok(py_str.into())
202+
} else if let Ok(py_str) = self.downcast::<PyString>() {
203+
// force to a rust string to make sure behaviour is consistent whether or not we go via a
204+
// rust string in StrConstrainedValidator - e.g. to_lower
205+
Ok(py_string_str(py_str)?.into())
210206
} else if let Ok(bytes) = self.downcast::<PyBytes>() {
211207
let str = match from_utf8(bytes.as_bytes()) {
212208
Ok(s) => s,
@@ -647,10 +643,6 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult<Option<Cow<
647643
}
648644
}
649645

650-
fn is_builtin_str(py_str: &PyString) -> bool {
651-
py_str.get_type().is(PyString::type_object(py_str.py()))
652-
}
653-
654646
#[cfg(PyPy)]
655647
static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell<Py<PyType>> = pyo3::once_cell::GILOnceCell::new();
656648

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,7 @@ class SomeStrEnum(str, Enum):
13141314
LARGE_STR_PREFIX = 'a' * 50
13151315

13161316

1317+
@pytest.mark.benchmark(group='validate_literal')
13171318
@pytest.mark.parametrize(
13181319
'allowed_values,input,expected_val_res',
13191320
[

0 commit comments

Comments
 (0)