@@ -391,12 +391,12 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
391
391
DPNPC_ptr_adapter<size_t > input1_strides_ptr (input1_strides, input1_ndim); \
392
392
\
393
393
DPNPC_ptr_adapter<_DataType_input2> input2_ptr (input2_in, input2_size); \
394
- DPNPC_ptr_adapter<size_t > input2_shape_ptr (input2_shape, input2_ndim); \
395
- DPNPC_ptr_adapter<size_t > input2_strides_ptr (input2_strides, input2_ndim); \
394
+ DPNPC_ptr_adapter<size_t > input2_shape_ptr (input2_shape, input2_ndim, true ); \
395
+ DPNPC_ptr_adapter<size_t > input2_strides_ptr (input2_strides, input2_ndim, true ); \
396
396
\
397
- DPNPC_ptr_adapter<_DataType_output> result_ptr (result_out, result_size); \
398
- DPNPC_ptr_adapter<size_t > result_shape_ptr (result_shape, result_ndim); \
399
- DPNPC_ptr_adapter<size_t > result_strides_ptr (result_strides, result_ndim); \
397
+ DPNPC_ptr_adapter<_DataType_output> result_ptr (result_out, result_size, false , true ); \
398
+ DPNPC_ptr_adapter<size_t > result_shape_ptr (result_shape, result_ndim, true ); \
399
+ DPNPC_ptr_adapter<size_t > result_strides_ptr (result_strides, result_ndim, true ); \
400
400
\
401
401
_DataType_input1* input1_data = input1_ptr.get_ptr (); \
402
402
size_t * input1_shape_data = input1_shape_ptr.get_ptr (); \
@@ -411,7 +411,18 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
411
411
size_t * result_strides_data = result_strides_ptr.get_ptr (); \
412
412
\
413
413
bool use_broadcasting = !array_equal (input1_shape_data, input1_ndim, input2_shape_data, input2_ndim); \
414
- bool use_strides = !array_equal (input1_strides_data, input1_ndim, input2_strides_data, input2_ndim); \
414
+ \
415
+ const size_t input1_shape_size_in_bytes = input1_ndim * sizeof (size_t ); \
416
+ size_t * input1_shape_offsets = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (input1_shape_size_in_bytes)); \
417
+ get_shape_offsets_inkernel (input1_shape_data, input1_ndim, input1_shape_offsets); \
418
+ bool use_strides = !array_equal (input1_strides_data, input1_ndim, input1_shape_offsets, input1_ndim); \
419
+ dpnp_memory_free_c (input1_shape_offsets); \
420
+ \
421
+ const size_t input2_shape_size_in_bytes = input2_ndim * sizeof (size_t ); \
422
+ size_t * input2_shape_offsets = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (input2_shape_size_in_bytes)); \
423
+ get_shape_offsets_inkernel (input2_shape_data, input2_ndim, input2_shape_offsets); \
424
+ use_strides = use_strides || !array_equal (input2_strides_data, input2_ndim, input2_shape_offsets, input2_ndim);\
425
+ dpnp_memory_free_c (input2_shape_offsets); \
415
426
\
416
427
cl::sycl::event event; \
417
428
cl::sycl::range<1 > gws (result_size); \
0 commit comments