Skip to content

Change lhs/inout parameter of assign functions to mutable ref #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/acoustic_wave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ fn acoustic_wave_simulation() {
// Location of the source.
let seqs = &[Seq::new(700.0, 800.0, 1.0), Seq::new(800.0, 800.0, 1.0)];
// Set the pressure there.
p = assign_seq(
&p,
assign_seq(
&mut p,
seqs,
&index(&pulse, &[Seq::new(it as f64, it as f64, 1.0)]),
);
Expand Down
6 changes: 3 additions & 3 deletions examples/helloworld.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn main() {

let dims = Dim4::new(&[num_rows, num_cols, 1, 1]);

let a = randu::<f32>(dims);
let mut a = randu::<f32>(dims);
af_print!("Create a 5-by-3 float matrix on the GPU", a);

println!("Element-wise arithmetic");
Expand Down Expand Up @@ -67,8 +67,8 @@ fn main() {
let r_dims = Dim4::new(&[3, 1, 1, 1]);
let r_input: [f32; 3] = [1.0, 1.0, 1.0];
let r = Array::new(&r_input, r_dims);
let ur = set_row(&a, &r, num_rows - 1);
af_print!("Set last row to 1's", ur);
set_row(&mut a, &r, num_rows - 1);
af_print!("Set last row to 1's", a);

let d_dims = Dim4::new(&[2, 3, 1, 1]);
let d_input: [i32; 6] = [1, 2, 3, 4, 5, 6];
Expand Down
7 changes: 2 additions & 5 deletions src/arith/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,6 @@ mod op_assign {
use crate::array::Array;
use crate::index::{assign_gen, Indexer};
use crate::seq::Seq;
use std::mem;
use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign, ShlAssign, ShrAssign};

Expand All @@ -852,8 +851,7 @@ mod op_assign {
idxrs.set_index(&tmp_seq, n, Some(false));
}
let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
let tmp = assign_gen(self as &Array<A>, &idxrs, &opres);
let old = mem::replace(self, tmp);
assign_gen(self, &idxrs, &opres);
}
}
};
Expand Down Expand Up @@ -884,8 +882,7 @@ mod op_assign {
idxrs.set_index(&tmp_seq, n, Some(false));
}
let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
let tmp = assign_gen(self as &Array<A>, &idxrs, &opres);
let old = mem::replace(self, tmp);
assign_gen(self, &idxrs, &opres);
}
}
};
Expand Down
12 changes: 3 additions & 9 deletions src/defines.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
extern crate num;

use self::num::Complex;
use std::error::Error;
use std::fmt::Error as FmtError;
use std::fmt::{Display, Formatter};

Expand Down Expand Up @@ -80,13 +79,7 @@ impl Display for Backend {

impl Display for AfError {
fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> {
write!(f, "{}", self.description())
}
}

impl Error for AfError {
fn description(&self) -> &str {
match *self {
let text = match *self {
AfError::SUCCESS => "Function returned successfully",
AfError::ERR_NO_MEM => "System or Device ran out of memory",
AfError::ERR_DRIVER => "Error in the device driver",
Expand All @@ -104,7 +97,8 @@ impl Error for AfError {
AfError::ERR_NO_GFX => "This build of ArrayFire has no graphics support",
AfError::ERR_INTERNAL => "Error either in ArrayFire or in a project upstream",
AfError::ERR_UNKNOWN => "Unknown Error",
}
};
write!(f, "{}", text)
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ extern crate libc;
use self::libc::c_char;
use crate::defines::AfError;
use crate::util::{free_host, DimT, MutDimT};
use std::error::Error;
use std::ffi::CStr;
use std::ops::{Deref, DerefMut};
use std::sync::RwLock;
Expand Down Expand Up @@ -39,7 +38,7 @@ pub fn handle_error_general(error_code: AfError) {
AfError::SUCCESS => {} /* No-op */
_ => panic!(
"Error message: {}\nLast error: {}",
error_code.description(),
error_code,
get_last_error()
),
}
Expand All @@ -63,7 +62,7 @@ lazy_static! {
/// fn handle_error(error_code: AfError) {
/// match error_code {
/// AfError::SUCCESS => {}, /* No-op */
/// _ => panic!("Error message: {}", error_code.description()),
/// _ => panic!("Error message: {}", error_code),
/// }
/// }
///
Expand Down
78 changes: 35 additions & 43 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::util::{AfArray, AfIndex, DimT, HasAfEnum, MutAfArray, MutAfIndex};

use std::default::Default;
use std::marker::PhantomData;
use std::mem;

#[allow(dead_code)]
extern "C" {
Expand Down Expand Up @@ -276,7 +277,6 @@ where
/// print(&a);
/// print(&row(&a, 4));
/// ```
#[allow(dead_code)]
pub fn row<T>(input: &Array<T>, row_num: u64) -> Array<T>
where
T: HasAfEnum,
Expand All @@ -290,20 +290,18 @@ where
)
}

#[allow(dead_code)]
/// Set `row_num`^th row in `input` Array to a new Array `new_row`
pub fn set_row<T>(input: &Array<T>, new_row: &Array<T>, row_num: u64) -> Array<T>
/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: u64)
where
T: HasAfEnum,
{
let seqs = [
Seq::new(row_num as f64, row_num as f64, 1.0),
Seq::default(),
];
assign_seq(input, &seqs, new_row)
assign_seq(inout, &seqs, new_row)
}

#[allow(dead_code)]
/// Get an Array with all rows from `first` to `last` in the `input` Array
pub fn rows<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
where
Expand All @@ -315,14 +313,13 @@ where
)
}

#[allow(dead_code)]
/// Set rows from `first` to `last` in `input` Array with rows from Array `new_rows`
pub fn set_rows<T>(input: &Array<T>, new_rows: &Array<T>, first: u64, last: u64) -> Array<T>
/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: u64, last: u64)
where
T: HasAfEnum,
{
let seqs = [Seq::new(first as f64, last as f64, 1.0), Seq::default()];
assign_seq(input, &seqs, new_rows)
assign_seq(inout, &seqs, new_rows)
}

/// Extract `col_num` col from `input` Array
Expand All @@ -337,7 +334,6 @@ where
/// println!("Grab last col of the random matrix");
/// print(&col(&a, 4));
/// ```
#[allow(dead_code)]
pub fn col<T>(input: &Array<T>, col_num: u64) -> Array<T>
where
T: HasAfEnum,
Expand All @@ -351,20 +347,18 @@ where
)
}

#[allow(dead_code)]
/// Set `col_num`^th col in `input` Array to a new Array `new_col`
pub fn set_col<T>(input: &Array<T>, new_col: &Array<T>, col_num: u64) -> Array<T>
/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: u64)
where
T: HasAfEnum,
{
let seqs = [
Seq::default(),
Seq::new(col_num as f64, col_num as f64, 1.0),
];
assign_seq(input, &seqs, new_col)
assign_seq(inout, &seqs, new_col)
}

#[allow(dead_code)]
/// Get all cols from `first` to `last` in the `input` Array
pub fn cols<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
where
Expand All @@ -376,20 +370,18 @@ where
)
}

#[allow(dead_code)]
/// Set cols from `first` to `last` in `input` Array with cols from Array `new_cols`
pub fn set_cols<T>(input: &Array<T>, new_cols: &Array<T>, first: u64, last: u64) -> Array<T>
/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: u64, last: u64)
where
T: HasAfEnum,
{
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, 1.0)];
assign_seq(input, &seqs, new_cols)
assign_seq(inout, &seqs, new_cols)
}

#[allow(dead_code)]
/// Get `slice_num`^th slice from `input` Array
///
/// Note. Slices indicate that the indexing is along 3rd dimension
/// Slices indicate that the indexing is along 3rd dimension
pub fn slice<T>(input: &Array<T>, slice_num: u64) -> Array<T>
where
T: HasAfEnum,
Expand All @@ -402,11 +394,10 @@ where
index(input, &seqs)
}

#[allow(dead_code)]
/// Set slice `slice_num` in `input` Array to a new Array `new_slice`
/// Set slice `slice_num` in `inout` Array to a new Array `new_slice`
///
/// Slices indicate that the indexing is along 3rd dimension
pub fn set_slice<T>(input: &Array<T>, new_slice: &Array<T>, slice_num: u64) -> Array<T>
pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: u64)
where
T: HasAfEnum,
{
Expand All @@ -415,10 +406,9 @@ where
Seq::default(),
Seq::new(slice_num as f64, slice_num as f64, 1.0),
];
assign_seq(input, &seqs, new_slice)
assign_seq(inout, &seqs, new_slice)
}

#[allow(dead_code)]
/// Get slices from `first` to `last` in `input` Array
///
/// Slices indicate that the indexing is along 3rd dimension
Expand All @@ -434,11 +424,10 @@ where
index(input, &seqs)
}

#[allow(dead_code)]
/// Set `first` to `last` slices of `input` Array to a new Array `new_slices`
/// Set `first` to `last` slices of `inout` Array to a new Array `new_slices`
///
/// Slices indicate that the indexing is along 3rd dimension
pub fn set_slices<T>(input: &Array<T>, new_slices: &Array<T>, first: u64, last: u64) -> Array<T>
pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: u64, last: u64)
where
T: HasAfEnum,
{
Expand All @@ -447,7 +436,7 @@ where
Seq::default(),
Seq::new(first as f64, last as f64, 1.0),
];
assign_seq(input, &seqs, new_slices)
assign_seq(inout, &seqs, new_slices)
}

/// Lookup(hash) an Array using another Array
Expand Down Expand Up @@ -480,25 +469,26 @@ where
///
/// ```rust
/// use arrayfire::{constant, Dim4, Seq, assign_seq, print};
/// let a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
/// let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
/// let sub = assign_seq(&a, seqs, &b);
/// let mut a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
/// print(&a);
/// // 2.0 2.0 2.0
/// // 2.0 2.0 2.0
/// // 2.0 2.0 2.0
/// // 2.0 2.0 2.0
/// // 2.0 2.0 2.0
///
/// print(&sub);
/// let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
/// assign_seq(&mut a, seqs, &b);
///
/// print(&a);
/// // 2.0 2.0 2.0
/// // 1.0 1.0 1.0
/// // 1.0 1.0 1.0
/// // 1.0 1.0 1.0
/// // 2.0 2.0 2.0
/// ```
pub fn assign_seq<T: Copy, I>(lhs: &Array<I>, seqs: &[Seq<T>], rhs: &Array<I>) -> Array<I>
pub fn assign_seq<T: Copy, I>(lhs: &mut Array<I>, seqs: &[Seq<T>], rhs: &Array<I>)
where
c_double: From<T>,
I: HasAfEnum,
Expand All @@ -516,7 +506,8 @@ where
);
HANDLE_ERROR(AfError::from(err_val));
}
temp.into()
let modified = temp.into();
let _old_arr = mem::replace(lhs, modified);
}

/// Index an Array using any combination of Array's and Sequence's
Expand Down Expand Up @@ -574,7 +565,7 @@ where
/// let values: [f32; 3] = [1.0, 2.0, 3.0];
/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
/// let seq4gen = Seq::new(0.0, 2.0, 1.0);
/// let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
/// let mut a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
/// // [5 3 1 1]
/// // 0.0000 0.2190 0.3835
/// // 0.1315 0.0470 0.5194
Expand All @@ -588,16 +579,16 @@ where
/// idxrs.set_index(&indices, 0, None); // 2nd parameter is indexing dimension
/// idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd parameter indicates batch operation
///
/// let sub2 = assign_gen(&a, &idxrs, &b);
/// println!("a(indices, seq(0, 2, 1))"); print(&sub2);
/// assign_gen(&mut a, &idxrs, &b);
/// println!("a(indices, seq(0, 2, 1))"); print(&a);
/// // [5 3 1 1]
/// // 0.0000 0.2190 0.3835
/// // 2.0000 2.0000 2.0000
/// // 2.0000 2.0000 2.0000
/// // 2.0000 2.0000 2.0000
/// // 0.5328 0.9347 0.0535
/// ```
pub fn assign_gen<T>(lhs: &Array<T>, indices: &Indexer, rhs: &Array<T>) -> Array<T>
pub fn assign_gen<T>(lhs: &mut Array<T>, indices: &Indexer, rhs: &Array<T>)
where
T: HasAfEnum,
{
Expand All @@ -612,7 +603,8 @@ where
);
HANDLE_ERROR(AfError::from(err_val));
}
temp.into()
let modified = temp.into();
let _old_arr = mem::replace(lhs, modified);
}

#[repr(C)]
Expand Down
2 changes: 1 addition & 1 deletion tests/error_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ macro_rules! implement_handler {
pub fn $fn_name(error_code: AfError) {
match error_code {
AfError::SUCCESS => {} /* No-op */
_ => panic!("Error message: {}", error_code.description()),
_ => panic!("Error message: {}", error_code),
}
}
};
Expand Down