Skip to content

Commit b986e6b

Browse files
authored
Merge pull request RustPython#3064 from qingshi163/array-pickle
Implement __reduce__ and __reduce_ex__ for array
2 parents ddf485c + ea69dc5 commit b986e6b

File tree

3 files changed

+128
-48
lines changed

3 files changed

+128
-48
lines changed

Lib/test/test_array.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,17 +247,13 @@ def test_deepcopy(self):
247247
self.assertNotEqual(id(a), id(b))
248248
self.assertEqual(a, b)
249249

250-
# TODO: RUSTPYTHON
251-
@unittest.expectedFailure
252250
def test_reduce_ex(self):
253251
a = array.array(self.typecode, self.example)
254252
for protocol in range(3):
255253
self.assertIs(a.__reduce_ex__(protocol)[0], array.array)
256254
for protocol in range(3, pickle.HIGHEST_PROTOCOL + 1):
257255
self.assertIs(a.__reduce_ex__(protocol)[0], array_reconstructor)
258256

259-
# TODO: RUSTPYTHON
260-
@unittest.expectedFailure
261257
def test_pickle(self):
262258
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
263259
a = array.array(self.typecode, self.example)
@@ -273,8 +269,6 @@ def test_pickle(self):
273269
self.assertEqual(a.x, b.x)
274270
self.assertEqual(type(a), type(b))
275271

276-
# TODO: RUSTPYTHON
277-
@unittest.expectedFailure
278272
def test_pickle_for_empty_array(self):
279273
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
280274
a = array.array(self.typecode)

extra_tests/snippets/stdlib_array.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from testutils import assert_raises
22
from array import array
3+
from pickle import dumps, loads
34

45
a1 = array("b", [0, 1, 2, 3])
56

@@ -96,4 +97,11 @@ def test_array_frombytes():
9697
with assert_raises(IndexError):
9798
a[0] = 42
9899
with assert_raises(IndexError):
99-
del a[42]
100+
del a[42]
101+
102+
test_str = '🌉abc🌐def🌉🌐'
103+
u = array('u', test_str)
104+
# skip as 2 bytes character enviroment with CPython is failing the test
105+
if u.itemsize >= 4:
106+
assert u.__reduce_ex__(1)[1][1] == list(test_str)
107+
assert loads(dumps(u, 1)) == loads(dumps(u, 3))

stdlib/src/array.rs

Lines changed: 119 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ mod array {
1212
};
1313
use crate::vm::{
1414
builtins::{
15-
PyByteArray, PyBytes, PyBytesRef, PyIntRef, PyList, PyListRef, PySliceRef, PyStr,
16-
PyStrRef, PyTypeRef,
15+
PyByteArray, PyBytes, PyBytesRef, PyDictRef, PyFloat, PyInt, PyIntRef, PyList,
16+
PyListRef, PySliceRef, PyStr, PyStrRef, PyTypeRef,
1717
},
1818
class_or_notimplemented,
1919
function::{
@@ -28,8 +28,8 @@ mod array {
2828
AsBuffer, AsMapping, Comparable, Iterable, IteratorIterable, PyComparisonOp,
2929
SlotConstructor, SlotIterator,
3030
},
31-
IdProtocol, PyComparisonValue, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
32-
TypeProtocol, VirtualMachine,
31+
IdProtocol, PyComparisonValue, PyObjectRef, PyObjectWrap, PyRef, PyResult, PyValue,
32+
TryFromObject, TypeProtocol, VirtualMachine,
3333
};
3434
use crossbeam_utils::atomic::AtomicCell;
3535
use itertools::Itertools;
@@ -460,6 +460,14 @@ mod array {
460460
})*
461461
}
462462
}
463+
464+
fn get_objects(&self, vm: &VirtualMachine) -> Vec<PyObjectRef> {
465+
match self {
466+
$(ArrayContentType::$n(v) => {
467+
v.iter().map(|&x| x.to_object(vm)).collect()
468+
})*
469+
}
470+
}
463471
}
464472
};
465473
}
@@ -486,32 +494,41 @@ mod array {
486494
trait ArrayElement: Sized {
487495
fn try_into_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self>;
488496
fn byteswap(self) -> Self;
497+
fn to_object(self, vm: &VirtualMachine) -> PyObjectRef;
489498
}
490499

491500
macro_rules! impl_array_element {
492-
($(($t:ty, $f_into:path, $f_swap:path),)*) => {$(
501+
($(($t:ty, $f_from:path, $f_swap:path, $f_to:path),)*) => {$(
493502
impl ArrayElement for $t {
494503
fn try_into_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
495-
$f_into(vm, obj)
504+
$f_from(vm, obj)
496505
}
497506
fn byteswap(self) -> Self {
498507
$f_swap(self)
499508
}
509+
fn to_object(self, vm: &VirtualMachine) -> PyObjectRef {
510+
$f_to(self).into_object(vm)
511+
}
500512
}
501513
)*};
502514
}
503515

504516
impl_array_element!(
505-
(i8, i8::try_from_object, i8::swap_bytes),
506-
(u8, u8::try_from_object, u8::swap_bytes),
507-
(i16, i16::try_from_object, i16::swap_bytes),
508-
(u16, u16::try_from_object, u16::swap_bytes),
509-
(i32, i32::try_from_object, i32::swap_bytes),
510-
(u32, u32::try_from_object, u32::swap_bytes),
511-
(i64, i64::try_from_object, i64::swap_bytes),
512-
(u64, u64::try_from_object, u64::swap_bytes),
513-
(f32, f32_try_into_from_object, f32_swap_bytes),
514-
(f64, f64_try_into_from_object, f64_swap_bytes),
517+
(i8, i8::try_from_object, i8::swap_bytes, PyInt::from),
518+
(u8, u8::try_from_object, u8::swap_bytes, PyInt::from),
519+
(i16, i16::try_from_object, i16::swap_bytes, PyInt::from),
520+
(u16, u16::try_from_object, u16::swap_bytes, PyInt::from),
521+
(i32, i32::try_from_object, i32::swap_bytes, PyInt::from),
522+
(u32, u32::try_from_object, u32::swap_bytes, PyInt::from),
523+
(i64, i64::try_from_object, i64::swap_bytes, PyInt::from),
524+
(u64, u64::try_from_object, u64::swap_bytes, PyInt::from),
525+
(
526+
f32,
527+
f32_try_into_from_object,
528+
f32_swap_bytes,
529+
pyfloat_from_f32
530+
),
531+
(f64, f64_try_into_from_object, f64_swap_bytes, PyFloat::from),
515532
);
516533

517534
fn f32_swap_bytes(x: f32) -> f32 {
@@ -530,6 +547,10 @@ mod array {
530547
ArgIntoFloat::try_from_object(vm, obj).map(|x| x.to_f64())
531548
}
532549

550+
fn pyfloat_from_f32(value: f32) -> PyFloat {
551+
PyFloat::from(value as f64)
552+
}
553+
533554
impl ArrayElement for WideChar {
534555
fn try_into_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
535556
PyStrRef::try_from_object(vm, obj)?
@@ -542,6 +563,9 @@ mod array {
542563
fn byteswap(self) -> Self {
543564
Self(self.0.swap_bytes())
544565
}
566+
fn to_object(self, _vm: &VirtualMachine) -> PyObjectRef {
567+
unreachable!()
568+
}
545569
}
546570

547571
fn u32_to_char(ch: u32) -> Result<char, String> {
@@ -731,6 +755,38 @@ mod array {
731755
}
732756
}
733757

758+
fn _wchar_bytes_to_string(
759+
bytes: &[u8],
760+
item_size: usize,
761+
vm: &VirtualMachine,
762+
) -> PyResult<String> {
763+
if item_size == 2 {
764+
// safe because every configuration of bytes for the types we support are valid
765+
let utf16 = unsafe {
766+
std::slice::from_raw_parts(
767+
bytes.as_ptr() as *const u16,
768+
bytes.len() / std::mem::size_of::<u16>(),
769+
)
770+
};
771+
Ok(String::from_utf16_lossy(utf16))
772+
} else {
773+
// safe because every configuration of bytes for the types we support are valid
774+
let chars = unsafe {
775+
std::slice::from_raw_parts(
776+
bytes.as_ptr() as *const u32,
777+
bytes.len() / std::mem::size_of::<u32>(),
778+
)
779+
};
780+
chars
781+
.iter()
782+
.map(|&ch| {
783+
// cpython issue 17223
784+
u32_to_char(ch).map_err(|msg| vm.new_value_error(msg))
785+
})
786+
.try_collect()
787+
}
788+
}
789+
734790
fn _unicode_to_wchar_bytes(utf8: &str, item_size: usize) -> Vec<u8> {
735791
if item_size == 2 {
736792
utf8.encode_utf16()
@@ -771,31 +827,7 @@ mod array {
771827
));
772828
}
773829
let bytes = array.get_bytes();
774-
if self.itemsize() == 2 {
775-
// safe because every configuration of bytes for the types we support are valid
776-
let utf16 = unsafe {
777-
std::slice::from_raw_parts(
778-
bytes.as_ptr() as *const u16,
779-
bytes.len() / std::mem::size_of::<u16>(),
780-
)
781-
};
782-
Ok(String::from_utf16_lossy(utf16))
783-
} else {
784-
// safe because every configuration of bytes for the types we support are valid
785-
let chars = unsafe {
786-
std::slice::from_raw_parts(
787-
bytes.as_ptr() as *const u32,
788-
bytes.len() / std::mem::size_of::<u32>(),
789-
)
790-
};
791-
chars
792-
.iter()
793-
.map(|&ch| {
794-
// cpython issue 17223
795-
u32_to_char(ch).map_err(|msg| vm.new_value_error(msg))
796-
})
797-
.try_collect()
798-
}
830+
Self::_wchar_bytes_to_string(bytes, self.itemsize(), vm)
799831
}
800832

801833
fn _from_bytes(&self, b: &[u8], itemsize: usize, vm: &VirtualMachine) -> PyResult<()> {
@@ -1079,6 +1111,52 @@ mod array {
10791111
}
10801112
Ok(true)
10811113
}
1114+
1115+
#[pymethod(magic)]
1116+
fn reduce_ex(
1117+
zelf: PyRef<Self>,
1118+
proto: usize,
1119+
vm: &VirtualMachine,
1120+
) -> PyResult<(PyObjectRef, PyObjectRef, Option<PyDictRef>)> {
1121+
if proto < 3 {
1122+
return Self::reduce(zelf, vm);
1123+
}
1124+
let array = zelf.read();
1125+
let cls = zelf.as_object().clone_class().into_object();
1126+
let typecode = vm.ctx.new_utf8_str(array.typecode_str());
1127+
let bytes = vm.ctx.new_bytes(array.get_bytes().to_vec());
1128+
let code = MachineFormatCode::from_typecode(array.typecode()).unwrap();
1129+
let code = PyInt::from(u8::from(code)).into_object(vm);
1130+
let module = vm.import("array", None, 0)?;
1131+
let func = vm.get_attribute(module, "_array_reconstructor")?;
1132+
Ok((
1133+
func,
1134+
vm.ctx.new_tuple(vec![cls, typecode, code, bytes]),
1135+
zelf.as_object().dict(),
1136+
))
1137+
}
1138+
1139+
#[pymethod(magic)]
1140+
fn reduce(
1141+
zelf: PyRef<Self>,
1142+
vm: &VirtualMachine,
1143+
) -> PyResult<(PyObjectRef, PyObjectRef, Option<PyDictRef>)> {
1144+
let array = zelf.read();
1145+
let cls = zelf.as_object().clone_class().into_object();
1146+
let typecode = vm.ctx.new_utf8_str(array.typecode_str());
1147+
let values = if array.typecode() == 'u' {
1148+
let s = Self::_wchar_bytes_to_string(array.get_bytes(), array.itemsize(), vm)?;
1149+
s.chars().map(|x| x.into_pyobject(vm)).collect()
1150+
} else {
1151+
array.get_objects(vm)
1152+
};
1153+
let values = vm.ctx.new_list(values);
1154+
Ok((
1155+
cls,
1156+
vm.ctx.new_tuple(vec![typecode, values]),
1157+
zelf.as_object().dict(),
1158+
))
1159+
}
10821160
}
10831161

10841162
impl Comparable for PyArray {

0 commit comments

Comments
 (0)