@@ -180,12 +180,12 @@ macro_rules! binary_func {
180
180
///
181
181
/// This is an element wise binary operation.
182
182
#[ allow( unused_mut) ]
183
- pub fn $fn_name( lhs: & Array , rhs: & Array ) -> Array {
183
+ pub fn $fn_name( lhs: & Array , rhs: & Array , batch : bool ) -> Array {
184
184
unsafe {
185
185
let mut temp: i64 = 0 ;
186
186
let err_val = $ffi_fn( & mut temp as MutAfArray ,
187
187
lhs. get( ) as AfArray , rhs. get( ) as AfArray ,
188
- 0 ) ;
188
+ batch as c_int ) ;
189
189
HANDLE_ERROR ( AfError :: from( err_val) ) ;
190
190
Array :: from( temp)
191
191
}
@@ -217,6 +217,8 @@ macro_rules! convertable_type_def {
217
217
)
218
218
}
219
219
220
+ convertable_type_def ! ( Complex <f64 >) ;
221
+ convertable_type_def ! ( Complex <f32 >) ;
220
222
convertable_type_def ! ( u64 ) ;
221
223
convertable_type_def ! ( i64 ) ;
222
224
convertable_type_def ! ( f64 ) ;
@@ -350,45 +352,33 @@ pub fn clamp<T, U> (input: &Array, arg1: &T, arg2: &U, batch: bool) -> Array
350
352
}
351
353
352
354
macro_rules! arith_scalar_func {
353
- ( $rust_type: ty, $op_name: ident, $fn_name: ident, $ffi_fn : ident ) => (
355
+ ( $rust_type: ty, $op_name: ident, $fn_name: ident) => (
354
356
impl <' f> $op_name<$rust_type> for & ' f Array {
355
357
type Output = Array ;
356
358
357
359
fn $fn_name( self , rhs: $rust_type) -> Array {
358
- let cnst_arr = constant( rhs, self . dims( ) ) ;
359
- unsafe {
360
- let mut temp: i64 = 0 ;
361
- let err_val = $ffi_fn( & mut temp as MutAfArray , self . get( ) as AfArray ,
362
- cnst_arr. get( ) as AfArray , 0 ) ;
363
- HANDLE_ERROR ( AfError :: from( err_val) ) ;
364
- Array :: from( temp)
365
- }
360
+ let temp = rhs. clone( ) ;
361
+ $fn_name( self , & temp, false )
366
362
}
367
363
}
368
364
369
365
impl $op_name<$rust_type> for Array {
370
366
type Output = Array ;
371
367
372
368
fn $fn_name( self , rhs: $rust_type) -> Array {
373
- let cnst_arr = constant( rhs, self . dims( ) ) ;
374
- unsafe {
375
- let mut temp: i64 = 0 ;
376
- let err_val = $ffi_fn( & mut temp as MutAfArray , self . get( ) as AfArray ,
377
- cnst_arr. get( ) as AfArray , 0 ) ;
378
- HANDLE_ERROR ( AfError :: from( err_val) ) ;
379
- Array :: from( temp)
380
- }
369
+ let temp = rhs. clone( ) ;
370
+ $fn_name( & self , & temp, false )
381
371
}
382
372
}
383
373
)
384
374
}
385
375
386
376
macro_rules! arith_scalar_spec {
387
377
( $ty_name: ty) => (
388
- arith_scalar_func!( $ty_name, Add , add, af_add ) ;
389
- arith_scalar_func!( $ty_name, Sub , sub, af_sub ) ;
390
- arith_scalar_func!( $ty_name, Mul , mul, af_mul ) ;
391
- arith_scalar_func!( $ty_name, Div , div, af_div ) ;
378
+ arith_scalar_func!( $ty_name, Add , add) ;
379
+ arith_scalar_func!( $ty_name, Sub , sub) ;
380
+ arith_scalar_func!( $ty_name, Mul , mul) ;
381
+ arith_scalar_func!( $ty_name, Div , div) ;
392
382
)
393
383
}
394
384
@@ -403,33 +393,51 @@ arith_scalar_spec!(i32);
403
393
arith_scalar_spec ! ( u8 ) ;
404
394
405
395
macro_rules! arith_func {
406
- ( $op_name: ident, $fn_name: ident, $ffi_fn : ident) => (
396
+ ( $op_name: ident, $fn_name: ident, $delegate : ident) => (
407
397
impl $op_name<Array > for Array {
408
398
type Output = Array ;
409
399
410
400
fn $fn_name( self , rhs: Array ) -> Array {
411
- unsafe {
412
- let mut temp: i64 = 0 ;
413
- let err_val = $ffi_fn( & mut temp as MutAfArray ,
414
- self . get( ) as AfArray , rhs. get( ) as AfArray , 0 ) ;
415
- HANDLE_ERROR ( AfError :: from( err_val) ) ;
416
- Array :: from( temp)
417
- }
401
+ $delegate( & self , & rhs, false )
402
+ }
403
+ }
404
+
405
+ impl <' a> $op_name<& ' a Array > for Array {
406
+ type Output = Array ;
407
+
408
+ fn $fn_name( self , rhs: & ' a Array ) -> Array {
409
+ $delegate( & self , rhs, false )
410
+ }
411
+ }
412
+
413
+ impl <' a> $op_name<Array > for & ' a Array {
414
+ type Output = Array ;
415
+
416
+ fn $fn_name( self , rhs: Array ) -> Array {
417
+ $delegate( self , & rhs, false )
418
+ }
419
+ }
420
+
421
+ impl <' a, ' b> $op_name<& ' a Array > for & ' b Array {
422
+ type Output = Array ;
423
+
424
+ fn $fn_name( self , rhs: & ' a Array ) -> Array {
425
+ $delegate( self , rhs, false )
418
426
}
419
427
}
420
428
)
421
429
}
422
430
423
- arith_func ! ( Add , add, af_add ) ;
424
- arith_func ! ( Sub , sub, af_sub ) ;
425
- arith_func ! ( Mul , mul, af_mul ) ;
426
- arith_func ! ( Div , div, af_div ) ;
427
- arith_func ! ( Rem , rem, af_rem ) ;
428
- arith_func ! ( BitAnd , bitand , af_bitand ) ;
429
- arith_func ! ( BitOr , bitor , af_bitor ) ;
430
- arith_func ! ( BitXor , bitxor , af_bitxor ) ;
431
- arith_func ! ( Shl , shl , af_bitshiftl ) ;
432
- arith_func ! ( Shr , shr , af_bitshiftr ) ;
431
+ arith_func ! ( Add , add , add ) ;
432
+ arith_func ! ( Sub , sub , sub ) ;
433
+ arith_func ! ( Mul , mul , mul ) ;
434
+ arith_func ! ( Div , div , div ) ;
435
+ arith_func ! ( Rem , rem , rem ) ;
436
+ arith_func ! ( Shl , shl , shiftl ) ;
437
+ arith_func ! ( Shr , shr , shiftr ) ;
438
+ arith_func ! ( BitAnd , bitand , bitand ) ;
439
+ arith_func ! ( BitOr , bitor , bitor ) ;
440
+ arith_func ! ( BitXor , bitxor , bitxor ) ;
433
441
434
442
#[ cfg( op_assign) ]
435
443
mod op_assign {
@@ -477,7 +485,7 @@ macro_rules! bit_assign_func {
477
485
let mut idxrs = Indexer :: new( ) ;
478
486
idxrs. set_index( & Seq :: <f32 >:: default ( ) , 0 , Some ( false ) ) ;
479
487
idxrs. set_index( & Seq :: <f32 >:: default ( ) , 1 , Some ( false ) ) ;
480
- let tmp = assign_gen( self as & Array , & idxrs, & $func( self as & Array , & rhs) ) ;
488
+ let tmp = assign_gen( self as & Array , & idxrs, & $func( self as & Array , & rhs, false ) ) ;
481
489
mem:: replace( self , tmp) ;
482
490
}
483
491
}
0 commit comments