Skip to content

Commit b49dfda

Browse files
committed
Add optional serde serialization support
- Update ci to run serde tests - Add serialization support for Enums except the enum `arrayfire::Scalar` - Structs with serde support added - [x] Array - [x] Dim4 - [x] Seq - [x] RandomEngine - Structs without serde support - Features - currently not possible as `af_features` can't be recreated from individual `af_arrays` with current upstream API - Indexer - not possible with current API. Also, any subarray when fetched to host for serialization results in separate owned copy this making serde support for this unnecessary. - Callback - Event - Window
1 parent 97d097b commit b49dfda

File tree

8 files changed

+271
-5
lines changed

8 files changed

+271
-5
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ jobs:
4848
export AF_PATH=${GITHUB_WORKSPACE}/afbin
4949
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${AF_PATH}/lib64
5050
echo "Using cargo version: $(cargo --version)"
51-
cargo build --all
52-
cargo test --no-fail-fast
51+
cargo build --all --features="afserde"
52+
cargo test --no-fail-fast --features="afserde"
5353
5454
format:
5555
name: Format Check

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,19 @@ statistics = []
4646
vision = []
4747
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
4848
"ml", "macros", "random", "signal", "sparse", "statistics", "vision"]
49+
afserde = ["serde"]
4950

5051
[dependencies]
5152
libc = "0.2"
5253
num = "0.2"
5354
lazy_static = "1.0"
5455
half = "1.5.0"
56+
serde = { version = "1.0", features = ["derive"], optional = true }
5557

5658
[dev-dependencies]
5759
half = "1.5.0"
60+
serde_json = "1.0"
61+
bincode = "1.3"
5862

5963
[build-dependencies]
6064
serde_json = "1.0"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Only, Major(M) & Minor(m) version numbers need to match. *p1* and *p2* are patch
1616

1717
## Supported platforms
1818

19-
Linux, Windows and OSX. Rust 1.15.1 or higher is required.
19+
Linux, Windows and OSX. Rust 1.31 or newer is required.
2020

2121
## Use from Crates.io [![][6]][7] [![][8]][9]
2222

src/core/array.rs

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@ use super::error::HANDLE_ERROR;
44
use super::util::{af_array, dim_t, void_ptr, HasAfEnum};
55

66
use libc::{c_char, c_int, c_longlong, c_uint, c_void};
7+
#[cfg(feature = "afserde")]
8+
use serde::de::{Deserializer, Error, Unexpected};
9+
#[cfg(feature = "afserde")]
10+
use serde::ser::Serializer;
11+
#[cfg(feature = "afserde")]
12+
use serde::{Deserialize, Serialize};
13+
use std::clone::Clone;
14+
use std::default::Default;
715
use std::ffi::CString;
16+
use std::fmt::Debug;
817
use std::marker::PhantomData;
918

1019
// Some unused functions from array.h in C-API of ArrayFire
@@ -851,12 +860,73 @@ pub fn is_eval_manual() -> bool {
851860
}
852861
}
853862

863+
#[cfg(feature = "afserde")]
864+
#[derive(Debug, Serialize, Deserialize)]
865+
struct ArrayOnHost<T: HasAfEnum + Debug> {
866+
dtype: DType,
867+
shape: Dim4,
868+
data: Vec<T>,
869+
}
870+
871+
/// Serialize Implementation of Array
872+
#[cfg(feature = "afserde")]
873+
impl<T> Serialize for Array<T>
874+
where
875+
T: Default + Clone + Serialize + HasAfEnum + Debug,
876+
{
877+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
878+
where
879+
S: Serializer,
880+
{
881+
let mut vec = vec![T::default(); self.elements()];
882+
self.host(&mut vec);
883+
let arr_on_host = ArrayOnHost {
884+
dtype: self.get_type(),
885+
shape: self.dims().clone(),
886+
data: vec,
887+
};
888+
arr_on_host.serialize(serializer)
889+
}
890+
}
891+
892+
/// Deserialize Implementation of Array
893+
#[cfg(feature = "afserde")]
894+
impl<'de, T> Deserialize<'de> for Array<T>
895+
where
896+
T: Deserialize<'de> + HasAfEnum + Debug,
897+
{
898+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
899+
where
900+
D: Deserializer<'de>,
901+
{
902+
match ArrayOnHost::<T>::deserialize(deserializer) {
903+
Ok(arr_on_host) => {
904+
let read_dtype = arr_on_host.dtype;
905+
let expected_dtype = T::get_af_dtype();
906+
if expected_dtype != read_dtype {
907+
let error_msg = format!(
908+
"data type is {:?}, deserialized type is {:?}",
909+
expected_dtype, read_dtype
910+
);
911+
return Err(Error::invalid_value(Unexpected::Enum, &error_msg.as_str()));
912+
}
913+
Ok(Array::<T>::new(
914+
&arr_on_host.data,
915+
arr_on_host.shape.clone(),
916+
))
917+
}
918+
Err(err) => Err(err),
919+
}
920+
}
921+
}
922+
854923
#[cfg(test)]
855924
mod tests {
925+
use super::super::super::algorithm::sum_all;
856926
use super::super::array::print;
857927
use super::super::data::constant;
858928
use super::super::device::{info, set_device, sync};
859-
use crate::dim4;
929+
use crate::{dim4, randu};
860930
use std::sync::{mpsc, Arc, RwLock};
861931
use std::thread;
862932

@@ -1082,4 +1152,36 @@ mod tests {
10821152
// 8.0000 8.0000 8.0000
10831153
// ANCHOR_END: accum_using_channel
10841154
}
1155+
1156+
#[test]
1157+
#[cfg(feature = "afserde")]
1158+
fn array_serde_json() {
1159+
use super::Array;
1160+
1161+
let input = randu!(u8; 2, 2);
1162+
let serd = match serde_json::to_string(&input) {
1163+
Ok(serialized_str) => serialized_str,
1164+
Err(e) => e.to_string(),
1165+
};
1166+
1167+
let deserd: Array<u8> = serde_json::from_str(&serd).unwrap();
1168+
1169+
assert_eq!(sum_all(&(input - deserd)), (0u32, 0u32));
1170+
}
1171+
1172+
#[test]
1173+
#[cfg(feature = "afserde")]
1174+
fn array_serde_bincode() {
1175+
use super::Array;
1176+
1177+
let input = randu!(u8; 2, 2);
1178+
let encoded = match bincode::serialize(&input) {
1179+
Ok(encoded) => encoded,
1180+
Err(_) => vec![],
1181+
};
1182+
1183+
let decoded: Array<u8> = bincode::deserialize(&encoded).unwrap();
1184+
1185+
assert_eq!(sum_all(&(input - decoded)), (0u32, 0u32));
1186+
}
10851187
}

0 commit comments

Comments
 (0)