Skip to content

Commit 033378d

Browse files
authored
some modification in matmul function (#1927)
* clean-up * address comments
1 parent 56a03e8 commit 033378d

File tree

4 files changed

+76
-47
lines changed

4 files changed

+76
-47
lines changed

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,16 +271,38 @@ std::tuple<sycl::event, sycl::event, bool>
271271

272272
standardize_strides_to_nonzero(a_stride, a_shape);
273273
standardize_strides_to_nonzero(b_stride, b_shape);
274+
standardize_strides_to_nonzero(c_stride, c_shape);
274275
const bool A_base_is_f_contig =
275276
a_stride[1] == 1 && a_stride[2] == a_shape[1];
277+
const bool A_base_is_c_contig =
278+
a_stride[1] == a_shape[2] && a_stride[2] == 1;
276279
const bool B_base_is_f_contig =
277280
b_stride[1] == 1 && b_stride[2] == b_shape[1];
281+
const bool B_base_is_c_contig =
282+
b_stride[1] == b_shape[2] && b_stride[2] == 1;
283+
const bool C_base_is_f_contig =
284+
c_stride[1] == 1 && c_stride[2] == c_shape[1];
285+
const bool C_base_is_c_contig =
286+
c_stride[1] == c_shape[2] && c_stride[2] == 1;
278287

279288
bool is_row_major = true;
280289
if (A_base_is_f_contig && B_base_is_f_contig) {
281290
is_row_major = false;
282291
}
283292

293+
if (!A_base_is_f_contig and !A_base_is_c_contig) {
294+
throw py::value_error("The 2D base of the first input array is not "
295+
"c-contiguous nor f-contiguous.");
296+
}
297+
if (!B_base_is_f_contig and !B_base_is_c_contig) {
298+
throw py::value_error("The 2D base of the second input array is not "
299+
"c-contiguous nor f-contiguous.");
300+
}
301+
if (!C_base_is_f_contig and !C_base_is_c_contig) {
302+
throw py::value_error("The 2D base of result array is not c-contiguous "
303+
"nor f-contiguous.");
304+
}
305+
284306
oneapi::mkl::transpose transA;
285307
oneapi::mkl::transpose transB;
286308
if (is_row_major) {
@@ -346,10 +368,10 @@ std::tuple<sycl::event, sycl::event, bool>
346368
strideb, stridec, transA, transB, a_typeless_ptr,
347369
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
348370

349-
sycl::event args_batch_ev = dpctl::utils::keep_args_alive(
371+
sycl::event args_ev = dpctl::utils::keep_args_alive(
350372
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
351373

352-
return std::make_tuple(args_batch_ev, gemm_batch_ev, is_row_major);
374+
return std::make_tuple(args_ev, gemm_batch_ev, is_row_major);
353375
}
354376

355377
template <typename fnT, typename Tab, typename Tc>

dpnp/backend/extensions/blas/gemv.hpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,6 @@ extern std::pair<sycl::event, sycl::event>
4040
const bool transpose,
4141
const std::vector<sycl::event> &depends);
4242

43-
extern std::pair<sycl::event, sycl::event>
44-
gemv_batch(sycl::queue &exec_q,
45-
const dpctl::tensor::usm_ndarray &matrixA,
46-
const dpctl::tensor::usm_ndarray &vectorX,
47-
const dpctl::tensor::usm_ndarray &vectorY,
48-
const bool transpose,
49-
const std::vector<sycl::event> &depends);
50-
5143
extern void init_gemv_dispatch_vector(void);
5244
extern void init_gemv_batch_dispatch_vector(void);
5345
} // namespace dpnp::extensions::blas

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ def _create_result_array(
134134
"""
135135
Create the result array.
136136
137-
If `out` is not ``None`` and its features match the specified `shape`, `dtype,
138-
`usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and
139-
does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
137+
If `out` is not ``None`` and its shape and dtype match the desired `shape`
138+
and `dtype`, and its 2-D base is contiguous and it does not have any memory
139+
overlap with `x1` and `x2`, `out` itself is returned.
140140
If these conditions are not satisfied, an empty array is returned with the
141141
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.
142+
142143
"""
143144

144145
if out is not None:
@@ -150,7 +151,6 @@ def _create_result_array(
150151
if (
151152
out.dtype == dtype
152153
and out.shape == shape
153-
and out.usm_type == usm_type
154154
and contig_flag
155155
and not ti._array_overlap(x1_usm, out_usm)
156156
and not ti._array_overlap(x2_usm, out_usm)
@@ -325,10 +325,13 @@ def _get_result_shape(x1, x2, out, np_flag):
325325

326326
def _gemm_batch_matmul(exec_q, x1, x2, res):
327327
# arrays here are already at least 3D, make them 3D
328-
x1 = dpnp.reshape(x1, (-1, x1.shape[-2], x1.shape[-1]))
329-
x2 = dpnp.reshape(x2, (-1, x2.shape[-2], x2.shape[-1]))
328+
x1_shape = x1.shape
329+
x2_shape = x2.shape
330+
x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1]))
331+
x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1]))
330332
orig_shape = res.shape
331-
res = dpnp.reshape(res, (-1, res.shape[-2], res.shape[-1]))
333+
res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1]))
334+
res_shape = res.shape
332335

333336
# gemm_batch does not handle negative strides, make a copy if needed
334337
x1 = _copy_array(x1, copy_flag=x1.strides[0] < 0)
@@ -338,16 +341,16 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
338341
_manager = dpu.SequentialOrderManager[exec_q]
339342

340343
# onemkl::blas::gemm_bacth throws an exception (Provided range is out
341-
# of integer limits) if the batch_size is too large (>=4096*4096), so
342-
# we need to split the batch into smaller chunks
343-
chunk = 2048 * 2048
344-
batch_size = res.shape[0]
344+
# of integer limits) if the batch_size is too large, so we need to
345+
# split the batch into smaller chunks, the size depnends on device
346+
chunk = 4096 * 4096 - 2
347+
batch_size = res_shape[0]
345348
for i in range(0, batch_size, chunk):
346-
if x1.shape[0] == 1:
349+
if x1_shape[0] == 1:
347350
# x1 is repeatedly multiplied with each matrix in x2
348351
x1_usm = dpnp.get_usm_ndarray(x1)
349352
x2_usm = dpnp.get_usm_ndarray(x2[i : i + chunk, ...])
350-
elif x2.shape[0] == 1:
353+
elif x2_shape[0] == 1:
351354
x1_usm = dpnp.get_usm_ndarray(x1[i : i + chunk, ...])
352355
x2_usm = dpnp.get_usm_ndarray(x2)
353356
else:
@@ -364,25 +367,36 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
364367
)
365368
_manager.add_event_pair(ht_ev, blas_ev)
366369

367-
res_shape = res.shape
368370
_, res_is_c_contig, res_is_f_contig = _define_contig_flag(res)
369371
if row_major:
370372
if res_is_f_contig:
371-
res = dpnp.reshape(
372-
dpnp.ravel(res, order="F"),
373-
(res_shape[1], res_shape[2], batch_size),
374-
).transpose(2, 0, 1)
373+
# Considering the multiplication for one of the batches,
374+
# we have result[0, 1] = a[0, :]*b[1, :]. In row_major mode,
375+
# it is assumed result array is c-contiguous, i.e. the value of
376+
# result[0, 1] is has the second place memory.
377+
# however, the result array is batches of 2D f-contiguous array,
378+
# i.e. the second place of memory points out to res[1, 0].
379+
# So, we need to read data of each 2D array in the batch in
380+
# "F" order and write it in "C" order
381+
res = (
382+
res.ravel(order="F")
383+
.reshape(res_shape[1], res_shape[2], batch_size)
384+
.transpose(2, 0, 1)
385+
)
375386
else:
376387
if res_is_c_contig:
377-
res = dpnp.reshape(
378-
dpnp.ravel(res, order="C"),
379-
(batch_size, res_shape[2], res_shape[1]),
380-
).transpose(0, 2, 1)
388+
# read data of each 2D array in the batch in "C" order and
389+
# write it in "F" order
390+
res = (
391+
res.ravel(order="C")
392+
.reshape(batch_size, res_shape[2], res_shape[1])
393+
.transpose(0, 2, 1)
394+
)
381395

382396
if res_shape != orig_shape:
383397
res = res.reshape(orig_shape)
384398

385-
return dpnp.ascontiguousarray(res)
399+
return res
386400

387401

388402
def _gemm_matmul(exec_q, x1, x2, res):
@@ -400,13 +414,13 @@ def _gemm_matmul(exec_q, x1, x2, res):
400414
if row_major:
401415
if res.flags.f_contiguous is True:
402416
# read data in "F" order and write it in "C" order
403-
res = dpnp.reshape(dpnp.ravel(res, order="F"), res.shape, order="C")
417+
res = dpnp.ravel(res, order="F").reshape(res.shape, order="C")
404418
else:
405419
if res.flags.c_contiguous is True:
406420
# read data in "C" order and write it in "F" order
407-
res = dpnp.reshape(dpnp.ravel(res, order="C"), res.shape, order="F")
421+
res = dpnp.ravel(res, order="C").reshape(res.shape, order="F")
408422

409-
return dpnp.ascontiguousarray(res)
423+
return res
410424

411425

412426
def _shape_error(a, b, core_dim, err_msg):
@@ -767,9 +781,9 @@ def dpnp_matmul(
767781
call_flag = "multiply"
768782
elif x1_is_1D and x2_is_1D:
769783
call_flag = "dot"
770-
x1 = dpnp.reshape(x1, x1_shape[-1])
771-
if x2_ndim != 1:
772-
x2 = dpnp.reshape(x2, x2_shape[-2])
784+
# arrays are inehrently 1D, make them 1D
785+
x1 = dpnp.ravel(x1)
786+
x2 = dpnp.ravel(x2)
773787
elif x1_base_is_1D and x2_base_is_1D:
774788
# TODO: implement a batch version of dot to use it here
775789
call_flag = "gemm_batch"
@@ -912,12 +926,11 @@ def dpnp_matmul(
912926
# we need to update it to match the passed `order`.
913927
if order not in ["k", "K"]:
914928
return dpnp.array(result, copy=False, order=order)
915-
return result
929+
# dpnp.ascontiguousarray changes 0-D array to 1-D array
930+
if result.ndim == 0:
931+
return result
932+
return dpnp.ascontiguousarray(result)
916933

917-
# TODO: There is opportunity to improve performance when out keyword is
918-
# present. For some cases, out is NOT result but they have the same base
919-
# (They are views of the same data). In this case, we can avoid copyign
920-
# result to out.
921934
result = dpnp.get_result_array(result, out, casting=casting)
922935
if axes is not None and out is result:
923936
# out and out_orig contain the same data but they have different shape

tests/test_mathematical.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3037,6 +3037,7 @@ def test_matmul_strided3(self, stride, transpose):
30373037
@pytest.mark.parametrize("incy", [-2, 2], ids=["-2", "2"])
30383038
@pytest.mark.parametrize("transpose", [False, True], ids=["False", "True"])
30393039
def test_matmul_strided_mat_vec(self, shape, incx, incy, transpose):
3040+
# vector is strided
30403041
if transpose:
30413042
s1 = shape[-2]
30423043
s2 = shape[-1]
@@ -3069,6 +3070,7 @@ def test_matmul_strided_mat_vec(self, shape, incx, incy, transpose):
30693070
@pytest.mark.parametrize("incy", [-2, 2], ids=["-2", "2"])
30703071
@pytest.mark.parametrize("transpose", [False, True], ids=["False", "True"])
30713072
def test_matmul_strided_vec_mat(self, shape, incx, incy, transpose):
3073+
# vector is strided
30723074
if transpose:
30733075
s1 = shape[-2]
30743076
s2 = shape[-1]
@@ -3217,9 +3219,9 @@ def test_matmul_out_0D(self, out_shape):
32173219
@pytest.mark.parametrize(
32183220
"shape_pair",
32193221
[
3220-
((4096, 4096, 2, 2), (4096, 4096, 2, 2)),
3221-
((2, 2), (4096, 4096, 2, 2)),
3222-
((4096, 4096, 2, 2), (2, 2)),
3222+
((5000, 5000, 2, 2), (5000, 5000, 2, 2)),
3223+
((2, 2), (5000, 5000, 2, 2)),
3224+
((5000, 5000, 2, 2), (2, 2)),
32233225
],
32243226
)
32253227
def test_matmul_large(self, shape_pair):

0 commit comments

Comments
 (0)