Skip to content

Commit c5a4261

Browse files
authored
clean up some string handling cases (#1381)
1 parent 507ff47 commit c5a4261

File tree

3 files changed

+31
-25
lines changed

3 files changed

+31
-25
lines changed

src/input/input_python.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::borrow::Cow;
21
use std::str::from_utf8;
32

43
use pyo3::intern;
@@ -145,12 +144,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
145144
Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)),
146145
}
147146
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
148-
// Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated,
149-
// and we immediately copy the bytes into a new Python string
150-
match from_utf8(unsafe { py_byte_array.as_bytes() }) {
151-
// Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the
152-
// final output needs to be Python anyway.
153-
Ok(s) => Ok(PyString::new_bound(self.py(), s).into()),
147+
match bytearray_to_str(py_byte_array) {
148+
Ok(py_str) => Ok(py_str.into()),
154149
Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)),
155150
}
156151
} else if coerce_numbers_to_str && !self.is_exact_instance_of::<PyBool>() && {
@@ -212,8 +207,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
212207
}
213208

214209
if !strict {
215-
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
216-
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
210+
if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
211+
return str_as_bool(self, s).map(ValidationMatch::lax);
217212
} else if let Some(int) = extract_i64(self) {
218213
return int_as_bool(self, int).map(ValidationMatch::lax);
219214
} else if let Ok(float) = self.extract::<f64>() {
@@ -249,8 +244,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
249244

250245
'lax: {
251246
if !strict {
252-
return if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? {
253-
str_as_int(self, &cow_str)
247+
return if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? {
248+
str_as_int(self, s)
254249
} else if self.is_exact_instance_of::<PyFloat>() {
255250
float_as_int(self, self.extract::<f64>()?)
256251
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
@@ -291,9 +286,9 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
291286
}
292287

293288
if !strict {
294-
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? {
289+
if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? {
295290
// checking for bytes and string is fast, so do this before isinstance(float)
296-
return str_as_float(self, &cow_str).map(ValidationMatch::lax);
291+
return str_as_float(self, s).map(ValidationMatch::lax);
297292
}
298293
}
299294

@@ -638,20 +633,31 @@ fn from_attributes_applicable(obj: &Bound<'_, PyAny>) -> bool {
638633
}
639634

640635
/// Utility for extracting a string from a PyAny, if possible.
641-
fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult<Option<Cow<'a, str>>> {
636+
fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult<Option<&'a str>> {
642637
if let Ok(py_string) = v.downcast::<PyString>() {
643-
let str = py_string_str(py_string)?;
644-
Ok(Some(Cow::Borrowed(str)))
638+
py_string_str(py_string).map(Some)
645639
} else if let Ok(bytes) = v.downcast::<PyBytes>() {
646640
match from_utf8(bytes.as_bytes()) {
647-
Ok(s) => Ok(Some(Cow::Owned(s.to_string()))),
641+
Ok(s) => Ok(Some(s)),
648642
Err(_) => Err(ValError::new(unicode_error, v)),
649643
}
650644
} else {
651645
Ok(None)
652646
}
653647
}
654648

649+
/// Decode a Python bytearray to a Python string.
650+
///
651+
/// Using Python's built-in machinery for this should be efficient and avoids questions around
652+
/// safety of concurrent mutation of the bytearray (by leaving that to the Python interpreter).
653+
fn bytearray_to_str<'py>(bytearray: &Bound<'py, PyByteArray>) -> PyResult<Bound<'py, PyString>> {
654+
let py = bytearray.py();
655+
let py_string = bytearray
656+
.call_method1(intern!(py, "decode"), (intern!(py, "utf-8"),))?
657+
.downcast_into()?;
658+
Ok(py_string)
659+
}
660+
655661
/// Utility for extracting an enum value, if possible.
656662
fn maybe_as_enum<'py>(v: &Bound<'py, PyAny>) -> Option<Bound<'py, PyAny>> {
657663
let py = v.py();

src/serializers/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ macro_rules! serialization_mode {
7777
return Ok(Self::default());
7878
};
7979
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), $config_key))?;
80-
raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(&raw.to_cow()?))
80+
raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(raw.to_str()?))
8181
}
8282
}
8383

src/serializers/fields.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl GeneralFieldsSerializer {
159159
for result in main_iter {
160160
let (key, value) = result?;
161161
let key_str = key_str(&key)?;
162-
let op_field = self.fields.get(key_str.as_ref());
162+
let op_field = self.fields.get(key_str);
163163
if extra.exclude_none && value.is_none() {
164164
if let Some(field) = op_field {
165165
if field.required {
@@ -169,7 +169,7 @@ impl GeneralFieldsSerializer {
169169
continue;
170170
}
171171
let field_extra = Extra {
172-
field_name: Some(&key_str),
172+
field_name: Some(key_str),
173173
..extra
174174
};
175175
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
@@ -236,13 +236,13 @@ impl GeneralFieldsSerializer {
236236
}
237237
let key_str = key_str(&key).map_err(py_err_se_err)?;
238238
let field_extra = Extra {
239-
field_name: Some(&key_str),
239+
field_name: Some(key_str),
240240
..extra
241241
};
242242

243243
let filter = self.filter.key_filter(&key, include, exclude).map_err(py_err_se_err)?;
244244
if let Some((next_include, next_exclude)) = filter {
245-
if let Some(field) = self.fields.get(key_str.as_ref()) {
245+
if let Some(field) = self.fields.get(key_str) {
246246
if let Some(ref serializer) = field.serializer {
247247
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
248248
let s = PydanticSerializer::new(
@@ -252,7 +252,7 @@ impl GeneralFieldsSerializer {
252252
next_exclude.as_ref(),
253253
&field_extra,
254254
);
255-
let output_key = field.get_key_json(&key_str, &field_extra);
255+
let output_key = field.get_key_json(key_str, &field_extra);
256256
map.serialize_entry(&output_key, &s)?;
257257
}
258258
}
@@ -446,8 +446,8 @@ impl TypeSerializer for GeneralFieldsSerializer {
446446
}
447447
}
448448

449-
fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<Cow<'a, str>> {
450-
key.downcast::<PyString>()?.to_cow()
449+
fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<&'a str> {
450+
key.downcast::<PyString>()?.to_str()
451451
}
452452

453453
fn dict_items<'py>(

0 commit comments

Comments
 (0)