Skip to content

Commit 10489e8

Browse files
committed
STRIDES minor fixes
1 parent 908e897 commit 10489e8

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,12 +391,12 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
391391
DPNPC_ptr_adapter<size_t> input1_strides_ptr(input1_strides, input1_ndim); \
392392
\
393393
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size); \
394-
DPNPC_ptr_adapter<size_t> input2_shape_ptr(input2_shape, input2_ndim); \
395-
DPNPC_ptr_adapter<size_t> input2_strides_ptr(input2_strides, input2_ndim); \
394+
DPNPC_ptr_adapter<size_t> input2_shape_ptr(input2_shape, input2_ndim, true); \
395+
DPNPC_ptr_adapter<size_t> input2_strides_ptr(input2_strides, input2_ndim, true); \
396396
\
397-
DPNPC_ptr_adapter<_DataType_output> result_ptr(result_out, result_size); \
398-
DPNPC_ptr_adapter<size_t> result_shape_ptr(result_shape, result_ndim); \
399-
DPNPC_ptr_adapter<size_t> result_strides_ptr(result_strides, result_ndim); \
397+
DPNPC_ptr_adapter<_DataType_output> result_ptr(result_out, result_size, false, true); \
398+
DPNPC_ptr_adapter<size_t> result_shape_ptr(result_shape, result_ndim, true); \
399+
DPNPC_ptr_adapter<size_t> result_strides_ptr(result_strides, result_ndim, true); \
400400
\
401401
_DataType_input1* input1_data = input1_ptr.get_ptr(); \
402402
size_t* input1_shape_data = input1_shape_ptr.get_ptr(); \
@@ -411,7 +411,18 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
411411
size_t* result_strides_data = result_strides_ptr.get_ptr(); \
412412
\
413413
bool use_broadcasting = !array_equal(input1_shape_data, input1_ndim, input2_shape_data, input2_ndim); \
414-
bool use_strides = !array_equal(input1_strides_data, input1_ndim, input2_strides_data, input2_ndim); \
414+
\
415+
const size_t input1_shape_size_in_bytes = input1_ndim * sizeof(size_t); \
416+
size_t* input1_shape_offsets = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(input1_shape_size_in_bytes)); \
417+
get_shape_offsets_inkernel(input1_shape_data, input1_ndim, input1_shape_offsets); \
418+
bool use_strides = !array_equal(input1_strides_data, input1_ndim, input1_shape_offsets, input1_ndim); \
419+
dpnp_memory_free_c(input1_shape_offsets); \
420+
\
421+
const size_t input2_shape_size_in_bytes = input2_ndim * sizeof(size_t); \
422+
size_t* input2_shape_offsets = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(input2_shape_size_in_bytes)); \
423+
get_shape_offsets_inkernel(input2_shape_data, input2_ndim, input2_shape_offsets); \
424+
use_strides = use_strides || !array_equal(input2_strides_data, input2_ndim, input2_shape_offsets, input2_ndim);\
425+
dpnp_memory_free_c(input2_shape_offsets); \
415426
\
416427
cl::sycl::event event; \
417428
cl::sycl::range<1> gws(result_size); \

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ ctypedef void(*fptr_1out_t)(void * , size_t)
238238
ctypedef void(*fptr_1in_1out_t)(void *, void * , size_t)
239239
ctypedef void(*fptr_2in_1out_t)(void * , const void * , const size_t, const long * , const size_t,
240240
const void *, const size_t, const long * , const size_t, const long * )
241-
ctypedef void(*fptr_2in_1out__strides_t)(void *, const size_t, const size_t, const long * , const long * ,
241+
ctypedef void(*fptr_2in_1out_strides_t)(void *, const size_t, const size_t, const long * , const long * ,
242242
void *, const size_t, const size_t, const long * , const long * ,
243243
void *, const size_t, const size_t, const long * , const long * ,
244244
const long * )

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
393393
cdef shape_type_c result_strides = strides_to_vector(result.strides, result_shape)
394394

395395
""" Call FPTR function """
396-
cdef fptr_2in_1out__strides_t func = <fptr_2in_1out__strides_t > kernel_data.ptr
396+
cdef fptr_2in_1out_strides_t func = <fptr_2in_1out_strides_t > kernel_data.ptr
397397
func(result.get_data(),
398398
result.size,
399399
result.ndim,

0 commit comments

Comments
 (0)