@@ -4,7 +4,16 @@ use super::error::HANDLE_ERROR;
4
4
use super :: util:: { af_array, dim_t, void_ptr, HasAfEnum } ;
5
5
6
6
use libc:: { c_char, c_int, c_longlong, c_uint, c_void} ;
7
+ #[ cfg( feature = "afserde" ) ]
8
+ use serde:: de:: { Deserializer , Error , Unexpected } ;
9
+ #[ cfg( feature = "afserde" ) ]
10
+ use serde:: ser:: Serializer ;
11
+ #[ cfg( feature = "afserde" ) ]
12
+ use serde:: { Deserialize , Serialize } ;
13
+ use std:: clone:: Clone ;
14
+ use std:: default:: Default ;
7
15
use std:: ffi:: CString ;
16
+ use std:: fmt:: Debug ;
8
17
use std:: marker:: PhantomData ;
9
18
10
19
// Some unused functions from array.h in C-API of ArrayFire
@@ -851,12 +860,73 @@ pub fn is_eval_manual() -> bool {
851
860
}
852
861
}
853
862
863
+ #[ cfg( feature = "afserde" ) ]
864
+ #[ derive( Debug , Serialize , Deserialize ) ]
865
+ struct ArrayOnHost < T : HasAfEnum + Debug > {
866
+ dtype : DType ,
867
+ shape : Dim4 ,
868
+ data : Vec < T > ,
869
+ }
870
+
871
+ /// Serialize Implementation of Array
872
+ #[ cfg( feature = "afserde" ) ]
873
+ impl < T > Serialize for Array < T >
874
+ where
875
+ T : Default + Clone + Serialize + HasAfEnum + Debug ,
876
+ {
877
+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
878
+ where
879
+ S : Serializer ,
880
+ {
881
+ let mut vec = vec ! [ T :: default ( ) ; self . elements( ) ] ;
882
+ self . host ( & mut vec) ;
883
+ let arr_on_host = ArrayOnHost {
884
+ dtype : self . get_type ( ) ,
885
+ shape : self . dims ( ) . clone ( ) ,
886
+ data : vec,
887
+ } ;
888
+ arr_on_host. serialize ( serializer)
889
+ }
890
+ }
891
+
892
+ /// Deserialize Implementation of Array
893
+ #[ cfg( feature = "afserde" ) ]
894
+ impl < ' de , T > Deserialize < ' de > for Array < T >
895
+ where
896
+ T : Deserialize < ' de > + HasAfEnum + Debug ,
897
+ {
898
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
899
+ where
900
+ D : Deserializer < ' de > ,
901
+ {
902
+ match ArrayOnHost :: < T > :: deserialize ( deserializer) {
903
+ Ok ( arr_on_host) => {
904
+ let read_dtype = arr_on_host. dtype ;
905
+ let expected_dtype = T :: get_af_dtype ( ) ;
906
+ if expected_dtype != read_dtype {
907
+ let error_msg = format ! (
908
+ "data type is {:?}, deserialized type is {:?}" ,
909
+ expected_dtype, read_dtype
910
+ ) ;
911
+ return Err ( Error :: invalid_value ( Unexpected :: Enum , & error_msg. as_str ( ) ) ) ;
912
+ }
913
+ Ok ( Array :: < T > :: new (
914
+ & arr_on_host. data ,
915
+ arr_on_host. shape . clone ( ) ,
916
+ ) )
917
+ }
918
+ Err ( err) => Err ( err) ,
919
+ }
920
+ }
921
+ }
922
+
854
923
#[ cfg( test) ]
855
924
mod tests {
925
+ use super :: super :: super :: algorithm:: sum_all;
856
926
use super :: super :: array:: print;
857
927
use super :: super :: data:: constant;
858
928
use super :: super :: device:: { info, set_device, sync} ;
859
- use crate :: dim4;
929
+ use crate :: { dim4, randu } ;
860
930
use std:: sync:: { mpsc, Arc , RwLock } ;
861
931
use std:: thread;
862
932
@@ -1082,4 +1152,36 @@ mod tests {
1082
1152
// 8.0000 8.0000 8.0000
1083
1153
// ANCHOR_END: accum_using_channel
1084
1154
}
1155
+
1156
+ #[ test]
1157
+ #[ cfg( feature = "afserde" ) ]
1158
+ fn array_serde_json ( ) {
1159
+ use super :: Array ;
1160
+
1161
+ let input = randu ! ( u8 ; 2 , 2 ) ;
1162
+ let serd = match serde_json:: to_string ( & input) {
1163
+ Ok ( serialized_str) => serialized_str,
1164
+ Err ( e) => e. to_string ( ) ,
1165
+ } ;
1166
+
1167
+ let deserd: Array < u8 > = serde_json:: from_str ( & serd) . unwrap ( ) ;
1168
+
1169
+ assert_eq ! ( sum_all( & ( input - deserd) ) , ( 0u32 , 0u32 ) ) ;
1170
+ }
1171
+
1172
+ #[ test]
1173
+ #[ cfg( feature = "afserde" ) ]
1174
+ fn array_serde_bincode ( ) {
1175
+ use super :: Array ;
1176
+
1177
+ let input = randu ! ( u8 ; 2 , 2 ) ;
1178
+ let encoded = match bincode:: serialize ( & input) {
1179
+ Ok ( encoded) => encoded,
1180
+ Err ( _) => vec ! [ ] ,
1181
+ } ;
1182
+
1183
+ let decoded: Array < u8 > = bincode:: deserialize ( & encoded) . unwrap ( ) ;
1184
+
1185
+ assert_eq ! ( sum_all( & ( input - decoded) ) , ( 0u32 , 0u32 ) ) ;
1186
+ }
1085
1187
}
0 commit comments