Skip to content

Commit 3461932

Browse files
vtavananpolina4antonwolfy
authored
add support for axes keyword to dpnp.matmul (#1705)
* update matmul for cupy tests * address comments * address more comments * fix an error * use a function for error msg --------- Co-authored-by: Natalia Polina <[email protected]> Co-authored-by: Anton <[email protected]>
1 parent 2c38ae8 commit 3461932

File tree

7 files changed

+595
-61
lines changed

7 files changed

+595
-61
lines changed

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas;
4646
namespace py = pybind11;
4747
namespace type_utils = dpctl::tensor::type_utils;
4848

49-
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
49+
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5050
oneapi::mkl::transpose,
5151
oneapi::mkl::transpose,
5252
const std::int64_t,
@@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
6464
[dpctl_td_ns::num_types];
6565

6666
template <typename Tab, typename Tc>
67-
static sycl::event gemm_impl(sycl::queue exec_q,
67+
static sycl::event gemm_impl(sycl::queue &exec_q,
6868
oneapi::mkl::transpose transA,
6969
oneapi::mkl::transpose transB,
7070
const std::int64_t m,
@@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
130130
}
131131

132132
std::pair<sycl::event, sycl::event>
133-
gemm(sycl::queue exec_q,
133+
gemm(sycl::queue &exec_q,
134134
dpctl::tensor::usm_ndarray matrixA,
135135
dpctl::tensor::usm_ndarray matrixB,
136136
dpctl::tensor::usm_ndarray resultC,

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ namespace ext
3939
namespace blas
4040
{
4141
extern std::pair<sycl::event, sycl::event>
42-
gemm(sycl::queue exec_q,
42+
gemm(sycl::queue &exec_q,
4343
dpctl::tensor::usm_ndarray matrixA,
4444
dpctl::tensor::usm_ndarray matrixB,
4545
dpctl::tensor::usm_ndarray resultC,
4646
const std::vector<sycl::event> &depends);
4747

4848
extern std::pair<sycl::event, sycl::event>
49-
gemm_batch(sycl::queue exec_q,
49+
gemm_batch(sycl::queue &exec_q,
5050
dpctl::tensor::usm_ndarray matrixA,
5151
dpctl::tensor::usm_ndarray matrixB,
5252
dpctl::tensor::usm_ndarray resultC,

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace py = pybind11;
4747
namespace type_utils = dpctl::tensor::type_utils;
4848

4949
typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
50-
sycl::queue,
50+
sycl::queue &,
5151
const std::int64_t,
5252
const std::int64_t,
5353
const std::int64_t,
@@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t
6969
gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
7070

7171
template <typename Tab, typename Tc>
72-
static sycl::event gemm_batch_impl(sycl::queue exec_q,
72+
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
7373
const std::int64_t m,
7474
const std::int64_t n,
7575
const std::int64_t k,
@@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
145145
}
146146

147147
std::pair<sycl::event, sycl::event>
148-
gemm_batch(sycl::queue exec_q,
148+
gemm_batch(sycl::queue &exec_q,
149149
dpctl::tensor::usm_ndarray matrixA,
150150
dpctl::tensor::usm_ndarray matrixB,
151151
dpctl::tensor::usm_ndarray resultC,

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,18 +266,53 @@ def matmul(
266266
order="K",
267267
dtype=None,
268268
subok=True,
269+
signature=None,
270+
extobj=None,
271+
axes=None,
272+
axis=None,
269273
):
270274
"""
271275
Matrix product of two arrays.
272276
273277
For full documentation refer to :obj:`numpy.matmul`.
274278
279+
Parameters
280+
----------
281+
x1 : {dpnp_array, usm_ndarray}
282+
First input array.
283+
x2 : {dpnp_array, usm_ndarray}
284+
Second input array.
285+
out : {dpnp.ndarray, usm_ndarray}, optional
286+
Alternative output array in which to place the result. It must have
287+
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
288+
(of the calculated values) will be cast if necessary. Default: ``None``.
289+
dtype : dtype, optional
290+
Type to use in computing the matrix product. By default, the returned
291+
array will have data type that is determined by considering
292+
Promotion Type Rule and device capabilities.
293+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
294+
Controls what kind of data casting may occur. Default: ``"same_kind"``.
295+
order : {"C", "F", "A", "K", None}, optional
296+
Memory layout of the newly output array, if parameter `out` is ``None``.
297+
Default: "K".
298+
axes : list of tuples, optional
299+
A list of tuples with indices of axes the matrix product should operate on.
300+
For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements
301+
are 2d matrices and these are taken to be stored in the two last axes of each
302+
argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)].
303+
Default: ``None``.
304+
305+
Returns
306+
-------
307+
out : dpnp.ndarray
308+
Returns the matrix product of the inputs.
309+
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.
310+
275311
Limitations
276312
-----------
277-
Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray`
278-
or :class:`dpctl.tensor.usm_ndarray`.
279-
Keyword argument `subok` is currently unsupported.
280-
Input array data types are limited by supported DPNP :ref:`Data types`.
313+
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
314+
only supported with their default value.
315+
Otherwise ``NotImplementedError`` exception will be raised.
281316
282317
See Also
283318
--------
@@ -338,6 +373,18 @@ def matmul(
338373
raise NotImplementedError(
339374
"subok keyword argument is only supported by its default value."
340375
)
376+
elif signature is not None:
377+
raise NotImplementedError(
378+
"signature keyword argument is only supported by its default value."
379+
)
380+
elif extobj is not None:
381+
raise NotImplementedError(
382+
"extobj keyword argument is only supported by its default value."
383+
)
384+
elif axis is not None:
385+
raise NotImplementedError(
386+
"axis keyword argument is only supported by its default value."
387+
)
341388
else:
342389
return dpnp_matmul(
343390
x1,
@@ -346,6 +393,7 @@ def matmul(
346393
casting=casting,
347394
order=order,
348395
dtype=dtype,
396+
axes=axes,
349397
)
350398

351399

0 commit comments

Comments
 (0)