Skip to content

Commit f0c4d15

Browse files
committed
Add optional serde serialization support
- Update ci to run serde tests - Add serialization support for Enums TODO - Structs
1 parent 97d097b commit f0c4d15

File tree

6 files changed

+184
-4
lines changed

6 files changed

+184
-4
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ jobs:
4848
export AF_PATH=${GITHUB_WORKSPACE}/afbin
4949
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${AF_PATH}/lib64
5050
echo "Using cargo version: $(cargo --version)"
51-
cargo build --all
52-
cargo test --no-fail-fast
51+
cargo build --all --features="afserde"
52+
cargo test --no-fail-fast --features="afserde"
5353
5454
format:
5555
name: Format Check

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,19 @@ statistics = []
4646
vision = []
4747
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
4848
"ml", "macros", "random", "signal", "sparse", "statistics", "vision"]
49+
afserde = ["serde"]
4950

5051
[dependencies]
5152
libc = "0.2"
5253
num = "0.2"
5354
lazy_static = "1.0"
5455
half = "1.5.0"
56+
serde = { version = "1.0", features = ["derive"], optional = true }
5557

5658
[dev-dependencies]
5759
half = "1.5.0"
60+
serde_json = "1.0"
61+
bincode = "1.3"
5862

5963
[build-dependencies]
6064
serde_json = "1.0"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Only, Major(M) & Minor(m) version numbers need to match. *p1* and *p2* are patch
1616

1717
## Supported platforms
1818

19-
Linux, Windows and OSX. Rust 1.15.1 or higher is required.
19+
Linux, Windows and OSX. Rust 1.31 or newer is required.
2020

2121
## Use from Crates.io [![][6]][7] [![][8]][9]
2222

src/core/array.rs

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@ use super::error::HANDLE_ERROR;
44
use super::util::{af_array, dim_t, void_ptr, HasAfEnum};
55

66
use libc::{c_char, c_int, c_longlong, c_uint, c_void};
7+
#[cfg(feature = "afserde")]
8+
use serde::de::{Deserializer, Error, Unexpected};
9+
#[cfg(feature = "afserde")]
10+
use serde::ser::Serializer;
11+
#[cfg(feature = "afserde")]
12+
use serde::{Deserialize, Serialize};
13+
use std::clone::Clone;
14+
use std::default::Default;
715
use std::ffi::CString;
16+
use std::fmt::Debug;
817
use std::marker::PhantomData;
918

1019
// Some unused functions from array.h in C-API of ArrayFire
@@ -851,12 +860,72 @@ pub fn is_eval_manual() -> bool {
851860
}
852861
}
853862

863+
#[derive(Debug, Serialize, Deserialize)]
864+
struct ArrayOnHost<T: HasAfEnum + Debug> {
865+
dtype: DType,
866+
shape: Dim4,
867+
data: Vec<T>,
868+
}
869+
870+
/// Serialize Implementation of Array
871+
#[cfg(feature = "afserde")]
872+
impl<T> Serialize for Array<T>
873+
where
874+
T: Default + Clone + Serialize + HasAfEnum + Debug,
875+
{
876+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
877+
where
878+
S: Serializer,
879+
{
880+
let mut vec = vec![T::default(); self.elements()];
881+
self.host(&mut vec);
882+
let arr_on_host = ArrayOnHost {
883+
dtype: self.get_type(),
884+
shape: self.dims().clone(),
885+
data: vec,
886+
};
887+
arr_on_host.serialize(serializer)
888+
}
889+
}
890+
891+
/// Deserialize Implementation of Array
892+
#[cfg(feature = "afserde")]
893+
impl<'de, T> Deserialize<'de> for Array<T>
894+
where
895+
T: Deserialize<'de> + HasAfEnum + Debug,
896+
{
897+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
898+
where
899+
D: Deserializer<'de>,
900+
{
901+
match ArrayOnHost::<T>::deserialize(deserializer) {
902+
Ok(arr_on_host) => {
903+
let read_dtype = arr_on_host.dtype;
904+
let expected_dtype = T::get_af_dtype();
905+
if expected_dtype != read_dtype {
906+
let error_msg = format!(
907+
"data type is {:?}, deserialized type is {:?}",
908+
expected_dtype, read_dtype
909+
);
910+
return Err(Error::invalid_value(Unexpected::Enum, &error_msg.as_str()));
911+
}
912+
Ok(Array::<T>::new(
913+
&arr_on_host.data,
914+
arr_on_host.shape.clone(),
915+
))
916+
}
917+
Err(err) => Err(err),
918+
}
919+
}
920+
}
921+
854922
#[cfg(test)]
855923
mod tests {
924+
use super::super::super::algorithm::sum_all;
856925
use super::super::array::print;
857926
use super::super::data::constant;
858927
use super::super::device::{info, set_device, sync};
859-
use crate::dim4;
928+
use crate::{dim4, randu};
860929
use std::sync::{mpsc, Arc, RwLock};
861930
use std::thread;
862931

@@ -1082,4 +1151,36 @@ mod tests {
10821151
// 8.0000 8.0000 8.0000
10831152
// ANCHOR_END: accum_using_channel
10841153
}
1154+
1155+
#[test]
1156+
#[cfg(feature = "afserde")]
1157+
fn array_serde_json() {
1158+
use super::Array;
1159+
1160+
let input = randu!(u8; 2, 2);
1161+
let serd = match serde_json::to_string(&input) {
1162+
Ok(serialized_str) => serialized_str,
1163+
Err(e) => e.to_string(),
1164+
};
1165+
1166+
let deserd: Array<u8> = serde_json::from_str(&serd).unwrap();
1167+
1168+
assert_eq!(sum_all(&(input - deserd)), (0u32, 0u32));
1169+
}
1170+
1171+
#[test]
1172+
#[cfg(feature = "afserde")]
1173+
fn array_serde_bincode() {
1174+
use super::Array;
1175+
1176+
let input = randu!(u8; 2, 2);
1177+
let encoded = match bincode::serialize(&input) {
1178+
Ok(encoded) => encoded,
1179+
Err(_) => vec![],
1180+
};
1181+
1182+
let decoded: Array<u8> = bincode::deserialize(&encoded).unwrap();
1183+
1184+
assert_eq!(sum_all(&(input - decoded)), (0u32, 0u32));
1185+
}
10851186
}

0 commit comments

Comments
 (0)