Skip to content

Commit 08b9de1

Browse files
authored
Merge 54f9f5b into 0457fe1
2 parents 0457fe1 + 54f9f5b commit 08b9de1

File tree

7 files changed

+326
-29
lines changed

7 files changed

+326
-29
lines changed

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas;
4646
namespace py = pybind11;
4747
namespace type_utils = dpctl::tensor::type_utils;
4848

49-
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
49+
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5050
oneapi::mkl::transpose,
5151
oneapi::mkl::transpose,
5252
const std::int64_t,
@@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
6464
[dpctl_td_ns::num_types];
6565

6666
template <typename Tab, typename Tc>
67-
static sycl::event gemm_impl(sycl::queue exec_q,
67+
static sycl::event gemm_impl(sycl::queue &exec_q,
6868
oneapi::mkl::transpose transA,
6969
oneapi::mkl::transpose transB,
7070
const std::int64_t m,
@@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
130130
}
131131

132132
std::pair<sycl::event, sycl::event>
133-
gemm(sycl::queue exec_q,
133+
gemm(sycl::queue &exec_q,
134134
dpctl::tensor::usm_ndarray matrixA,
135135
dpctl::tensor::usm_ndarray matrixB,
136136
dpctl::tensor::usm_ndarray resultC,

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ namespace ext
3939
namespace blas
4040
{
4141
extern std::pair<sycl::event, sycl::event>
42-
gemm(sycl::queue exec_q,
42+
gemm(sycl::queue &exec_q,
4343
dpctl::tensor::usm_ndarray matrixA,
4444
dpctl::tensor::usm_ndarray matrixB,
4545
dpctl::tensor::usm_ndarray resultC,
4646
const std::vector<sycl::event> &depends);
4747

4848
extern std::pair<sycl::event, sycl::event>
49-
gemm_batch(sycl::queue exec_q,
49+
gemm_batch(sycl::queue &exec_q,
5050
dpctl::tensor::usm_ndarray matrixA,
5151
dpctl::tensor::usm_ndarray matrixB,
5252
dpctl::tensor::usm_ndarray resultC,

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace py = pybind11;
4747
namespace type_utils = dpctl::tensor::type_utils;
4848

4949
typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
50-
sycl::queue,
50+
sycl::queue &,
5151
const std::int64_t,
5252
const std::int64_t,
5353
const std::int64_t,
@@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t
6969
gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
7070

7171
template <typename Tab, typename Tc>
72-
static sycl::event gemm_batch_impl(sycl::queue exec_q,
72+
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
7373
const std::int64_t m,
7474
const std::int64_t n,
7575
const std::int64_t k,
@@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
145145
}
146146

147147
std::pair<sycl::event, sycl::event>
148-
gemm_batch(sycl::queue exec_q,
148+
gemm_batch(sycl::queue &exec_q,
149149
dpctl::tensor::usm_ndarray matrixA,
150150
dpctl::tensor::usm_ndarray matrixB,
151151
dpctl::tensor::usm_ndarray resultC,

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,18 +266,53 @@ def matmul(
266266
order="K",
267267
dtype=None,
268268
subok=True,
269+
signature=None,
270+
extobj=None,
271+
axes=None,
272+
axis=None,
269273
):
270274
"""
271275
Matrix product of two arrays.
272276
273277
For full documentation refer to :obj:`numpy.matmul`.
274278
279+
Parameters
280+
----------
281+
x1 : {dpnp_array, usm_ndarray}
282+
First input array.
283+
x2 : {dpnp_array, usm_ndarray}
284+
Second input array.
285+
out : {dpnp.ndarray, usm_ndarray}, optional
286+
Alternative output array in which to place the result. It must have
287+
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
288+
(of the calculated values) will be cast if necessary. Default: ``None``.
289+
dtype : dtype, optional
290+
Type to use in computing the matrix product. By default, the returned
291+
array will have data type that is determined by considering
292+
Promotion Type Rule and device capabilities.
293+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
294+
Controls what kind of data casting may occur. Default: ``"same_kind"``.
295+
order : {"C", "F", "A", "K", None}, optional
296+
Memory layout of the newly output array, if parameter `out` is ``None``.
297+
Default: "K".
298+
axes : list of tuples, optional
299+
A list of tuples with indices of axes the matrix product should operate on.
300+
For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements
301+
are 2d matrices and these are taken to be stored in the two last axes of each
302+
argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)].
303+
Default: ``None``.
304+
305+
Returns
306+
-------
307+
out : dpnp.ndarray
308+
Returns the matrix product of the inputs.
309+
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.
310+
275311
Limitations
276312
-----------
277-
Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray`
278-
or :class:`dpctl.tensor.usm_ndarray`.
279-
Keyword argument `subok` is currently unsupported.
280-
Input array data types are limited by supported DPNP :ref:`Data types`.
313+
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
314+
only supported with their default value.
315+
Otherwise ``NotImplementedError`` exception will be raised.
281316
282317
See Also
283318
--------
@@ -338,15 +373,67 @@ def matmul(
338373
raise NotImplementedError(
339374
"subok keyword argument is only supported by its default value."
340375
)
376+
elif signature is not None:
377+
raise NotImplementedError(
378+
"signature keyword argument is only supported by its default value."
379+
)
380+
elif extobj is not None:
381+
raise NotImplementedError(
382+
"extobj keyword argument is only supported by its default value."
383+
)
384+
elif axis is not None:
385+
raise NotImplementedError(
386+
"axis keyword argument is only supported by its default value."
387+
)
341388
else:
342-
return dpnp_matmul(
389+
if axes is not None:
390+
if not isinstance(axes, list):
391+
raise TypeError("Axes should be a list.")
392+
else:
393+
if len(axes) != 3:
394+
raise ValueError(
395+
"Axes should be a list of three tuples for inputs and output."
396+
)
397+
398+
for i in range(3):
399+
if not isinstance(axes[i], tuple):
400+
raise TypeError(f"Axes item {i} should be a tuple.")
401+
if len(axes[i]) != 2:
402+
raise ValueError(
403+
f"Axes item {i} should be a tuple with 2 elements."
404+
)
405+
406+
for j in range(2):
407+
if not isinstance(axes[i][j], int):
408+
raise TypeError("Axes must be an integer.")
409+
410+
axes_x1, axes_x2, axes_res = axes
411+
# Move the axes that are going to be used in matrix product,
412+
# to the end of "x1" and "x2"
413+
x1 = dpnp.moveaxis(x1, axes_x1, (-2, -1))
414+
x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1))
415+
out_orig = out
416+
if out is not None:
417+
dpnp.check_supported_arrays_type(x1, x2)
418+
# out that is passed to the backend should have the correct shape
419+
out = dpnp.moveaxis(out, axes_res, (-2, -1))
420+
421+
result = dpnp_matmul(
343422
x1,
344423
x2,
345424
out=out,
346425
casting=casting,
347426
order=order,
348427
dtype=dtype,
349428
)
429+
if axes is not None:
430+
if out is result:
431+
# out and out_orig contain the same data but they have different shape
432+
return out_orig
433+
# Move the result to the appropriate axes of out array
434+
result = dpnp.moveaxis(result, (-2, -1), axes_res)
435+
436+
return result
350437

351438

352439
def outer(x1, x2, out=None):

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,21 +116,9 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
116116
x2_strides = x2.strides
117117
res_strides = res.strides
118118

119-
# when shape along any particular dimension is 1,
120-
# the stride along that dimension is not a
121-
# meaningful number and is undefined. Here, we
122-
# standardizing strides before continuing,
123-
# setting stride to 0 if the shape along that axis is <=1
124-
if x1_is_2D:
125-
x1_strides = tuple(
126-
str_i if sh_i > 1 else 0
127-
for sh_i, str_i in zip(x1.shape, x1_strides)
128-
)
129-
if x2_is_2D:
130-
x2_strides = tuple(
131-
str_i if sh_i > 1 else 0
132-
for sh_i, str_i in zip(x2.shape, x2_strides)
133-
)
119+
# need to standardize to use in ti._contract_iter2
120+
x1_strides = _standardize_strides(x1_strides, x1_is_2D, x1.shape, x1.ndim)
121+
x2_strides = _standardize_strides(x2_strides, x2_is_2D, x2.shape, x2.ndim)
134122

135123
batch_size = res.shape[:-2][0]
136124
stridea = x1_strides[0]
@@ -220,6 +208,37 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
220208
return op_dtype, res_dtype
221209

222210

211+
def _standardize_strides(strides, inherently_2D, shape, ndim):
212+
"""
213+
Standardizing the strides.
214+
215+
When shape of an array along any particular dimension is 1, the stride
216+
along that dimension is undefined. This functions standardize the strides
217+
in the following way:
218+
For N-D arrays that are inherently 2D (all dimesnsion are one except for two of them),
219+
we use zero as the stride for dimensions equal one.
220+
For other N-D arrays, the non-zero value of strides is calculated and used.
221+
222+
"""
223+
224+
if inherently_2D:
225+
stndrd_strides = tuple(
226+
str_i if sh_i > 1 else 0 for sh_i, str_i in zip(shape, strides)
227+
)
228+
else:
229+
stndrd_strides = [
230+
numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i]
231+
for i in range(ndim - 1)
232+
]
233+
# last dimension
234+
stndrd_strides.append(
235+
1 if strides[ndim - 1] == 0 else strides[ndim - 1]
236+
)
237+
stndrd_strides = tuple(stndrd_strides)
238+
239+
return stndrd_strides
240+
241+
223242
def dpnp_dot(a, b, /, out=None, *, conjugate=False):
224243
"""
225244
Return the dot product of two arrays.

tests/test_mathematical.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2634,6 +2634,57 @@ def test_matmul_dtype(self, dtype, shape_pair):
26342634
expected = numpy.matmul(a1, a2, dtype=dtype)
26352635
assert_dtype_allclose(result, expected)
26362636

2637+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
2638+
@pytest.mark.parametrize(
2639+
"axes",
2640+
[
2641+
[(-3, -1), (0, 2), (-2, -3)],
2642+
[(3, 1), (2, 0), (3, 1)],
2643+
[(3, 1), (2, 0), (0, 1)],
2644+
],
2645+
)
2646+
def test_matmul_axes(self, dtype, axes):
2647+
a = numpy.array(
2648+
numpy.random.uniform(-10, 10, 120), dtype=dtype
2649+
).reshape(2, 5, 3, 4)
2650+
b = numpy.array(
2651+
numpy.random.uniform(-10, 10, 120), dtype=dtype
2652+
).reshape(4, 2, 5, 3)
2653+
ia = dpnp.array(a)
2654+
ib = dpnp.array(b)
2655+
2656+
result = dpnp.matmul(ia, ib, axes=axes)
2657+
print(result.shape)
2658+
expected = numpy.matmul(a, b, axes=axes)
2659+
# TODO: investigate the effect of factor, see SAT-6700
2660+
assert_dtype_allclose(result, expected, factor=24)
2661+
2662+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
2663+
@pytest.mark.parametrize(
2664+
"axes, out_shape",
2665+
[
2666+
([(-3, -1), (0, 2), (-2, -3)], (2, 5, 5, 3)),
2667+
([(3, 1), (2, 0), (3, 1)], (2, 4, 3, 4)),
2668+
([(3, 1), (2, 0), (1, 2)], (2, 4, 4, 3)),
2669+
],
2670+
)
2671+
def test_matmul_axes_out(self, dtype, axes, out_shape):
2672+
a = numpy.array(
2673+
numpy.random.uniform(-10, 10, 120), dtype=dtype
2674+
).reshape(2, 5, 3, 4)
2675+
b = numpy.array(
2676+
numpy.random.uniform(-10, 10, 120), dtype=dtype
2677+
).reshape(4, 2, 5, 3)
2678+
ia = dpnp.array(a)
2679+
ib = dpnp.array(b)
2680+
2681+
out_dp = dpnp.empty(out_shape, dtype=dtype)
2682+
result = dpnp.matmul(ia, ib, axes=axes, out=out_dp)
2683+
assert result is out_dp
2684+
expected = numpy.matmul(a, b, axes=axes)
2685+
# TODO: investigate the effect of factor, see SAT-6700
2686+
assert_dtype_allclose(result, expected, factor=24)
2687+
26372688
@pytest.mark.parametrize("dtype1", get_all_dtypes(no_bool=True))
26382689
@pytest.mark.parametrize(
26392690
"dtype2", get_all_dtypes(no_bool=True, no_none=True)
@@ -2822,9 +2873,52 @@ def test_matmul_casting(self):
28222873
with pytest.raises(TypeError):
28232874
dpnp.matmul(a1, a2, out=res, casting="safe")
28242875

2825-
def test_matmul_subok(self):
2876+
def test_matmul_not_implemented(self):
28262877
a1 = dpnp.arange(2 * 4).reshape(2, 4)
28272878
a2 = dpnp.arange(4 * 3).reshape(4, 3)
28282879

28292880
with pytest.raises(NotImplementedError):
28302881
dpnp.matmul(a1, a2, subok=False)
2882+
2883+
with pytest.raises(NotImplementedError):
2884+
dpnp.matmul(
2885+
a1, a2, signature=(dpnp.float32, dpnp.float32, dpnp.float32)
2886+
)
2887+
2888+
def custom_error_callback(err):
2889+
print("Custom error callback triggered with error:", err)
2890+
2891+
with pytest.raises(NotImplementedError):
2892+
dpnp.matmul(a1, a2, extobj=[32, 1, custom_error_callback])
2893+
2894+
with pytest.raises(NotImplementedError):
2895+
dpnp.matmul(a1, a2, axis=2)
2896+
2897+
def test_matmul_axes(self):
2898+
a1 = dpnp.arange(120).reshape(2, 5, 3, 4)
2899+
a2 = dpnp.arange(120).reshape(4, 2, 5, 3)
2900+
2901+
# axes must be a list
2902+
axes = ((3, 1), (2, 0), (0, 1))
2903+
with pytest.raises(TypeError):
2904+
dpnp.matmul(a1, a2, axes=axes)
2905+
2906+
# axes must be be a list of three tuples
2907+
axes = [(3, 1), (2, 0)]
2908+
with pytest.raises(ValueError):
2909+
dpnp.matmul(a1, a2, axes=axes)
2910+
2911+
# axes items should be a tuple
2912+
axes = [(3, 1), (2, 0), [0, 1]]
2913+
with pytest.raises(TypeError):
2914+
dpnp.matmul(a1, a2, axes=axes)
2915+
2916+
# axes items should be a tuple with 2 elements
2917+
axes = [(3, 1), (2, 0), (0, 1, 2)]
2918+
with pytest.raises(ValueError):
2919+
dpnp.matmul(a1, a2, axes=axes)
2920+
2921+
# axes must be an integer
2922+
axes = [(3, 1), (2, 0), (0.0, 1)]
2923+
with pytest.raises(TypeError):
2924+
dpnp.matmul(a1, a2, axes=axes)

0 commit comments

Comments
 (0)