@@ -4,7 +4,16 @@ use super::error::HANDLE_ERROR;
4
4
use super :: util:: { af_array, dim_t, void_ptr, HasAfEnum } ;
5
5
6
6
use libc:: { c_char, c_int, c_longlong, c_uint, c_void} ;
7
+ #[ cfg( feature = "afserde" ) ]
8
+ use serde:: de:: { Deserializer , Error , Unexpected } ;
9
+ #[ cfg( feature = "afserde" ) ]
10
+ use serde:: ser:: Serializer ;
11
+ #[ cfg( feature = "afserde" ) ]
12
+ use serde:: { Deserialize , Serialize } ;
13
+ use std:: clone:: Clone ;
14
+ use std:: default:: Default ;
7
15
use std:: ffi:: CString ;
16
+ use std:: fmt:: Debug ;
8
17
use std:: marker:: PhantomData ;
9
18
10
19
// Some unused functions from array.h in C-API of ArrayFire
@@ -851,12 +860,72 @@ pub fn is_eval_manual() -> bool {
851
860
}
852
861
}
853
862
863
+ #[ derive( Debug , Serialize , Deserialize ) ]
864
+ struct ArrayOnHost < T : HasAfEnum + Debug > {
865
+ dtype : DType ,
866
+ shape : Dim4 ,
867
+ data : Vec < T > ,
868
+ }
869
+
870
+ /// Serialize Implementation of Array
871
+ #[ cfg( feature = "afserde" ) ]
872
+ impl < T > Serialize for Array < T >
873
+ where
874
+ T : Default + Clone + Serialize + HasAfEnum + Debug ,
875
+ {
876
+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
877
+ where
878
+ S : Serializer ,
879
+ {
880
+ let mut vec = vec ! [ T :: default ( ) ; self . elements( ) ] ;
881
+ self . host ( & mut vec) ;
882
+ let arr_on_host = ArrayOnHost {
883
+ dtype : self . get_type ( ) ,
884
+ shape : self . dims ( ) . clone ( ) ,
885
+ data : vec,
886
+ } ;
887
+ arr_on_host. serialize ( serializer)
888
+ }
889
+ }
890
+
891
+ /// Deserialize Implementation of Array
892
+ #[ cfg( feature = "afserde" ) ]
893
+ impl < ' de , T > Deserialize < ' de > for Array < T >
894
+ where
895
+ T : Deserialize < ' de > + HasAfEnum + Debug ,
896
+ {
897
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
898
+ where
899
+ D : Deserializer < ' de > ,
900
+ {
901
+ match ArrayOnHost :: < T > :: deserialize ( deserializer) {
902
+ Ok ( arr_on_host) => {
903
+ let read_dtype = arr_on_host. dtype ;
904
+ let expected_dtype = T :: get_af_dtype ( ) ;
905
+ if expected_dtype != read_dtype {
906
+ let error_msg = format ! (
907
+ "data type is {:?}, deserialized type is {:?}" ,
908
+ expected_dtype, read_dtype
909
+ ) ;
910
+ return Err ( Error :: invalid_value ( Unexpected :: Enum , & error_msg. as_str ( ) ) ) ;
911
+ }
912
+ Ok ( Array :: < T > :: new (
913
+ & arr_on_host. data ,
914
+ arr_on_host. shape . clone ( ) ,
915
+ ) )
916
+ }
917
+ Err ( err) => Err ( err) ,
918
+ }
919
+ }
920
+ }
921
+
854
922
#[ cfg( test) ]
855
923
mod tests {
924
+ use super :: super :: super :: algorithm:: sum_all;
856
925
use super :: super :: array:: print;
857
926
use super :: super :: data:: constant;
858
927
use super :: super :: device:: { info, set_device, sync} ;
859
- use crate :: dim4;
928
+ use crate :: { dim4, randu } ;
860
929
use std:: sync:: { mpsc, Arc , RwLock } ;
861
930
use std:: thread;
862
931
@@ -1082,4 +1151,36 @@ mod tests {
1082
1151
// 8.0000 8.0000 8.0000
1083
1152
// ANCHOR_END: accum_using_channel
1084
1153
}
1154
+
1155
+ #[ test]
1156
+ #[ cfg( feature = "afserde" ) ]
1157
+ fn array_serde_json ( ) {
1158
+ use super :: Array ;
1159
+
1160
+ let input = randu ! ( u8 ; 2 , 2 ) ;
1161
+ let serd = match serde_json:: to_string ( & input) {
1162
+ Ok ( serialized_str) => serialized_str,
1163
+ Err ( e) => e. to_string ( ) ,
1164
+ } ;
1165
+
1166
+ let deserd: Array < u8 > = serde_json:: from_str ( & serd) . unwrap ( ) ;
1167
+
1168
+ assert_eq ! ( sum_all( & ( input - deserd) ) , ( 0u32 , 0u32 ) ) ;
1169
+ }
1170
+
1171
+ #[ test]
1172
+ #[ cfg( feature = "afserde" ) ]
1173
+ fn array_serde_bincode ( ) {
1174
+ use super :: Array ;
1175
+
1176
+ let input = randu ! ( u8 ; 2 , 2 ) ;
1177
+ let encoded = match bincode:: serialize ( & input) {
1178
+ Ok ( encoded) => encoded,
1179
+ Err ( _) => vec ! [ ] ,
1180
+ } ;
1181
+
1182
+ let decoded: Array < u8 > = bincode:: deserialize ( & encoded) . unwrap ( ) ;
1183
+
1184
+ assert_eq ! ( sum_all( & ( input - decoded) ) , ( 0u32 , 0u32 ) ) ;
1185
+ }
1085
1186
}
0 commit comments