Skip to content

Commit 786d450

Browse files
authored
Merge 295e0a7 into 98dc2f5
2 parents 98dc2f5 + 295e0a7 commit 786d450

File tree

7 files changed

+470
-46
lines changed

7 files changed

+470
-46
lines changed

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas;
4646
namespace py = pybind11;
4747
namespace type_utils = dpctl::tensor::type_utils;
4848

49-
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
49+
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5050
oneapi::mkl::transpose,
5151
oneapi::mkl::transpose,
5252
const std::int64_t,
@@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
6464
[dpctl_td_ns::num_types];
6565

6666
template <typename Tab, typename Tc>
67-
static sycl::event gemm_impl(sycl::queue exec_q,
67+
static sycl::event gemm_impl(sycl::queue &exec_q,
6868
oneapi::mkl::transpose transA,
6969
oneapi::mkl::transpose transB,
7070
const std::int64_t m,
@@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
130130
}
131131

132132
std::pair<sycl::event, sycl::event>
133-
gemm(sycl::queue exec_q,
133+
gemm(sycl::queue &exec_q,
134134
dpctl::tensor::usm_ndarray matrixA,
135135
dpctl::tensor::usm_ndarray matrixB,
136136
dpctl::tensor::usm_ndarray resultC,

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ namespace ext
3939
namespace blas
4040
{
4141
extern std::pair<sycl::event, sycl::event>
42-
gemm(sycl::queue exec_q,
42+
gemm(sycl::queue &exec_q,
4343
dpctl::tensor::usm_ndarray matrixA,
4444
dpctl::tensor::usm_ndarray matrixB,
4545
dpctl::tensor::usm_ndarray resultC,
4646
const std::vector<sycl::event> &depends);
4747

4848
extern std::pair<sycl::event, sycl::event>
49-
gemm_batch(sycl::queue exec_q,
49+
gemm_batch(sycl::queue &exec_q,
5050
dpctl::tensor::usm_ndarray matrixA,
5151
dpctl::tensor::usm_ndarray matrixB,
5252
dpctl::tensor::usm_ndarray resultC,

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace py = pybind11;
4747
namespace type_utils = dpctl::tensor::type_utils;
4848

4949
typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
50-
sycl::queue,
50+
sycl::queue &,
5151
const std::int64_t,
5252
const std::int64_t,
5353
const std::int64_t,
@@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t
6969
gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
7070

7171
template <typename Tab, typename Tc>
72-
static sycl::event gemm_batch_impl(sycl::queue exec_q,
72+
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
7373
const std::int64_t m,
7474
const std::int64_t n,
7575
const std::int64_t k,
@@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
145145
}
146146

147147
std::pair<sycl::event, sycl::event>
148-
gemm_batch(sycl::queue exec_q,
148+
gemm_batch(sycl::queue &exec_q,
149149
dpctl::tensor::usm_ndarray matrixA,
150150
dpctl::tensor::usm_ndarray matrixB,
151151
dpctl::tensor::usm_ndarray resultC,

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,18 +266,53 @@ def matmul(
266266
order="K",
267267
dtype=None,
268268
subok=True,
269+
signature=None,
270+
extobj=None,
271+
axes=None,
272+
axis=None,
269273
):
270274
"""
271275
Matrix product of two arrays.
272276
273277
For full documentation refer to :obj:`numpy.matmul`.
274278
279+
Parameters
280+
----------
281+
x1 : {dpnp_array, usm_ndarray}
282+
First input array.
283+
x2 : {dpnp_array, usm_ndarray}
284+
Second input array.
285+
out : {dpnp.ndarray, usm_ndarray}, optional
286+
Alternative output array in which to place the result. It must have
287+
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
288+
(of the calculated values) will be cast if necessary. Default: ``None``.
289+
dtype : dtype, optional
290+
Type to use in computing the matrix product. By default, the returned
291+
array will have data type that is determined by considering
292+
Promotion Type Rule and device capabilities.
293+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
294+
Controls what kind of data casting may occur. Default: ``"same_kind"``.
295+
order : {"C", "F", "A", "K", None}, optional
296+
Memory layout of the newly output array, if parameter `out` is ``None``.
297+
Default: "K".
298+
axes : list of tuples, optional
299+
A list of tuples with indices of axes the matrix product should operate on.
300+
For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements
301+
are 2d matrices and these are taken to be stored in the two last axes of each
302+
argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)].
303+
Default: ``None``.
304+
305+
Returns
306+
-------
307+
out : dpnp.ndarray
308+
Returns the matrix product of the inputs.
309+
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.
310+
275311
Limitations
276312
-----------
277-
Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray`
278-
or :class:`dpctl.tensor.usm_ndarray`.
279-
Keyword argument `subok` is currently unsupported.
280-
Input array data types are limited by supported DPNP :ref:`Data types`.
313+
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
314+
only supported with their default value.
315+
Otherwise ``NotImplementedError`` exception will be raised.
281316
282317
See Also
283318
--------
@@ -338,6 +373,18 @@ def matmul(
338373
raise NotImplementedError(
339374
"subok keyword argument is only supported by its default value."
340375
)
376+
elif signature is not None:
377+
raise NotImplementedError(
378+
"signature keyword argument is only supported by its default value."
379+
)
380+
elif extobj is not None:
381+
raise NotImplementedError(
382+
"extobj keyword argument is only supported by its default value."
383+
)
384+
elif axis is not None:
385+
raise NotImplementedError(
386+
"axis keyword argument is only supported by its default value."
387+
)
341388
else:
342389
return dpnp_matmul(
343390
x1,
@@ -346,6 +393,7 @@ def matmul(
346393
casting=casting,
347394
order=order,
348395
dtype=dtype,
396+
axes=axes,
349397
)
350398

351399

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 129 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import dpctl.tensor as dpt
2828
import dpctl.tensor._tensor_impl as ti
2929
import numpy
30+
from numpy.core.numeric import normalize_axis_tuple
3031

3132
import dpnp
3233
import dpnp.backend.extensions.blas._blas_impl as bi
@@ -43,7 +44,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
4344
If `out` is not ``None`` and its features match the specified `shape`, `dtype,
4445
`usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and
4546
does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
46-
If these conditions are not statisfied, an empty array is returned with the
47+
If these conditions are not satisfied, an empty array is returned with the
4748
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.
4849
"""
4950

@@ -116,21 +117,9 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
116117
x2_strides = x2.strides
117118
res_strides = res.strides
118119

119-
# when shape along any particular dimension is 1,
120-
# the stride along that dimension is not a
121-
# meaningful number and is undefined. Here, we
122-
# standardizing strides before continuing,
123-
# setting stride to 0 if the shape along that axis is <=1
124-
if x1_is_2D:
125-
x1_strides = tuple(
126-
str_i if sh_i > 1 else 0
127-
for sh_i, str_i in zip(x1.shape, x1_strides)
128-
)
129-
if x2_is_2D:
130-
x2_strides = tuple(
131-
str_i if sh_i > 1 else 0
132-
for sh_i, str_i in zip(x2.shape, x2_strides)
133-
)
120+
# need to standardize to use in ti._contract_iter2
121+
x1_strides = _standardize_strides(x1_strides, x1_is_2D, x1.shape, x1.ndim)
122+
x2_strides = _standardize_strides(x2_strides, x2_is_2D, x2.shape, x2.ndim)
134123

135124
batch_size = res.shape[:-2][0]
136125
stridea = x1_strides[0]
@@ -220,6 +209,92 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
220209
return op_dtype, res_dtype
221210

222211

212+
def _standardize_strides(strides, inherently_2D, shape, ndim):
213+
"""
214+
Standardizing the strides.
215+
216+
When shape of an array along any particular dimension is 1, the stride
217+
along that dimension is undefined. This functions standardize the strides
218+
in the following way:
219+
For N-D arrays that are inherently 2D (all dimesnsion are one except for two of them),
220+
we use zero as the stride for dimensions equal one.
221+
For other N-D arrays, the non-zero value of strides is calculated and used.
222+
223+
"""
224+
225+
if inherently_2D:
226+
stndrd_strides = tuple(
227+
str_i if sh_i > 1 else 0 for sh_i, str_i in zip(shape, strides)
228+
)
229+
else:
230+
stndrd_strides = [
231+
numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i]
232+
for i in range(ndim - 1)
233+
]
234+
# last dimension
235+
stndrd_strides.append(
236+
1 if strides[ndim - 1] == 0 else strides[ndim - 1]
237+
)
238+
stndrd_strides = tuple(stndrd_strides)
239+
240+
return stndrd_strides
241+
242+
243+
def _validate_axes(x1, x2, axes):
244+
"""Check axes is valid for matmul function."""
245+
246+
def _validate_internal(axes, i, ndim):
247+
if ndim == 1:
248+
iter = 1
249+
if isinstance(axes, int):
250+
axes = (axes,)
251+
elif not isinstance(axes, tuple):
252+
raise TypeError(
253+
f"Axes item {i}: {type(axes)} object cannot be interpreted as an integer."
254+
)
255+
256+
if len(axes) != 1:
257+
raise ValueError(
258+
f"Axes item {i} should be a tuple with a single element, or an integer."
259+
)
260+
else:
261+
iter = 2
262+
if not isinstance(axes, tuple):
263+
raise TypeError(f"Axes item {i} should be a tuple.")
264+
if len(axes) != 2:
265+
raise ValueError(
266+
f"Axes item {i} should be a tuple with 2 elements."
267+
)
268+
269+
for j in range(iter):
270+
if not isinstance(axes[j], int):
271+
raise TypeError(
272+
f"Axes item {i}: {type(axes[j])} object cannot be interpreted as an integer."
273+
)
274+
return axes
275+
276+
if not isinstance(axes, list):
277+
raise TypeError("Axes should be a list.")
278+
else:
279+
if len(axes) != 3:
280+
raise ValueError(
281+
"Axes should be a list of three tuples for inputs and output."
282+
)
283+
284+
axes[0] = _validate_internal(axes[0], 0, x1.ndim)
285+
axes[1] = _validate_internal(axes[1], 1, x2.ndim)
286+
287+
if x1.ndim == 1 and x2.ndim == 1:
288+
if axes[2] != ():
289+
raise TypeError("Axes item 2 should be an empty tuple.")
290+
elif x1.ndim == 1 or x2.ndim == 1:
291+
axes[2] = _validate_internal(axes[2], 2, 1)
292+
else:
293+
axes[2] = _validate_internal(axes[2], 2, 2)
294+
295+
return axes
296+
297+
223298
def dpnp_dot(a, b, /, out=None, *, conjugate=False):
224299
"""
225300
Return the dot product of two arrays.
@@ -302,6 +377,7 @@ def dpnp_matmul(
302377
casting="same_kind",
303378
order="K",
304379
dtype=None,
380+
axes=None,
305381
):
306382
"""
307383
Return the matrix product of two arrays.
@@ -327,6 +403,22 @@ def dpnp_matmul(
327403

328404
res_usm_type, exec_q = get_usm_allocations([x1, x2])
329405

406+
if axes is not None:
407+
axes = _validate_axes(x1, x2, axes)
408+
409+
axes_x1, axes_x2, axes_res = axes
410+
axes_x1 = normalize_axis_tuple(axes_x1, x1.ndim, "axis")
411+
axes_x2 = normalize_axis_tuple(axes_x2, x2.ndim, "axis")
412+
# Move the axes that are going to be used in matrix product,
413+
# to the end of "x1" and "x2"
414+
x1 = dpnp.moveaxis(x1, axes_x1, (-2, -1)) if x1.ndim != 1 else x1
415+
x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) if x2.ndim != 1 else x2
416+
out_orig = out
417+
if out is not None:
418+
dpnp.check_supported_arrays_type(out)
419+
# out that is passed to the backend should have the correct shape
420+
out = dpnp.moveaxis(out, axes_res, (-2, -1))
421+
330422
appended_axes = []
331423
if x1_ndim == 1:
332424
x1 = x1[dpnp.newaxis, :]
@@ -397,9 +489,15 @@ def dpnp_matmul(
397489
x2_shape = x2.shape
398490
res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1])
399491

492+
# handling a special case to provide a similar result to NumPy
493+
if out is not None and x1.shape == (1, 0) and x2.shape == (0, 1):
494+
res_shape = (0,)
495+
appended_axes = []
496+
400497
result = _create_result_array(
401498
x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q
402499
)
500+
403501
# calculate result
404502
if result.size == 0:
405503
pass
@@ -471,12 +569,25 @@ def dpnp_matmul(
471569

472570
if gemm_dtype != res_dtype:
473571
result = dpnp.astype(result, res_dtype, copy=False)
572+
474573
if out is None:
574+
if axes is not None:
575+
# Move the result to the appropriate axes of out array
576+
if len(axes_res) == 2:
577+
result = dpnp.moveaxis(result, (-2, -1), axes_res)
578+
elif len(axes_res) == 1:
579+
result = dpnp.moveaxis(result, (-1,), axes_res)
580+
return result
475581
# If `order` was not passed as default
476582
# we need to update it to match the passed `order`.
477-
if order not in ["k", "K"]:
583+
elif order not in ["k", "K"]:
478584
return dpnp.array(result, copy=False, order=order)
479585
else:
480586
return result
481587
else:
482-
return dpnp.get_result_array(result, out, casting=casting)
588+
result = dpnp.get_result_array(result, out, casting=casting)
589+
if axes is not None:
590+
if out is result:
591+
# out and out_orig contain the same data but they have different shape
592+
return out_orig
593+
return result

0 commit comments

Comments
 (0)