34
34
#include < dpnp_iface.hpp>
35
35
36
36
namespace mkl_blas = oneapi::mkl::blas;
37
+ namespace mkl_blas_cm = oneapi::mkl::blas::column_major;
37
38
namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
38
39
namespace mkl_lapack = oneapi::mkl::lapack;
39
40
@@ -227,12 +228,10 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
227
228
DPCTLSyclEventRef event_ref = nullptr ;
228
229
sycl::queue q = *(reinterpret_cast <sycl::queue *>(q_ref));
229
230
230
- DPNPC_ptr_adapter<_DataType_input1> input1_ptr (q_ref, input1_in,
231
- input1_size);
232
- DPNPC_ptr_adapter<_DataType_input2> input2_ptr (q_ref, input2_in,
233
- input2_size);
234
- _DataType_input1 *input1 = input1_ptr.get_ptr ();
235
- _DataType_input2 *input2 = input2_ptr.get_ptr ();
231
+ _DataType_input1 *input1 =
232
+ static_cast <_DataType_input1 *>(const_cast <void *>(input1_in));
233
+ _DataType_input2 *input2 =
234
+ static_cast <_DataType_input2 *>(const_cast <void *>(input2_in));
236
235
_DataType_output *result = reinterpret_cast <_DataType_output *>(result_out);
237
236
238
237
if (!input1_size || !input2_size) {
@@ -257,10 +256,12 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
257
256
// if both arrays are vectors
258
257
if ((input1_ndim == 1 ) && (input2_ndim == 1 )) {
259
258
assert (input1_size == input2_size);
259
+
260
260
sycl::event event = dot (q, result, input1, input2, input1_strides[0 ],
261
261
input2_strides[0 ], input1_size);
262
- event.wait ();
263
- return event_ref;
262
+
263
+ event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
264
+ return DPCTLEvent_Copy (event_ref);
264
265
}
265
266
266
267
// 1D vector
@@ -297,13 +298,17 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
297
298
size_t ext_result_ndim =
298
299
((input1_ndim == 1 ) || (input2_ndim == 1 )) ? 2 : result_ndim;
299
300
shape_elem_type *ext_result_shape = new shape_elem_type[ext_result_ndim];
301
+ shape_elem_type *ext_result_strides = new shape_elem_type[ext_result_ndim];
300
302
if ((input1_ndim == 1 ) || (input2_ndim == 1 )) {
301
303
ext_result_shape[0 ] = ext_input1_shape[0 ];
302
304
ext_result_shape[1 ] = ext_input2_shape[1 ];
305
+ ext_result_strides[0 ] = 0 ;
306
+ ext_result_strides[1 ] = result_strides[0 ];
303
307
}
304
308
else {
305
309
for (size_t i = 0 ; i < ext_result_ndim; ++i) {
306
310
ext_result_shape[i] = result_shape[i];
311
+ ext_result_strides[i] = result_strides[i];
307
312
}
308
313
}
309
314
@@ -316,80 +321,89 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
316
321
// check if GEMM can be executed (strides)
317
322
// TODO: rewrite the condition in general case for ndims > 2
318
323
// (looks like there are such another cases)
319
-
320
324
if (ext_input1_ndim == 2 && ext_input2_ndim == 2 ) {
321
- // there is a difference of behavior with trans and sizes params in previous
322
- // version of GEMM only new version is supported, in case of old version
323
- // computation goes in common way
324
- #if INTEL_MKL_VERSION >= 20210004
325
- // is mat1 F-contiguous, C-contiguous
326
- bool mat1_f_contig =
327
- (((ext_input1_shape[0 ] == 1 ) || (ext_input1_strides[0 ] == 1 )) &&
328
- ((ext_input1_shape[1 ] == 1 ) ||
329
- (ext_input1_strides[1 ] == ext_input1_shape[0 ])));
330
- bool mat1_c_contig =
331
- (((ext_input1_shape[1 ] == 1 ) || (ext_input1_strides[1 ] == 1 )) &&
332
- ((ext_input1_shape[0 ] == 1 ) ||
333
- (ext_input1_strides[0 ] == ext_input1_shape[1 ])));
334
- // is mat2 F-contiguous, C-contiguous
335
- bool mat2_f_contig =
336
- (((ext_input2_shape[0 ] == 1 ) || (ext_input2_strides[0 ] == 1 )) &&
337
- ((ext_input2_shape[1 ] == 1 ) ||
338
- (ext_input2_strides[1 ] == ext_input2_shape[0 ])));
339
- bool mat2_c_contig =
340
- (((ext_input2_shape[1 ] == 1 ) || (ext_input2_strides[1 ] == 1 )) &&
341
- ((ext_input2_shape[0 ] == 1 ) ||
342
- (ext_input2_strides[0 ] == ext_input2_shape[1 ])));
343
-
344
- if ((mat1_f_contig || mat1_c_contig) &&
345
- (mat2_f_contig || mat2_c_contig)) {
346
- oneapi::mkl::transpose trans1 =
347
- (mat1_f_contig && !mat1_c_contig)
348
- ? oneapi::mkl::transpose::trans
349
- : oneapi::mkl::transpose::nontrans;
350
- oneapi::mkl::transpose trans2 =
351
- (mat2_f_contig && !mat2_c_contig)
352
- ? oneapi::mkl::transpose::trans
353
- : oneapi::mkl::transpose::nontrans;
325
+ // OneMKL gemm suports only arrays contiguous on inner dimension,
326
+ // so stride for at least one dimension should be equal to 1
327
+ if ((ext_input1_strides[0 ] == 1 || ext_input1_strides[1 ] == 1 ) &&
328
+ (ext_input2_strides[0 ] == 1 || ext_input2_strides[1 ] == 1 ) &&
329
+ (ext_result_strides[0 ] == 1 || ext_result_strides[1 ] == 1 ))
330
+ {
331
+ const bool isRowmA =
332
+ (ext_input1_strides[1 ] == 1 || ext_input1_strides[0 ] == 0 );
333
+ const bool isRowmB =
334
+ (ext_input2_strides[1 ] == 1 || ext_input2_strides[1 ] == 0 );
335
+ const bool isRowmC =
336
+ (ext_result_strides[1 ] == 1 || ext_result_strides[0 ] == 0 );
337
+
338
+ oneapi::mkl::transpose transA =
339
+ (isRowmA != isRowmC) ? oneapi::mkl::transpose::trans
340
+ : oneapi::mkl::transpose::nontrans;
341
+ oneapi::mkl::transpose transB =
342
+ (isRowmB != isRowmC) ? oneapi::mkl::transpose::trans
343
+ : oneapi::mkl::transpose::nontrans;
354
344
355
345
const size_t size_m = ext_input1_shape[0 ];
356
346
const size_t size_n = ext_input2_shape[1 ];
357
347
const size_t size_k = ext_input1_shape[1 ];
358
348
359
- const std::int64_t lda =
360
- trans1 == oneapi::mkl::transpose::nontrans
361
- ? ext_input1_strides[0 ]
362
- : ext_input1_strides[1 ];
363
- const std::int64_t ldb =
364
- trans2 == oneapi::mkl::transpose::nontrans
365
- ? ext_input2_strides[0 ]
366
- : ext_input2_strides[1 ];
367
-
368
- // definition of ldc will be another for result with
369
- // non-standard (c-contiguous) strides const std::int64_t ldc =
370
- // result_strides[0] == 1 ? result_strides[1] :
371
- // result_strides[0];
372
- const std::int64_t ldc = size_n;
349
+ auto getLdaLdc = [](const bool isRown, shape_elem_type *strides,
350
+ shape_elem_type *shapes) {
351
+ if (isRown) {
352
+ return (strides[0 ] != 0 ) ? strides[0 ] : shapes[1 ];
353
+ }
354
+ return strides[1 ];
355
+ };
356
+
357
+ const std::int64_t lda = static_cast <std::int64_t >(
358
+ getLdaLdc (isRowmA, ext_input1_strides, ext_input1_shape));
359
+ const std::int64_t ldb = static_cast <std::int64_t >(
360
+ isRowmB ? ext_input2_strides[0 ] : ext_input2_strides[1 ]);
361
+ const std::int64_t ldc = static_cast <std::int64_t >(
362
+ getLdaLdc (isRowmC, ext_result_strides, ext_result_shape));
363
+
364
+ constexpr _DataType_output alpha = 1 ;
365
+ constexpr _DataType_output beta = 0 ;
366
+
367
+ std::stringstream error_msg;
368
+ std::int64_t info = 0 ;
373
369
374
370
try {
375
- sycl::event event = mkl_blas_rm::gemm (
376
- q, trans1, trans2, size_m, size_n, size_k,
377
- _DataType_output (1 ), // alpha
378
- input1, lda, input2, ldb,
379
- _DataType_output (0 ), // beta
380
- result, ldc);
381
- event.wait ();
382
- delete[] ext_input1_shape;
383
- delete[] ext_input1_strides;
384
- delete[] ext_input2_shape;
385
- delete[] ext_input2_strides;
386
- delete[] ext_result_shape;
387
-
388
- return event_ref;
371
+ if (isRowmC) {
372
+ mkl_blas_rm::gemm (q, transA, transB, size_m, size_n,
373
+ size_k, alpha, input1, lda, input2,
374
+ ldb, beta, result, ldc)
375
+ .wait ();
376
+ }
377
+ else {
378
+ mkl_blas_cm::gemm (q, transA, transB, size_m, size_n,
379
+ size_k, alpha, input1, lda, input2,
380
+ ldb, beta, result, ldc)
381
+ .wait ();
382
+ }
383
+ } catch (mkl_lapack::exception const &e) {
384
+ error_msg << " Unexpected MKL exception caught during "
385
+ " gemm() call:\n reason: "
386
+ << e.what () << " \n info: " << e.info ();
387
+ info = e.info ();
389
388
} catch (const std::exception &e) {
390
- // do nothing, proceed to general case
389
+ error_msg << " Unexpected SYCL exception caught during "
390
+ " gemm() call:\n "
391
+ << e.what ();
392
+ info = -1 ;
391
393
}
392
- #endif
394
+
395
+ if (info != 0 ) // an unexected error occurs
396
+ {
397
+ throw std::runtime_error (error_msg.str ());
398
+ }
399
+
400
+ delete[] ext_input1_shape;
401
+ delete[] ext_input1_strides;
402
+ delete[] ext_input2_shape;
403
+ delete[] ext_input2_strides;
404
+ delete[] ext_result_shape;
405
+ delete[] ext_result_strides;
406
+ return event_ref;
393
407
}
394
408
}
395
409
}
@@ -437,6 +451,7 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
437
451
delete[] ext_input2_shape;
438
452
delete[] ext_input2_strides;
439
453
delete[] ext_result_shape;
454
+ delete[] ext_result_strides;
440
455
441
456
return event_ref;
442
457
}
0 commit comments