Skip to content

Commit 326c451

Browse files
authored
Removed extra copy for transpose arrays in dot() (#1477)
* Removed extra copy for strided arrays in dot() * Added support of strided arrays * Added support of strided out array * Fix handling of 1d and 2d arrays
1 parent 771653b commit 326c451

File tree

4 files changed

+94
-77
lines changed

4 files changed

+94
-77
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ env:
2222
test_special.py
2323
test_umath.py
2424
test_usm_type.py
25+
third_party/cupy/linalg_tests/test_product.py
2526
third_party/cupy/math_tests/test_explog.py
2627
third_party/cupy/math_tests/test_misc.py
2728
third_party/cupy/math_tests/test_trigonometric.py

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 87 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <dpnp_iface.hpp>
3535

3636
namespace mkl_blas = oneapi::mkl::blas;
37+
namespace mkl_blas_cm = oneapi::mkl::blas::column_major;
3738
namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
3839
namespace mkl_lapack = oneapi::mkl::lapack;
3940

@@ -227,12 +228,10 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
227228
DPCTLSyclEventRef event_ref = nullptr;
228229
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
229230

230-
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(q_ref, input1_in,
231-
input1_size);
232-
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(q_ref, input2_in,
233-
input2_size);
234-
_DataType_input1 *input1 = input1_ptr.get_ptr();
235-
_DataType_input2 *input2 = input2_ptr.get_ptr();
231+
_DataType_input1 *input1 =
232+
static_cast<_DataType_input1 *>(const_cast<void *>(input1_in));
233+
_DataType_input2 *input2 =
234+
static_cast<_DataType_input2 *>(const_cast<void *>(input2_in));
236235
_DataType_output *result = reinterpret_cast<_DataType_output *>(result_out);
237236

238237
if (!input1_size || !input2_size) {
@@ -257,10 +256,12 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
257256
// if both arrays are vectors
258257
if ((input1_ndim == 1) && (input2_ndim == 1)) {
259258
assert(input1_size == input2_size);
259+
260260
sycl::event event = dot(q, result, input1, input2, input1_strides[0],
261261
input2_strides[0], input1_size);
262-
event.wait();
263-
return event_ref;
262+
263+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
264+
return DPCTLEvent_Copy(event_ref);
264265
}
265266

266267
// 1D vector
@@ -297,13 +298,17 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
297298
size_t ext_result_ndim =
298299
((input1_ndim == 1) || (input2_ndim == 1)) ? 2 : result_ndim;
299300
shape_elem_type *ext_result_shape = new shape_elem_type[ext_result_ndim];
301+
shape_elem_type *ext_result_strides = new shape_elem_type[ext_result_ndim];
300302
if ((input1_ndim == 1) || (input2_ndim == 1)) {
301303
ext_result_shape[0] = ext_input1_shape[0];
302304
ext_result_shape[1] = ext_input2_shape[1];
305+
ext_result_strides[0] = 0;
306+
ext_result_strides[1] = result_strides[0];
303307
}
304308
else {
305309
for (size_t i = 0; i < ext_result_ndim; ++i) {
306310
ext_result_shape[i] = result_shape[i];
311+
ext_result_strides[i] = result_strides[i];
307312
}
308313
}
309314

@@ -316,80 +321,89 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
316321
// check if GEMM can be executed (strides)
317322
// TODO: rewrite the condition in general case for ndims > 2
318323
// (looks like there are such another cases)
319-
320324
if (ext_input1_ndim == 2 && ext_input2_ndim == 2) {
321-
// there is a difference of behavior with trans and sizes params in previous
322-
// version of GEMM only new version is supported, in case of old version
323-
// computation goes in common way
324-
#if INTEL_MKL_VERSION >= 20210004
325-
// is mat1 F-contiguous, C-contiguous
326-
bool mat1_f_contig =
327-
(((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == 1)) &&
328-
((ext_input1_shape[1] == 1) ||
329-
(ext_input1_strides[1] == ext_input1_shape[0])));
330-
bool mat1_c_contig =
331-
(((ext_input1_shape[1] == 1) || (ext_input1_strides[1] == 1)) &&
332-
((ext_input1_shape[0] == 1) ||
333-
(ext_input1_strides[0] == ext_input1_shape[1])));
334-
// is mat2 F-contiguous, C-contiguous
335-
bool mat2_f_contig =
336-
(((ext_input2_shape[0] == 1) || (ext_input2_strides[0] == 1)) &&
337-
((ext_input2_shape[1] == 1) ||
338-
(ext_input2_strides[1] == ext_input2_shape[0])));
339-
bool mat2_c_contig =
340-
(((ext_input2_shape[1] == 1) || (ext_input2_strides[1] == 1)) &&
341-
((ext_input2_shape[0] == 1) ||
342-
(ext_input2_strides[0] == ext_input2_shape[1])));
343-
344-
if ((mat1_f_contig || mat1_c_contig) &&
345-
(mat2_f_contig || mat2_c_contig)) {
346-
oneapi::mkl::transpose trans1 =
347-
(mat1_f_contig && !mat1_c_contig)
348-
? oneapi::mkl::transpose::trans
349-
: oneapi::mkl::transpose::nontrans;
350-
oneapi::mkl::transpose trans2 =
351-
(mat2_f_contig && !mat2_c_contig)
352-
? oneapi::mkl::transpose::trans
353-
: oneapi::mkl::transpose::nontrans;
325+
// OneMKL gemm suports only arrays contiguous on inner dimension,
326+
// so stride for at least one dimension should be equal to 1
327+
if ((ext_input1_strides[0] == 1 || ext_input1_strides[1] == 1) &&
328+
(ext_input2_strides[0] == 1 || ext_input2_strides[1] == 1) &&
329+
(ext_result_strides[0] == 1 || ext_result_strides[1] == 1))
330+
{
331+
const bool isRowmA =
332+
(ext_input1_strides[1] == 1 || ext_input1_strides[0] == 0);
333+
const bool isRowmB =
334+
(ext_input2_strides[1] == 1 || ext_input2_strides[1] == 0);
335+
const bool isRowmC =
336+
(ext_result_strides[1] == 1 || ext_result_strides[0] == 0);
337+
338+
oneapi::mkl::transpose transA =
339+
(isRowmA != isRowmC) ? oneapi::mkl::transpose::trans
340+
: oneapi::mkl::transpose::nontrans;
341+
oneapi::mkl::transpose transB =
342+
(isRowmB != isRowmC) ? oneapi::mkl::transpose::trans
343+
: oneapi::mkl::transpose::nontrans;
354344

355345
const size_t size_m = ext_input1_shape[0];
356346
const size_t size_n = ext_input2_shape[1];
357347
const size_t size_k = ext_input1_shape[1];
358348

359-
const std::int64_t lda =
360-
trans1 == oneapi::mkl::transpose::nontrans
361-
? ext_input1_strides[0]
362-
: ext_input1_strides[1];
363-
const std::int64_t ldb =
364-
trans2 == oneapi::mkl::transpose::nontrans
365-
? ext_input2_strides[0]
366-
: ext_input2_strides[1];
367-
368-
// definition of ldc will be another for result with
369-
// non-standard (c-contiguous) strides const std::int64_t ldc =
370-
// result_strides[0] == 1 ? result_strides[1] :
371-
// result_strides[0];
372-
const std::int64_t ldc = size_n;
349+
auto getLdaLdc = [](const bool isRown, shape_elem_type *strides,
350+
shape_elem_type *shapes) {
351+
if (isRown) {
352+
return (strides[0] != 0) ? strides[0] : shapes[1];
353+
}
354+
return strides[1];
355+
};
356+
357+
const std::int64_t lda = static_cast<std::int64_t>(
358+
getLdaLdc(isRowmA, ext_input1_strides, ext_input1_shape));
359+
const std::int64_t ldb = static_cast<std::int64_t>(
360+
isRowmB ? ext_input2_strides[0] : ext_input2_strides[1]);
361+
const std::int64_t ldc = static_cast<std::int64_t>(
362+
getLdaLdc(isRowmC, ext_result_strides, ext_result_shape));
363+
364+
constexpr _DataType_output alpha = 1;
365+
constexpr _DataType_output beta = 0;
366+
367+
std::stringstream error_msg;
368+
std::int64_t info = 0;
373369

374370
try {
375-
sycl::event event = mkl_blas_rm::gemm(
376-
q, trans1, trans2, size_m, size_n, size_k,
377-
_DataType_output(1), // alpha
378-
input1, lda, input2, ldb,
379-
_DataType_output(0), // beta
380-
result, ldc);
381-
event.wait();
382-
delete[] ext_input1_shape;
383-
delete[] ext_input1_strides;
384-
delete[] ext_input2_shape;
385-
delete[] ext_input2_strides;
386-
delete[] ext_result_shape;
387-
388-
return event_ref;
371+
if (isRowmC) {
372+
mkl_blas_rm::gemm(q, transA, transB, size_m, size_n,
373+
size_k, alpha, input1, lda, input2,
374+
ldb, beta, result, ldc)
375+
.wait();
376+
}
377+
else {
378+
mkl_blas_cm::gemm(q, transA, transB, size_m, size_n,
379+
size_k, alpha, input1, lda, input2,
380+
ldb, beta, result, ldc)
381+
.wait();
382+
}
383+
} catch (mkl_lapack::exception const &e) {
384+
error_msg << "Unexpected MKL exception caught during "
385+
"gemm() call:\nreason: "
386+
<< e.what() << "\ninfo: " << e.info();
387+
info = e.info();
389388
} catch (const std::exception &e) {
390-
// do nothing, proceed to general case
389+
error_msg << "Unexpected SYCL exception caught during "
390+
"gemm() call:\n"
391+
<< e.what();
392+
info = -1;
391393
}
392-
#endif
394+
395+
if (info != 0) // an unexected error occurs
396+
{
397+
throw std::runtime_error(error_msg.str());
398+
}
399+
400+
delete[] ext_input1_shape;
401+
delete[] ext_input1_strides;
402+
delete[] ext_input2_shape;
403+
delete[] ext_input2_strides;
404+
delete[] ext_result_shape;
405+
delete[] ext_result_strides;
406+
return event_ref;
393407
}
394408
}
395409
}
@@ -437,6 +451,7 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
437451
delete[] ext_input2_shape;
438452
delete[] ext_input2_strides;
439453
delete[] ext_result_shape;
454+
delete[] ext_result_strides;
440455

441456
return event_ref;
442457
}

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_dot_t)(c_dpctl.DPCTLSyclQueueR
5555
const shape_elem_type *, const shape_elem_type * ,
5656
void * , const size_t, const size_t,
5757
const shape_elem_type *, const shape_elem_type * ,
58-
const c_dpctl.DPCTLEventVectorRef)
58+
const c_dpctl.DPCTLEventVectorRef) except +
5959
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQueueRef,
6060
void * , const size_t, const size_t,
6161
const shape_elem_type *, const shape_elem_type * ,

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,16 @@ def dot(x1, x2, out=None, **kwargs):
108108
else (None, None)
109109
)
110110

111-
# TODO: copy_when_strides=False (now it's done for faster implementation with transpose arrays)
112111
x1_desc = dpnp.get_dpnp_descriptor(
113112
x1,
114-
copy_when_strides=True,
113+
copy_when_strides=False,
115114
copy_when_nondefault_queue=False,
116115
alloc_usm_type=usm_type,
117116
alloc_queue=queue,
118117
)
119118
x2_desc = dpnp.get_dpnp_descriptor(
120119
x2,
121-
copy_when_strides=True,
120+
copy_when_strides=False,
122121
copy_when_nondefault_queue=False,
123122
alloc_usm_type=usm_type,
124123
alloc_queue=queue,
@@ -131,7 +130,9 @@ def dot(x1, x2, out=None, **kwargs):
131130
)
132131
out_desc = (
133132
dpnp.get_dpnp_descriptor(
134-
out, copy_when_nondefault_queue=False
133+
out,
134+
copy_when_strides=False,
135+
copy_when_nondefault_queue=False,
135136
)
136137
or None
137138
)

0 commit comments

Comments
 (0)