Skip to content

GEMM to use trans only if matrix if not C-contig #1160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 64 additions & 38 deletions dpnp/backend/kernels/dpnp_krnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,49 +323,75 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
// check if GEMM can be executed (strides)
// TODO: rewrite the condition in general case for ndims > 2
// (looks like there are such another cases)
if ((ext_input1_ndim == 2 && ext_input2_ndim == 2) &&
(ext_input1_strides[0] == 1 || ext_input1_strides[1] == 1) &&
(ext_input2_strides[0] == 1 || ext_input2_strides[1] == 1))

if (ext_input1_ndim == 2 && ext_input2_ndim == 2)
{
// there is a difference of behavior with trans and sizes params in previous version of GEMM
// only new version is supported, in case of old version computation goes in common way
#if INTEL_MKL_VERSION >= 20210004
oneapi::mkl::transpose trans1 =
ext_input1_strides[0] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
oneapi::mkl::transpose trans2 =
ext_input2_strides[0] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;

const size_t size_m = ext_input1_shape[0];
const size_t size_n = ext_input2_shape[1];
const size_t size_k = ext_input1_shape[1];

const std::int64_t lda =
trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0] : ext_input1_strides[1];
const std::int64_t ldb =
trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0] : ext_input2_strides[1];
;
// defenition of ldc will be another for result with non-standard (c-contiguous) strides
// const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
const std::int64_t ldc = size_n;

sycl::event event = mkl_blas_rm::gemm(q,
trans1,
trans2,
size_m,
size_n,
size_k,
_DataType_output(1), // alpha
input1,
lda,
input2,
ldb,
_DataType_output(0), // beta
result,
ldc);
event.wait();
return event_ref;
// is mat1 F-contiguous, C-contiguous
bool mat1_f_contig = (
((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == 1)) &&
((ext_input1_shape[1] == 1) || (ext_input1_strides[1] == ext_input1_shape[0])));
bool mat1_c_contig = (
((ext_input1_shape[1] == 1) || (ext_input1_strides[1] == 1)) &&
((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == ext_input1_shape[1])));
// is mat2 F-contiguous, C-contiguous
bool mat2_f_contig = (
((ext_input2_shape[0] == 1) || (ext_input2_strides[0] == 1)) &&
((ext_input2_shape[1] == 1) || (ext_input2_strides[1] == ext_input2_shape[0])));
bool mat2_c_contig = (
((ext_input2_shape[1] == 1) || (ext_input2_strides[1] == 1)) &&
((ext_input2_shape[0] == 1) || (ext_input2_strides[0] == ext_input2_shape[1])));

if ((mat1_f_contig || mat1_c_contig) && (mat2_f_contig || mat2_c_contig)) {
oneapi::mkl::transpose trans1 =
(mat1_f_contig && !mat1_c_contig) ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
oneapi::mkl::transpose trans2 =
(mat2_f_contig && !mat2_c_contig) ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;

const size_t size_m = ext_input1_shape[0];
const size_t size_n = ext_input2_shape[1];
const size_t size_k = ext_input1_shape[1];

const std::int64_t lda =
trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0] : ext_input1_strides[1];
const std::int64_t ldb =
trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0] : ext_input2_strides[1];

// definition of ldc will be another for result with non-standard (c-contiguous) strides
// const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
const std::int64_t ldc = size_n;

try {
sycl::event event = mkl_blas_rm::gemm(q,
trans1,
trans2,
size_m,
size_n,
size_k,
_DataType_output(1), // alpha
input1,
lda,
input2,
ldb,
_DataType_output(0), // beta
result,
ldc);
event.wait();
delete[] ext_input1_shape;
delete[] ext_input1_strides;
delete[] ext_input2_shape;
delete[] ext_input2_strides;
delete[] ext_result_shape;

return event_ref;
} catch (const std::exception &e) {
// do nothing, proceed to general case
}
#endif
}
}
}
}

std::vector<sycl::event> dot_events;
Expand Down