Skip to content

Update indexing tutorials #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/core/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,3 +958,22 @@ pub fn pad<T: HasAfEnum>(
temp.into()
}
}

#[cfg(test)]
mod tests {
use super::reorder_v2;

use super::super::random::randu;

use crate::dim4;

#[test]
fn check_reorder_api() {
let a = randu::<f32>(dim4!(4, 5, 2, 3));

let _transposed = reorder_v2(&a, 1, 0, None);
let _swap_0_2 = reorder_v2(&a, 2, 1, Some(vec![0]));
let _swap_1_2 = reorder_v2(&a, 0, 2, Some(vec![1]));
let _swap_0_3 = reorder_v2(&a, 3, 1, Some(vec![2, 0]));
}
}
164 changes: 164 additions & 0 deletions src/core/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,167 @@ impl SeqInternal {
}
}
}

#[cfg(test)]
mod tests {
use super::super::array::Array;
use super::super::data::constant;
use super::super::dim4::Dim4;
use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
use super::super::random::randu;
use super::super::seq::Seq;

use crate::{dim4, seq, view};

#[test]
fn non_macro_seq_index() {
// ANCHOR: non_macro_seq_index
let dims = Dim4::new(&[5, 5, 1, 1]);
let a = randu::<f32>(dims);
//af_print!("a", a);
//a
//[5 5 1 1]
// 0.3990 0.5160 0.8831 0.9107 0.6688
// 0.6720 0.3932 0.0621 0.9159 0.8434
// 0.5339 0.2706 0.7089 0.0231 0.1328
// 0.1386 0.9455 0.9434 0.2330 0.2657
// 0.7353 0.1587 0.1227 0.2220 0.2299

// Index array using sequences
let seqs = &[Seq::new(1u32, 3, 1), Seq::default()];
let _sub = index(&a, seqs);
//af_print!("a(seq(1,3,1), span)", sub);
// [3 5 1 1]
// 0.6720 0.3932 0.0621 0.9159 0.8434
// 0.5339 0.2706 0.7089 0.0231 0.1328
// 0.1386 0.9455 0.9434 0.2330 0.2657
// ANCHOR_END: non_macro_seq_index
}

#[test]
fn seq_index() {
// ANCHOR: seq_index
let dims = dim4!(5, 5, 1, 1);
let a = randu::<f32>(dims);
let first3 = seq!(1:3:1);
let allindim2 = seq!();
let _sub = view!(a[first3, allindim2]);
// ANCHOR_END: seq_index
}

#[test]
fn non_macro_seq_assign() {
// ANCHOR: non_macro_seq_assign
let mut a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
//print(&a);
// 2.0 2.0 2.0
// 2.0 2.0 2.0
// 2.0 2.0 2.0
// 2.0 2.0 2.0
// 2.0 2.0 2.0

let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
assign_seq(&mut a, seqs, &b);
//print(&a);
// 2.0 2.0 2.0
// 1.0 1.0 1.0
// 1.0 1.0 1.0
// 1.0 1.0 1.0
// 2.0 2.0 2.0
// ANCHOR_END: non_macro_seq_assign
}

#[test]
fn non_macro_seq_array_index() {
// ANCHOR: non_macro_seq_array_index
let values: [f32; 3] = [1.0, 2.0, 3.0];
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
let seq4gen = Seq::new(0.0, 2.0, 1.0);
let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
// [5 3 1 1]
// 0.0000 0.2190 0.3835
// 0.1315 0.0470 0.5194
// 0.7556 0.6789 0.8310
// 0.4587 0.6793 0.0346
// 0.5328 0.9347 0.0535

let mut idxrs = Indexer::default();
idxrs.set_index(&indices, 0, None); // 2nd arg is indexing dimension
idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd arg indicates batch operation

let _sub2 = index_gen(&a, idxrs);
//println!("a(indices, seq(0, 2, 1))"); print(&sub2);
// [3 3 1 1]
// 0.1315 0.0470 0.5194
// 0.7556 0.6789 0.8310
// 0.4587 0.6793 0.0346
// ANCHOR_END: non_macro_seq_array_index
}

#[test]
fn seq_array_index() {
// ANCHOR: seq_array_index
let values: [f32; 3] = [1.0, 2.0, 3.0];
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
let seq4gen = seq!(0:2:1);
let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
let _sub2 = view!(a[indices, seq4gen]);
// ANCHOR_END: seq_array_index
}

#[test]
fn non_macro_seq_array_assign() {
// ANCHOR: non_macro_seq_array_assign
let values: [f32; 3] = [1.0, 2.0, 3.0];
let indices = Array::new(&values, dim4!(3, 1, 1, 1));
let seq4gen = seq!(0:2:1);
let mut a = randu::<f32>(dim4!(5, 3, 1, 1));
// [5 3 1 1]
// 0.0000 0.2190 0.3835
// 0.1315 0.0470 0.5194
// 0.7556 0.6789 0.8310
// 0.4587 0.6793 0.0346
// 0.5328 0.9347 0.0535

let b = constant(2.0 as f32, dim4!(3, 3, 1, 1));

let mut idxrs = Indexer::default();
idxrs.set_index(&indices, 0, None); // 2nd arg is indexing dimension
idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd arg indicates batch operation

let _sub2 = assign_gen(&mut a, &idxrs, &b);
//println!("a(indices, seq(0, 2, 1))"); print(&sub2);
// [5 3 1 1]
// 0.0000 0.2190 0.3835
// 2.0000 2.0000 2.0000
// 2.0000 2.0000 2.0000
// 2.0000 2.0000 2.0000
// 0.5328 0.9347 0.0535
// ANCHOR_END: non_macro_seq_array_assign
}

#[test]
fn setrow() {
// ANCHOR: setrow
let a = randu::<f32>(dim4!(5, 5, 1, 1));
//print(&a);
// [5 5 1 1]
// 0.6010 0.5497 0.1583 0.3636 0.6755
// 0.0278 0.2864 0.3712 0.4165 0.6105
// 0.9806 0.3410 0.3543 0.5814 0.5232
// 0.2126 0.7509 0.6450 0.8962 0.5567
// 0.0655 0.4105 0.9675 0.3712 0.7896
let _r = row(&a, 4);
// [1 5 1 1]
// 0.0655 0.4105 0.9675 0.3712 0.7896
let _c = col(&a, 4);
// [5 1 1 1]
// 0.6755
// 0.6105
// 0.5232
// 0.5567
// 0.7896
// ANCHOR_END: setrow
}
}
85 changes: 80 additions & 5 deletions src/core/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ macro_rules! seq {
/// - A simple Array identifier.
/// - An Array with slicing info for indexing.
/// - An Array with slicing info and other arrays used for indexing.
///
/// Examples on how to use this macro are provided in the [tutorials book][1]
///
/// [1]: http://arrayfire.org/arrayfire-rust/book/indexing.html
#[macro_export]
macro_rules! view {
(@af_max_dims) => {
Expand All @@ -194,12 +198,13 @@ macro_rules! view {
};
( $array_ident:ident [ $($start:literal : $end:literal : $step:literal),+ ] ) => {
{
#[allow(non_snake_case)]
let AF_MAX_DIMS: usize = view!(@af_max_dims);
let mut seq_vec = Vec::<$crate::Seq<i32>>::with_capacity(AF_MAX_DIMS);
$(
seq_vec.push($crate::seq!($start:$end:$step));
)*
for span_place_holder in seq_vec.len()..AF_MAX_DIMS {
for _span_place_holder in seq_vec.len()..AF_MAX_DIMS {
seq_vec.push($crate::seq!());
}
$crate::index(&$array_ident, &seq_vec)
Expand All @@ -218,18 +223,88 @@ macro_rules! view {
};
($array_ident:ident [ $($_e:expr),+ ]) => {
{
#[allow(non_snake_case)]
let AF_MAX_DIMS: u32 = view!(@af_max_dims);
let span = $crate::seq!();
let mut idxrs = $crate::Indexer::default();

view!(@set_indexer 0, idxrs, $($_e),*);

let mut dimIx = idxrs.len() as u32;
while dimIx < AF_MAX_DIMS {
idxrs.set_index(&span, dimIx, None);
dimIx += 1;
let mut dim_ix = idxrs.len() as u32;
while dim_ix < AF_MAX_DIMS {
idxrs.set_index(&span, dim_ix, None);
dim_ix += 1;
}
$crate::index_gen(&$array_ident, idxrs)
}
};
}

#[cfg(test)]
mod tests {
use super::super::array::Array;
use super::super::index::index;
use super::super::random::randu;

#[test]
fn dim4_construction() {
let dim1d = dim4!(2);
let dim2d = dim4!(2, 3);
let dim3d = dim4!(2, 3, 4);
let dim4d = dim4!(2, 3, 4, 2);
let _dimn = dim4!(dim1d[0], dim2d[1], dim3d[2], dim4d[3]);
}

#[test]
fn seq_construction() {
let default_seq = seq!();
let _range_1_to_10_step_1 = seq!(0:9:1);
let _range_1_to_10_step_1_2 = seq!(f32; 0.0:9.0:1.5);
let _range_from_exprs = seq!(default_seq.begin(), default_seq.end(), default_seq.step());
let _range_from_exprs2 = seq!(f32; default_seq.begin() as f32,
default_seq.end() as f32, default_seq.step() as f32);
}

#[test]
fn seq_view() {
let mut dim4d = dim4!(5, 3, 2, 1);
dim4d[2] = 1;

let a = randu::<f32>(dim4d);
let seqs = &[seq!(1:3:1), seq!()];
let sub = index(&a, seqs);
af_print!("A", a);
af_print!("Indexed A", sub);
}

#[test]
fn view_macro() {
let dims = dim4!(5, 5, 2, 1);
let a = randu::<f32>(dims);
let b = a.clone();
let c = a.clone();
let d = a.clone();
let e = a.clone();

let v = view!(a);
af_print!("v = a[None]", v);

let m = view!(c[1:3:1, 1:3:2]);
af_print!("m = c[:, :]", m);

let x = seq!(1:3:1);
let y = seq!(1:3:2);
let u = view!(b[x, y]);
af_print!("u = b[seq(), seq()]", u);

let values: [u32; 3] = [1, 2, 3];
let indices = Array::new(&values, dim4!(3, 1, 1, 1));
let indices2 = Array::new(&values, dim4!(3, 1, 1, 1));

let w = view!(d[indices, indices2]);
af_print!("w = d[Array, Array]", w);

let z = view!(e[indices, y]);
af_print!("z = e[Array, Seq]", z);
}
}
13 changes: 0 additions & 13 deletions tests/data.rs

This file was deleted.

32 changes: 0 additions & 32 deletions tests/index_macro.rs

This file was deleted.

11 changes: 5 additions & 6 deletions tests/scalar_arith.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use ::arrayfire::*;
use float_cmp::approx_eq;

#[allow(non_snake_case)]
#[test]
fn check_scalar_arith() {
let dims = Dim4::new(&[5, 5, 1, 1]);
let A = randu::<f32>(dims);
let a = randu::<f32>(dims);
let s: f32 = 2.0;
let scalar_as_lhs = s * &A;
let scalar_as_rhs = &A * s;
let C = constant(s, dims);
let no_scalars = A * C;
let scalar_as_lhs = s * &a;
let scalar_as_rhs = &a * s;
let c = constant(s, dims);
let no_scalars = a * c;
let scalar_res_comp = eq(&scalar_as_lhs, &scalar_as_rhs, false);
let res_comp = eq(&scalar_as_lhs, &no_scalars, false);
let scalar_res = all_true_all(&scalar_res_comp);
Expand Down
17 changes: 0 additions & 17 deletions tests/seq_macro.rs

This file was deleted.

Loading