Skip to content

Commit e721533

Browse files
committed
Update API to reflect ArrayFire 3.7.0 release
1 parent 0e707f2 commit e721533

File tree

18 files changed

+1465
-12
lines changed

18 files changed

+1465
-12
lines changed

Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,25 @@ indexing = []
2727
graphics = []
2828
image = []
2929
lapack = []
30+
machine_learning = []
3031
macros = []
3132
random = []
3233
signal = []
3334
sparse = []
3435
statistics = []
3536
vision = []
3637
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
37-
"macros", "random", "signal", "sparse", "statistics", "vision"]
38+
"machine_learning", "macros", "random", "signal", "sparse", "statistics", "vision"]
3839

3940
[dependencies]
4041
libc = "0.2"
4142
num = "0.2"
4243
lazy_static = "1.0"
44+
half = "1.5.0"
4345

4446
[dev-dependencies]
4547
float-cmp = "0.6.0"
48+
half = "1.5.0"
4649

4750
[build-dependencies]
4851
serde_json = "1.0"
@@ -85,3 +88,7 @@ path = "examples/conway.rs"
8588
[[example]]
8689
name = "fft"
8790
path = "examples/fft.rs"
91+
92+
[[example]]
93+
name = "using_half"
94+
path = "examples/using_half.rs"

examples/using_half.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use arrayfire::*;
2+
use half::f16;
3+
4+
fn main() {
5+
set_device(0);
6+
info();
7+
8+
let values: Vec<_> = (1u8..101).map(f32::from).collect();
9+
10+
let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
11+
12+
let hvals = Array::new(&half_values, Dim4::new(&[10, 10, 1, 1]));
13+
14+
print(&hvals);
15+
}

src/algorithm/mod.rs

Lines changed: 256 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::array::Array;
55
use crate::defines::{AfError, BinaryOp};
66
use crate::error::HANDLE_ERROR;
77
use crate::util::{AfArray, MutAfArray, MutDouble, MutUint};
8-
use crate::util::{HasAfEnum, RealNumber, Scanable};
8+
use crate::util::{HasAfEnum, RealNumber, ReduceByKeyInput, Scanable};
99

1010
#[allow(dead_code)]
1111
extern "C" {
@@ -59,6 +59,71 @@ extern "C" {
5959
op: c_uint,
6060
inclusive: c_int,
6161
) -> c_int;
62+
fn af_all_true_by_key(
63+
keys_out: MutAfArray,
64+
vals_out: MutAfArray,
65+
keys: AfArray,
66+
vals: AfArray,
67+
dim: c_int,
68+
) -> c_int;
69+
fn af_any_true_by_key(
70+
keys_out: MutAfArray,
71+
vals_out: MutAfArray,
72+
keys: AfArray,
73+
vals: AfArray,
74+
dim: c_int,
75+
) -> c_int;
76+
fn af_count_by_key(
77+
keys_out: MutAfArray,
78+
vals_out: MutAfArray,
79+
keys: AfArray,
80+
vals: AfArray,
81+
dim: c_int,
82+
) -> c_int;
83+
fn af_max_by_key(
84+
keys_out: MutAfArray,
85+
vals_out: MutAfArray,
86+
keys: AfArray,
87+
vals: AfArray,
88+
dim: c_int,
89+
) -> c_int;
90+
fn af_min_by_key(
91+
keys_out: MutAfArray,
92+
vals_out: MutAfArray,
93+
keys: AfArray,
94+
vals: AfArray,
95+
dim: c_int,
96+
) -> c_int;
97+
fn af_product_by_key(
98+
keys_out: MutAfArray,
99+
vals_out: MutAfArray,
100+
keys: AfArray,
101+
vals: AfArray,
102+
dim: c_int,
103+
) -> c_int;
104+
fn af_product_by_key_nan(
105+
keys_out: MutAfArray,
106+
vals_out: MutAfArray,
107+
keys: AfArray,
108+
vals: AfArray,
109+
dim: c_int,
110+
nan_val: c_double,
111+
) -> c_int;
112+
fn af_sum_by_key(
113+
keys_out: MutAfArray,
114+
vals_out: MutAfArray,
115+
keys: AfArray,
116+
vals: AfArray,
117+
dim: c_int,
118+
) -> c_int;
119+
fn af_sum_by_key_nan(
120+
keys_out: MutAfArray,
121+
vals_out: MutAfArray,
122+
keys: AfArray,
123+
vals: AfArray,
124+
dim: c_int,
125+
nan_val: c_double,
126+
) -> c_int;
62127
}
63128

64129
macro_rules! dim_reduce_func_def {
@@ -1137,3 +1202,193 @@ where
11371202
}
11381203
temp.into()
11391204
}
1205+
1206+
macro_rules! dim_reduce_by_key_func_def {
1207+
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1208+
#[doc=$brief_str]
1209+
/// # Parameters
1210+
///
1211+
/// - `keys` - key Array
1212+
/// - `vals` - value Array
1213+
/// - `dim` - Dimension along which the input Array is reduced
1214+
///
1215+
/// # Return Values
1216+
///
1217+
/// Tuple of Arrays, with output keys and values after reduction
1218+
///
1219+
#[doc=$ex_str]
1220+
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1221+
dim: i32
1222+
) -> (Array<KeyType>, Array<$out_type>)
1223+
where
1224+
KeyType: ReduceByKeyInput,
1225+
ValueType: HasAfEnum,
1226+
$out_type: HasAfEnum,
1227+
{
1228+
let mut out_keys: i64 = 0;
1229+
let mut out_vals: i64 = 0;
1230+
unsafe {
1231+
let err_val = $ffi_name(
1232+
&mut out_keys as MutAfArray,
1233+
&mut out_vals as MutAfArray,
1234+
keys.get() as AfArray,
1235+
vals.get() as AfArray,
1236+
dim as c_int,
1237+
);
1238+
HANDLE_ERROR(AfError::from(err_val));
1239+
}
1240+
(out_keys.into(), out_vals.into())
1241+
}
1242+
};
1243+
}
1244+
1245+
dim_reduce_by_key_func_def!(
1246+
"
1247+
Key based AND of elements along a given dimension
1248+
1249+
All positive non-zero values are considered true, while negative and zero
1250+
values are considered as false.
1251+
",
1252+
"
1253+
# Examples
1254+
```rust
1255+
use arrayfire::{Dim4, print, randu, all_true_by_key};
1256+
let dims = Dim4::new(&[5, 3, 1, 1]);
1257+
let vals = randu::<f32>(dims);
1258+
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1259+
print(&vals);
1260+
print(&keys);
1261+
let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1262+
print(&out_keys);
1263+
print(&out_vals);
1264+
```
1265+
",
1266+
all_true_by_key,
1267+
af_all_true_by_key,
1268+
ValueType::AggregateOutType
1269+
);
1270+
1271+
dim_reduce_by_key_func_def!(
1272+
"
1273+
Key based OR of elements along a given dimension
1274+
1275+
All positive non-zero values are considered true, while negative and zero
1276+
values are considered as false.
1277+
",
1278+
"
1279+
# Examples
1280+
```rust
1281+
use arrayfire::{Dim4, print, randu, any_true_by_key};
1282+
let dims = Dim4::new(&[5, 3, 1, 1]);
1283+
let vals = randu::<f32>(dims);
1284+
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1285+
print(&vals);
1286+
print(&keys);
1287+
let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1288+
print(&out_keys);
1289+
print(&out_vals);
1290+
```
1291+
",
1292+
any_true_by_key,
1293+
af_any_true_by_key,
1294+
ValueType::AggregateOutType
1295+
);
1296+
1297+
dim_reduce_by_key_func_def!(
1298+
"Find total count of elements with similar keys along a given dimension",
1299+
"",
1300+
count_by_key,
1301+
af_count_by_key,
1302+
ValueType::AggregateOutType
1303+
);
1304+
1305+
dim_reduce_by_key_func_def!(
1306+
"Find maximum among values of similar keys along a given dimension",
1307+
"",
1308+
max_by_key,
1309+
af_max_by_key,
1310+
ValueType::AggregateOutType
1311+
);
1312+
1313+
dim_reduce_by_key_func_def!(
1314+
"Find minimum among values of similar keys along a given dimension",
1315+
"",
1316+
min_by_key,
1317+
af_min_by_key,
1318+
ValueType::AggregateOutType
1319+
);
1320+
1321+
dim_reduce_by_key_func_def!(
1322+
"Find product of all values with similar keys along a given dimension",
1323+
"",
1324+
product_by_key,
1325+
af_product_by_key,
1326+
ValueType::ProductOutType
1327+
);
1328+
1329+
dim_reduce_by_key_func_def!(
1330+
"Find sum of all values with similar keys along a given dimension",
1331+
"",
1332+
sum_by_key,
1333+
af_sum_by_key,
1334+
ValueType::AggregateOutType
1335+
);
1336+
1337+
macro_rules! dim_reduce_by_key_nan_func_def {
1338+
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1339+
#[doc=$brief_str]
1340+
///
1341+
/// This version of sum by key can replaced all NaN values in the input
1342+
/// with a user provided value before performing the reduction operation.
1343+
/// # Parameters
1344+
///
1345+
/// - `keys` - key Array
1346+
/// - `vals` - value Array
1347+
/// - `dim` - Dimension along which the input Array is reduced
1348+
///
1349+
/// # Return Values
1350+
///
1351+
/// Tuple of Arrays, with output keys and values after reduction
1352+
///
1353+
#[doc=$ex_str]
1354+
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1355+
dim: i32, replace_value: f64
1356+
) -> (Array<KeyType>, Array<$out_type>)
1357+
where
1358+
KeyType: ReduceByKeyInput,
1359+
ValueType: HasAfEnum,
1360+
$out_type: HasAfEnum,
1361+
{
1362+
let mut out_keys: i64 = 0;
1363+
let mut out_vals: i64 = 0;
1364+
unsafe {
1365+
let err_val = $ffi_name(
1366+
&mut out_keys as MutAfArray,
1367+
&mut out_vals as MutAfArray,
1368+
keys.get() as AfArray,
1369+
vals.get() as AfArray,
1370+
dim as c_int,
1371+
replace_value as c_double,
1372+
);
1373+
HANDLE_ERROR(AfError::from(err_val));
1374+
}
1375+
(out_keys.into(), out_vals.into())
1376+
}
1377+
};
1378+
}
1379+
1380+
dim_reduce_by_key_nan_func_def!(
1381+
"Compute sum of all values with similar keys along a given dimension",
1382+
"",
1383+
sum_by_key_nan,
1384+
af_sum_by_key_nan,
1385+
ValueType::AggregateOutType
1386+
);
1387+
1388+
dim_reduce_by_key_nan_func_def!(
1389+
"Compute product of all values with similar keys along a given dimension",
1390+
"",
1391+
product_by_key_nan,
1392+
af_product_by_key_nan,
1393+
ValueType::ProductOutType
1394+
);

src/arith/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ extern "C" {
8585
fn af_log10(out: MutAfArray, arr: AfArray) -> c_int;
8686
fn af_log2(out: MutAfArray, arr: AfArray) -> c_int;
8787
fn af_sqrt(out: MutAfArray, arr: AfArray) -> c_int;
88+
fn af_rsqrt(out: MutAfArray, arr: AfArray) -> c_int;
8889
fn af_cbrt(out: MutAfArray, arr: AfArray) -> c_int;
8990
fn af_factorial(out: MutAfArray, arr: AfArray) -> c_int;
9091
fn af_tgamma(out: MutAfArray, arr: AfArray) -> c_int;
@@ -199,6 +200,12 @@ unary_func!("Compute the natural logarithm", log, af_log, UnaryOutType);
199200
unary_func!("Compute sin", sin, af_sin, UnaryOutType);
200201
unary_func!("Compute sinh", sinh, af_sinh, UnaryOutType);
201202
unary_func!("Compute the square root", sqrt, af_sqrt, UnaryOutType);
203+
unary_func!(
204+
"Compute the reciprocal square root",
205+
rsqrt,
206+
af_rsqrt,
207+
UnaryOutType
208+
);
202209
unary_func!("Compute tan", tan, af_tan, UnaryOutType);
203210
unary_func!("Compute tanh", tanh, af_tanh, UnaryOutType);
204211

src/array.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,29 @@ where
166166
///
167167
/// # Examples
168168
///
169+
/// An example of creating an Array from f32 array
170+
///
169171
/// ```rust
170172
/// use arrayfire::{Array, Dim4, print};
171173
/// let values: [f32; 3] = [1.0, 2.0, 3.0];
172174
/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
173175
/// print(&indices);
174176
/// ```
177+
/// An example of creating an Array from half::f16 array
178+
///
179+
/// ```rust
180+
/// use arrayfire::{Array, Dim4, print};
181+
/// use half::f16;
182+
///
183+
/// let values: [f32; 3] = [1.0, 2.0, 3.0];
184+
///
185+
/// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
186+
///
187+
/// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1]));
188+
///
189+
/// print(&hvals);
190+
/// ```
191+
///
175192
#[allow(unused_mut)]
176193
pub fn new(slice: &[T], dims: Dim4) -> Self {
177194
let aftype = T::get_af_dtype();
@@ -218,7 +235,7 @@ where
218235
///
219236
/// ```rust
220237
/// use arrayfire::{Array, Dim4};
221-
/// let garbageVals = Array::<f32>::new_empty(Dim4::new(&[3, 1, 1, 1]));
238+
/// let garbage_vals = Array::<f32>::new_empty(Dim4::new(&[3, 1, 1, 1]));
222239
/// ```
223240
#[allow(unused_mut)]
224241
pub fn new_empty(dims: Dim4) -> Self {
@@ -353,6 +370,11 @@ where
353370
self.handle
354371
}
355372

373+
/// Returns the native FFI handle for Rust object `Array`
374+
pub fn set(&mut self, handle: i64) {
375+
self.handle = handle;
376+
}
377+
356378
/// Copies the data from the Array to the mutable slice `data`
357379
pub fn host<O: HasAfEnum>(&self, data: &mut [O]) {
358380
if data.len() != self.elements() {

0 commit comments

Comments
 (0)