@@ -5,7 +5,7 @@ use crate::array::Array;
5
5
use crate :: defines:: { AfError , BinaryOp } ;
6
6
use crate :: error:: HANDLE_ERROR ;
7
7
use crate :: util:: { AfArray , MutAfArray , MutDouble , MutUint } ;
8
- use crate :: util:: { HasAfEnum , RealNumber , Scanable } ;
8
+ use crate :: util:: { HasAfEnum , RealNumber , ReduceByKeyInput , Scanable } ;
9
9
10
10
#[ allow( dead_code) ]
11
11
extern "C" {
@@ -59,6 +59,71 @@ extern "C" {
59
59
op : c_uint ,
60
60
inclusive : c_int ,
61
61
) -> c_int ;
62
+ fn af_all_true_by_key (
63
+ keys_out : MutAfArray ,
64
+ vals_out : MutAfArray ,
65
+ keys : AfArray ,
66
+ vals : AfArray ,
67
+ dim : c_int ,
68
+ ) -> c_int ;
69
+ fn af_any_true_by_key (
70
+ keys_out : MutAfArray ,
71
+ vals_out : MutAfArray ,
72
+ keys : AfArray ,
73
+ vals : AfArray ,
74
+ dim : c_int ,
75
+ ) -> c_int ;
76
+ fn af_count_by_key (
77
+ keys_out : MutAfArray ,
78
+ vals_out : MutAfArray ,
79
+ keys : AfArray ,
80
+ vals : AfArray ,
81
+ dim : c_int ,
82
+ ) -> c_int ;
83
+ fn af_max_by_key (
84
+ keys_out : MutAfArray ,
85
+ vals_out : MutAfArray ,
86
+ keys : AfArray ,
87
+ vals : AfArray ,
88
+ dim : c_int ,
89
+ ) -> c_int ;
90
+ fn af_min_by_key (
91
+ keys_out : MutAfArray ,
92
+ vals_out : MutAfArray ,
93
+ keys : AfArray ,
94
+ vals : AfArray ,
95
+ dim : c_int ,
96
+ ) -> c_int ;
97
+ fn af_product_by_key (
98
+ keys_out : MutAfArray ,
99
+ vals_out : MutAfArray ,
100
+ keys : AfArray ,
101
+ vals : AfArray ,
102
+ dim : c_int ,
103
+ ) -> c_int ;
104
+ fn af_product_by_key_nan (
105
+ keys_out : MutAfArray ,
106
+ vals_out : MutAfArray ,
107
+ keys : AfArray ,
108
+ vals : AfArray ,
109
+ dim : c_int ,
110
+ nan_val : c_double ,
111
+ ) -> c_int ;
112
+ fn af_sum_by_key (
113
+ keys_out : MutAfArray ,
114
+ vals_out : MutAfArray ,
115
+ keys : AfArray ,
116
+ vals : AfArray ,
117
+ dim : c_int ,
118
+ ) -> c_int ;
119
+ fn af_sum_by_key_nan (
120
+ keys_out : MutAfArray ,
121
+ vals_out : MutAfArray ,
122
+ keys : AfArray ,
123
+ vals : AfArray ,
124
+ dim : c_int ,
125
+ nan_val : c_double ,
126
+ ) -> c_int ;
62
127
}
63
128
64
129
macro_rules! dim_reduce_func_def {
@@ -1137,3 +1202,193 @@ where
1137
1202
}
1138
1203
temp. into ( )
1139
1204
}
1205
+
1206
+ macro_rules! dim_reduce_by_key_func_def {
1207
+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1208
+ #[ doc=$brief_str]
1209
+ /// # Parameters
1210
+ ///
1211
+ /// - `keys` - key Array
1212
+ /// - `vals` - value Array
1213
+ /// - `dim` - Dimension along which the input Array is reduced
1214
+ ///
1215
+ /// # Return Values
1216
+ ///
1217
+ /// Tuple of Arrays, with output keys and values after reduction
1218
+ ///
1219
+ #[ doc=$ex_str]
1220
+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1221
+ dim: i32
1222
+ ) -> ( Array <KeyType >, Array <$out_type>)
1223
+ where
1224
+ KeyType : ReduceByKeyInput ,
1225
+ ValueType : HasAfEnum ,
1226
+ $out_type: HasAfEnum ,
1227
+ {
1228
+ let mut out_keys: i64 = 0 ;
1229
+ let mut out_vals: i64 = 0 ;
1230
+ unsafe {
1231
+ let err_val = $ffi_name(
1232
+ & mut out_keys as MutAfArray ,
1233
+ & mut out_vals as MutAfArray ,
1234
+ keys. get( ) as AfArray ,
1235
+ vals. get( ) as AfArray ,
1236
+ dim as c_int,
1237
+ ) ;
1238
+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1239
+ }
1240
+ ( out_keys. into( ) , out_vals. into( ) )
1241
+ }
1242
+ } ;
1243
+ }
1244
+
1245
+ dim_reduce_by_key_func_def ! (
1246
+ "
1247
+ Key based AND of elements along a given dimension
1248
+
1249
+ All positive non-zero values are considered true, while negative and zero
1250
+ values are considered as false.
1251
+ " ,
1252
+ "
1253
+ # Examples
1254
+ ```rust
1255
+ use arrayfire::{Dim4, print, randu, all_true_by_key};
1256
+ let dims = Dim4::new(&[5, 3, 1, 1]);
1257
+ let vals = randu::<f32>(dims);
1258
+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1259
+ print(&vals);
1260
+ print(&keys);
1261
+ let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1262
+ print(&out_keys);
1263
+ print(&out_vals);
1264
+ ```
1265
+ " ,
1266
+ all_true_by_key,
1267
+ af_all_true_by_key,
1268
+ ValueType :: AggregateOutType
1269
+ ) ;
1270
+
1271
+ dim_reduce_by_key_func_def ! (
1272
+ "
1273
+ Key based OR of elements along a given dimension
1274
+
1275
+ All positive non-zero values are considered true, while negative and zero
1276
+ values are considered as false.
1277
+ " ,
1278
+ "
1279
+ # Examples
1280
+ ```rust
1281
+ use arrayfire::{Dim4, print, randu, any_true_by_key};
1282
+ let dims = Dim4::new(&[5, 3, 1, 1]);
1283
+ let vals = randu::<f32>(dims);
1284
+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1285
+ print(&vals);
1286
+ print(&keys);
1287
+ let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1288
+ print(&out_keys);
1289
+ print(&out_vals);
1290
+ ```
1291
+ " ,
1292
+ any_true_by_key,
1293
+ af_any_true_by_key,
1294
+ ValueType :: AggregateOutType
1295
+ ) ;
1296
+
1297
+ dim_reduce_by_key_func_def ! (
1298
+ "Find total count of elements with similar keys along a given dimension" ,
1299
+ "" ,
1300
+ count_by_key,
1301
+ af_count_by_key,
1302
+ ValueType :: AggregateOutType
1303
+ ) ;
1304
+
1305
+ dim_reduce_by_key_func_def ! (
1306
+ "Find maximum among values of similar keys along a given dimension" ,
1307
+ "" ,
1308
+ max_by_key,
1309
+ af_max_by_key,
1310
+ ValueType :: AggregateOutType
1311
+ ) ;
1312
+
1313
+ dim_reduce_by_key_func_def ! (
1314
+ "Find minimum among values of similar keys along a given dimension" ,
1315
+ "" ,
1316
+ min_by_key,
1317
+ af_min_by_key,
1318
+ ValueType :: AggregateOutType
1319
+ ) ;
1320
+
1321
+ dim_reduce_by_key_func_def ! (
1322
+ "Find product of all values with similar keys along a given dimension" ,
1323
+ "" ,
1324
+ product_by_key,
1325
+ af_product_by_key,
1326
+ ValueType :: ProductOutType
1327
+ ) ;
1328
+
1329
+ dim_reduce_by_key_func_def ! (
1330
+ "Find sum of all values with similar keys along a given dimension" ,
1331
+ "" ,
1332
+ sum_by_key,
1333
+ af_sum_by_key,
1334
+ ValueType :: AggregateOutType
1335
+ ) ;
1336
+
1337
+ macro_rules! dim_reduce_by_key_nan_func_def {
1338
+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1339
+ #[ doc=$brief_str]
1340
+ ///
1341
+ /// This version of sum by key can replaced all NaN values in the input
1342
+ /// with a user provided value before performing the reduction operation.
1343
+ /// # Parameters
1344
+ ///
1345
+ /// - `keys` - key Array
1346
+ /// - `vals` - value Array
1347
+ /// - `dim` - Dimension along which the input Array is reduced
1348
+ ///
1349
+ /// # Return Values
1350
+ ///
1351
+ /// Tuple of Arrays, with output keys and values after reduction
1352
+ ///
1353
+ #[ doc=$ex_str]
1354
+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1355
+ dim: i32 , replace_value: f64
1356
+ ) -> ( Array <KeyType >, Array <$out_type>)
1357
+ where
1358
+ KeyType : ReduceByKeyInput ,
1359
+ ValueType : HasAfEnum ,
1360
+ $out_type: HasAfEnum ,
1361
+ {
1362
+ let mut out_keys: i64 = 0 ;
1363
+ let mut out_vals: i64 = 0 ;
1364
+ unsafe {
1365
+ let err_val = $ffi_name(
1366
+ & mut out_keys as MutAfArray ,
1367
+ & mut out_vals as MutAfArray ,
1368
+ keys. get( ) as AfArray ,
1369
+ vals. get( ) as AfArray ,
1370
+ dim as c_int,
1371
+ replace_value as c_double,
1372
+ ) ;
1373
+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1374
+ }
1375
+ ( out_keys. into( ) , out_vals. into( ) )
1376
+ }
1377
+ } ;
1378
+ }
1379
+
1380
+ dim_reduce_by_key_nan_func_def ! (
1381
+ "Compute sum of all values with similar keys along a given dimension" ,
1382
+ "" ,
1383
+ sum_by_key_nan,
1384
+ af_sum_by_key_nan,
1385
+ ValueType :: AggregateOutType
1386
+ ) ;
1387
+
1388
+ dim_reduce_by_key_nan_func_def ! (
1389
+ "Compute product of all values with similar keys along a given dimension" ,
1390
+ "" ,
1391
+ product_by_key_nan,
1392
+ af_product_by_key_nan,
1393
+ ValueType :: ProductOutType
1394
+ ) ;
0 commit comments