Skip to content

Commit cd24184

Browse files
authored
Excess memcpy to shared memory in elementwise and bitwise functions (#1328)
1 parent 2224ce2 commit cd24184

File tree

3 files changed

+130
-99
lines changed

3 files changed

+130
-99
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ env:
1212
CHANNELS: '-c dppy/label/dev -c intel -c main --override-channels'
1313
TEST_SCOPE: >-
1414
test_arraycreation.py
15+
test_dot.py
1516
test_dparray.py
1617
test_fft.py
1718
test_linalg.py

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2016-2020, Intel Corporation
2+
// Copyright (c) 2016-2023, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -148,53 +148,62 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
148148
\
149149
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); \
150150
\
151-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, input1_size); \
152-
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
153-
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
151+
_DataType* input1_data = static_cast<_DataType*>(const_cast<void*>(input1_in)); \
152+
_DataType* input2_data = static_cast<_DataType*>(const_cast<void*>(input2_in)); \
153+
_DataType* result = static_cast<_DataType*>(result_out); \
154154
\
155-
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, input2_in, input2_size); \
156-
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(q_ref, input2_shape, input2_ndim, true); \
157-
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(q_ref, input2_strides, input2_ndim, true); \
155+
shape_elem_type* input1_shape_offsets = new shape_elem_type[input1_ndim]; \
158156
\
159-
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref, result_out, result_size, false, true); \
160-
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
157+
get_shape_offsets_inkernel(input1_shape, input1_ndim, input1_shape_offsets); \
158+
bool use_strides = !array_equal(input1_strides, input1_ndim, input1_shape_offsets, input1_ndim); \
159+
delete[] input1_shape_offsets; \
161160
\
162-
_DataType* input1_data = input1_ptr.get_ptr(); \
163-
shape_elem_type* input1_shape_data = input1_shape_ptr.get_ptr(); \
164-
shape_elem_type* input1_strides_data = input1_strides_ptr.get_ptr(); \
161+
shape_elem_type* input2_shape_offsets = new shape_elem_type[input2_ndim]; \
165162
\
166-
_DataType* input2_data = input2_ptr.get_ptr(); \
167-
shape_elem_type* input2_shape_data = input2_shape_ptr.get_ptr(); \
168-
shape_elem_type* input2_strides_data = input2_strides_ptr.get_ptr(); \
169-
\
170-
_DataType* result = result_ptr.get_ptr(); \
171-
shape_elem_type* result_strides_data = result_strides_ptr.get_ptr(); \
172-
\
173-
const size_t input1_shape_size_in_bytes = input1_ndim * sizeof(shape_elem_type); \
174-
shape_elem_type* input1_shape_offsets = \
175-
reinterpret_cast<shape_elem_type*>(sycl::malloc_shared(input1_shape_size_in_bytes, q)); \
176-
get_shape_offsets_inkernel(input1_shape_data, input1_ndim, input1_shape_offsets); \
177-
bool use_strides = !array_equal(input1_strides_data, input1_ndim, input1_shape_offsets, input1_ndim); \
178-
sycl::free(input1_shape_offsets, q); \
179-
\
180-
const size_t input2_shape_size_in_bytes = input2_ndim * sizeof(shape_elem_type); \
181-
shape_elem_type* input2_shape_offsets = \
182-
reinterpret_cast<shape_elem_type*>(sycl::malloc_shared(input2_shape_size_in_bytes, q)); \
183-
get_shape_offsets_inkernel(input2_shape_data, input2_ndim, input2_shape_offsets); \
184-
use_strides = \
185-
use_strides || !array_equal(input2_strides_data, input2_ndim, input2_shape_offsets, input2_ndim); \
186-
sycl::free(input2_shape_offsets, q); \
163+
get_shape_offsets_inkernel(input2_shape, input2_ndim, input2_shape_offsets); \
164+
use_strides = use_strides || !array_equal(input2_strides, input2_ndim, input2_shape_offsets, input2_ndim); \
165+
delete[] input2_shape_offsets; \
187166
\
188167
sycl::event event; \
189168
sycl::range<1> gws(result_size); \
190169
\
191170
if (use_strides) \
192171
{ \
172+
if ((result_ndim != input1_ndim) || (result_ndim != input2_ndim)) \
173+
{ \
174+
throw std::runtime_error("Result ndim=" + std::to_string(result_ndim) + \
175+
" mismatches with either input1 ndim=" + std::to_string(input1_ndim) + \
176+
" or input2 ndim=" + std::to_string(input2_ndim)); \
177+
} \
178+
\
179+
/* memory transfer optimization, use USM-host for temporary speeds up tranfer to device */ \
180+
using usm_host_allocatorT = sycl::usm_allocator<shape_elem_type, sycl::usm::alloc::host>; \
181+
\
182+
size_t strides_size = 3 * result_ndim; \
183+
shape_elem_type* dev_strides_data = sycl::malloc_device<shape_elem_type>(strides_size, q); \
184+
\
185+
/* create host temporary for packed strides managed by shared pointer */ \
186+
auto strides_host_packed = \
187+
std::vector<shape_elem_type, usm_host_allocatorT>(strides_size, usm_host_allocatorT(q)); \
188+
\
189+
/* packed vector is concatenation of result_strides, input1_strides and input2_strides */ \
190+
std::copy(result_strides, result_strides + result_ndim, strides_host_packed.begin()); \
191+
std::copy(input1_strides, input1_strides + result_ndim, strides_host_packed.begin() + result_ndim); \
192+
std::copy(input2_strides, input2_strides + result_ndim, strides_host_packed.begin() + 2 * result_ndim); \
193+
\
194+
auto copy_strides_ev = \
195+
q.copy<shape_elem_type>(strides_host_packed.data(), dev_strides_data, strides_host_packed.size()); \
196+
\
193197
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { \
194-
const size_t output_id = global_id[0]; /*for (size_t i = 0; i < result_size; ++i)*/ \
198+
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
195199
{ \
200+
const shape_elem_type* result_strides_data = &dev_strides_data[0]; \
201+
const shape_elem_type* input1_strides_data = &dev_strides_data[1]; \
202+
const shape_elem_type* input2_strides_data = &dev_strides_data[2]; \
203+
\
196204
size_t input1_id = 0; \
197205
size_t input2_id = 0; \
206+
\
198207
for (size_t i = 0; i < result_ndim; ++i) \
199208
{ \
200209
const size_t output_xyz_id = \
@@ -209,14 +218,19 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
209218
} \
210219
}; \
211220
auto kernel_func = [&](sycl::handler& cgh) { \
221+
cgh.depends_on(copy_strides_ev); \
212222
cgh.parallel_for<class __name__##_strides_kernel<_DataType>>(gws, kernel_parallel_for_func); \
213223
}; \
214-
event = q.submit(kernel_func); \
224+
\
225+
q.submit(kernel_func).wait(); \
226+
\
227+
sycl::free(dev_strides_data, q); \
228+
return event_ref; \
215229
} \
216230
else \
217231
{ \
218232
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { \
219-
size_t i = global_id[0]; /*for (size_t i = 0; i < result_size; ++i)*/ \
233+
size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
220234
const _DataType input1_elem = (input1_size == 1) ? input1_data[0] : input1_data[i]; \
221235
const _DataType input2_elem = (input2_size == 1) ? input2_data[0] : input2_data[i]; \
222236
result[i] = __operation__; \
@@ -226,16 +240,8 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
226240
}; \
227241
event = q.submit(kernel_func); \
228242
} \
229-
input1_ptr.depends_on(event); \
230-
input1_shape_ptr.depends_on(event); \
231-
input1_strides_ptr.depends_on(event); \
232-
input2_ptr.depends_on(event); \
233-
input2_shape_ptr.depends_on(event); \
234-
input2_strides_ptr.depends_on(event); \
235-
result_ptr.depends_on(event); \
236-
result_strides_ptr.depends_on(event); \
237-
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
238243
\
244+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
239245
return DPCTLEvent_Copy(event_ref); \
240246
} \
241247
\
@@ -278,6 +284,7 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
278284
where, \
279285
dep_event_vec_ref); \
280286
DPCTLEvent_WaitAndThrow(event_ref); \
287+
DPCTLEvent_Delete(event_ref); \
281288
} \
282289
\
283290
template <typename _DataType> \

0 commit comments

Comments
 (0)