Skip to content

Commit d6fc0c3

Browse files
authored
Merge pull request #104 from 9prady9/Array_arith_traits
Additional arithmetic ops/traits implementations for Array and &Array
2 parents 456a335 + 3ba93c8 commit d6fc0c3

File tree

1 file changed

+50
-42
lines changed

1 file changed

+50
-42
lines changed

src/arith/mod.rs

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,12 @@ macro_rules! binary_func {
180180
///
181181
/// This is an element wise binary operation.
182182
#[allow(unused_mut)]
183-
pub fn $fn_name(lhs: &Array, rhs: &Array) -> Array {
183+
pub fn $fn_name(lhs: &Array, rhs: &Array, batch: bool) -> Array {
184184
unsafe {
185185
let mut temp: i64 = 0;
186186
let err_val = $ffi_fn(&mut temp as MutAfArray,
187187
lhs.get() as AfArray, rhs.get() as AfArray,
188-
0);
188+
batch as c_int);
189189
HANDLE_ERROR(AfError::from(err_val));
190190
Array::from(temp)
191191
}
@@ -217,6 +217,8 @@ macro_rules! convertable_type_def {
217217
)
218218
}
219219

220+
convertable_type_def!(Complex<f64>);
221+
convertable_type_def!(Complex<f32>);
220222
convertable_type_def!(u64);
221223
convertable_type_def!(i64);
222224
convertable_type_def!(f64);
@@ -350,45 +352,33 @@ pub fn clamp<T, U> (input: &Array, arg1: &T, arg2: &U, batch: bool) -> Array
350352
}
351353

352354
macro_rules! arith_scalar_func {
353-
($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => (
355+
($rust_type: ty, $op_name:ident, $fn_name: ident) => (
354356
impl<'f> $op_name<$rust_type> for &'f Array {
355357
type Output = Array;
356358

357359
fn $fn_name(self, rhs: $rust_type) -> Array {
358-
let cnst_arr = constant(rhs, self.dims());
359-
unsafe {
360-
let mut temp: i64 = 0;
361-
let err_val = $ffi_fn(&mut temp as MutAfArray, self.get() as AfArray,
362-
cnst_arr.get() as AfArray, 0);
363-
HANDLE_ERROR(AfError::from(err_val));
364-
Array::from(temp)
365-
}
360+
let temp = rhs.clone();
361+
$fn_name(self, &temp, false)
366362
}
367363
}
368364

369365
impl $op_name<$rust_type> for Array {
370366
type Output = Array;
371367

372368
fn $fn_name(self, rhs: $rust_type) -> Array {
373-
let cnst_arr = constant(rhs, self.dims());
374-
unsafe {
375-
let mut temp: i64 = 0;
376-
let err_val = $ffi_fn(&mut temp as MutAfArray, self.get() as AfArray,
377-
cnst_arr.get() as AfArray, 0);
378-
HANDLE_ERROR(AfError::from(err_val));
379-
Array::from(temp)
380-
}
369+
let temp = rhs.clone();
370+
$fn_name(&self, &temp, false)
381371
}
382372
}
383373
)
384374
}
385375

386376
macro_rules! arith_scalar_spec {
387377
($ty_name:ty) => (
388-
arith_scalar_func!($ty_name, Add, add, af_add);
389-
arith_scalar_func!($ty_name, Sub, sub, af_sub);
390-
arith_scalar_func!($ty_name, Mul, mul, af_mul);
391-
arith_scalar_func!($ty_name, Div, div, af_div);
378+
arith_scalar_func!($ty_name, Add, add);
379+
arith_scalar_func!($ty_name, Sub, sub);
380+
arith_scalar_func!($ty_name, Mul, mul);
381+
arith_scalar_func!($ty_name, Div, div);
392382
)
393383
}
394384

@@ -403,33 +393,51 @@ arith_scalar_spec!(i32);
403393
arith_scalar_spec!(u8);
404394

405395
macro_rules! arith_func {
406-
($op_name:ident, $fn_name:ident, $ffi_fn: ident) => (
396+
($op_name:ident, $fn_name:ident, $delegate:ident) => (
407397
impl $op_name<Array> for Array {
408398
type Output = Array;
409399

410400
fn $fn_name(self, rhs: Array) -> Array {
411-
unsafe {
412-
let mut temp: i64 = 0;
413-
let err_val = $ffi_fn(&mut temp as MutAfArray,
414-
self.get() as AfArray, rhs.get() as AfArray, 0);
415-
HANDLE_ERROR(AfError::from(err_val));
416-
Array::from(temp)
417-
}
401+
$delegate(&self, &rhs, false)
402+
}
403+
}
404+
405+
impl<'a> $op_name<&'a Array> for Array {
406+
type Output = Array;
407+
408+
fn $fn_name(self, rhs: &'a Array) -> Array {
409+
$delegate(&self, rhs, false)
410+
}
411+
}
412+
413+
impl<'a> $op_name<Array> for &'a Array {
414+
type Output = Array;
415+
416+
fn $fn_name(self, rhs: Array) -> Array {
417+
$delegate(self, &rhs, false)
418+
}
419+
}
420+
421+
impl<'a, 'b> $op_name<&'a Array> for &'b Array {
422+
type Output = Array;
423+
424+
fn $fn_name(self, rhs: &'a Array) -> Array {
425+
$delegate(self, rhs, false)
418426
}
419427
}
420428
)
421429
}
422430

423-
arith_func!(Add, add, af_add);
424-
arith_func!(Sub, sub, af_sub);
425-
arith_func!(Mul, mul, af_mul);
426-
arith_func!(Div, div, af_div);
427-
arith_func!(Rem, rem, af_rem);
428-
arith_func!(BitAnd, bitand, af_bitand);
429-
arith_func!(BitOr, bitor, af_bitor);
430-
arith_func!(BitXor, bitxor, af_bitxor);
431-
arith_func!(Shl, shl, af_bitshiftl);
432-
arith_func!(Shr, shr, af_bitshiftr);
431+
arith_func!(Add , add , add );
432+
arith_func!(Sub , sub , sub );
433+
arith_func!(Mul , mul , mul );
434+
arith_func!(Div , div , div );
435+
arith_func!(Rem , rem , rem );
436+
arith_func!(Shl , shl , shiftl);
437+
arith_func!(Shr , shr , shiftr);
438+
arith_func!(BitAnd, bitand, bitand);
439+
arith_func!(BitOr , bitor , bitor );
440+
arith_func!(BitXor, bitxor, bitxor);
433441

434442
#[cfg(op_assign)]
435443
mod op_assign {
@@ -477,7 +485,7 @@ macro_rules! bit_assign_func {
477485
let mut idxrs = Indexer::new();
478486
idxrs.set_index(&Seq::<f32>::default(), 0, Some(false));
479487
idxrs.set_index(&Seq::<f32>::default(), 1, Some(false));
480-
let tmp = assign_gen(self as &Array, &idxrs, & $func(self as &Array, &rhs));
488+
let tmp = assign_gen(self as &Array, &idxrs, & $func(self as &Array, &rhs, false));
481489
mem::replace(self, tmp);
482490
}
483491
}

0 commit comments

Comments
 (0)