Skip to content

Commit f5b4393

Browse files
committed
update test_product.py and test_linalg.py
1 parent ed7bba2 commit f5b4393

File tree

5 files changed

+829
-1338
lines changed

5 files changed

+829
-1338
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@
3737
3838
"""
3939

40-
4140
import numpy
42-
from dpctl.tensor._numpy_helper import normalize_axis_tuple
4341

4442
import dpnp
4543

@@ -48,6 +46,7 @@
4846
dpnp_dot,
4947
dpnp_kron,
5048
dpnp_matmul,
49+
dpnp_tensordot,
5150
dpnp_vecdot,
5251
)
5352

@@ -1047,65 +1046,7 @@ def tensordot(a, b, axes=2):
10471046
# TODO: use specific scalar-vector kernel
10481047
return dpnp.multiply(a, b)
10491048

1050-
try:
1051-
iter(axes)
1052-
except Exception as e: # pylint: disable=broad-exception-caught
1053-
if not isinstance(axes, int):
1054-
raise TypeError("Axes must be an integer.") from e
1055-
if axes < 0:
1056-
raise ValueError("Axes must be a non-negative integer.") from e
1057-
axes_a = tuple(range(-axes, 0))
1058-
axes_b = tuple(range(0, axes))
1059-
else:
1060-
if len(axes) != 2:
1061-
raise ValueError("Axes must consist of two sequences.")
1062-
1063-
axes_a, axes_b = axes
1064-
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
1065-
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b
1066-
1067-
if len(axes_a) != len(axes_b):
1068-
raise ValueError("Axes length mismatch.")
1069-
1070-
# Make the axes non-negative
1071-
a_ndim = a.ndim
1072-
b_ndim = b.ndim
1073-
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
1074-
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")
1075-
1076-
if a.ndim == 0 or b.ndim == 0:
1077-
# TODO: use specific scalar-vector kernel
1078-
return dpnp.multiply(a, b)
1079-
1080-
a_shape = a.shape
1081-
b_shape = b.shape
1082-
for axis_a, axis_b in zip(axes_a, axes_b):
1083-
if a_shape[axis_a] != b_shape[axis_b]:
1084-
raise ValueError(
1085-
"shape of input arrays is not similar at requested axes."
1086-
)
1087-
1088-
# Move the axes to sum over, to the end of "a"
1089-
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
1090-
newaxes_a = not_in + axes_a
1091-
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
1092-
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
1093-
newshape_a = (n1, n2)
1094-
olda = [a_shape[axis] for axis in not_in]
1095-
1096-
# Move the axes to sum over, to the front of "b"
1097-
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
1098-
newaxes_b = tuple(axes_b + not_in)
1099-
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
1100-
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
1101-
newshape_b = (n1, n2)
1102-
oldb = [b_shape[axis] for axis in not_in]
1103-
1104-
at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
1105-
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
1106-
res = dpnp.matmul(at, bt)
1107-
1108-
return res.reshape(olda + oldb)
1049+
return dpnp_tensordot(a, b, axes=axes)
11091050

11101051

11111052
def vdot(a, b):

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@
3636
from dpnp.dpnp_array import dpnp_array
3737
from dpnp.dpnp_utils import get_usm_allocations
3838

39-
__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_kron", "dpnp_matmul", "dpnp_vecdot"]
39+
__all__ = [
40+
"dpnp_cross",
41+
"dpnp_dot",
42+
"dpnp_kron",
43+
"dpnp_matmul",
44+
"dpnp_tensordot",
45+
"dpnp_vecdot",
46+
]
4047

4148

4249
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
@@ -974,6 +981,70 @@ def dpnp_matmul(
974981
return result
975982

976983

984+
def dpnp_tensordot(a, b, axes=2):
985+
"""Tensor dot product of two arrays."""
986+
987+
try:
988+
iter(axes)
989+
except Exception as e: # pylint: disable=broad-exception-caught
990+
if not isinstance(axes, int):
991+
raise TypeError("Axes must be an integer.") from e
992+
if axes < 0:
993+
raise ValueError("Axes must be a non-negative integer.") from e
994+
axes_a = tuple(range(-axes, 0))
995+
axes_b = tuple(range(0, axes))
996+
else:
997+
if len(axes) != 2:
998+
raise ValueError("Axes must consist of two sequences.")
999+
1000+
axes_a, axes_b = axes
1001+
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
1002+
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b
1003+
1004+
if len(axes_a) != len(axes_b):
1005+
raise ValueError("Axes length mismatch.")
1006+
1007+
# Make the axes non-negative
1008+
a_ndim = a.ndim
1009+
b_ndim = b.ndim
1010+
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
1011+
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")
1012+
1013+
if a.ndim == 0 or b.ndim == 0:
1014+
# TODO: use specific scalar-vector kernel
1015+
return dpnp.multiply(a, b)
1016+
1017+
a_shape = a.shape
1018+
b_shape = b.shape
1019+
for axis_a, axis_b in zip(axes_a, axes_b):
1020+
if a_shape[axis_a] != b_shape[axis_b]:
1021+
raise ValueError(
1022+
"shape of input arrays is not similar at requested axes."
1023+
)
1024+
1025+
# Move the axes to sum over, to the end of "a"
1026+
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
1027+
newaxes_a = not_in + axes_a
1028+
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
1029+
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
1030+
newshape_a = (n1, n2)
1031+
olda = [a_shape[axis] for axis in not_in]
1032+
1033+
# Move the axes to sum over, to the front of "b"
1034+
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
1035+
newaxes_b = tuple(axes_b + not_in)
1036+
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
1037+
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
1038+
newshape_b = (n1, n2)
1039+
oldb = [b_shape[axis] for axis in not_in]
1040+
1041+
at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
1042+
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
1043+
res = dpnp.matmul(at, bt)
1044+
1045+
return res.reshape(olda + oldb)
1046+
1047+
9771048
def dpnp_vecdot(
9781049
x1,
9791050
x2,

dpnp/tests/helper.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def get_all_dtypes(
162162

163163

164164
def generate_random_numpy_array(
165-
shape, dtype=None, hermitian=False, seed_value=None
165+
shape, dtype=None, hermitian=False, seed_value=None, low=-1, high=1
166166
):
167167
"""
168168
Generate a random numpy array with the specified shape and dtype.
@@ -177,13 +177,19 @@ def generate_random_numpy_array(
177177
dtype : str or dtype, optional
178178
Desired data-type for the output array.
179179
If not specified, data type will be determined by numpy.
180-
Default : None
180+
Default : ``None``
181181
hermitian : bool, optional
182182
If True, generates a Hermitian (symmetric if `dtype` is real) matrix.
183-
Default : False
183+
Default : ``False``
184184
seed_value : int, optional
185185
The seed value to initialize the random number generator.
186-
Default : None
186+
Default : ``None``
187+
low : scalar, optional
188+
Lower boundary of the generated samples from a uniform distribution.
189+
Default : ``-1``.
190+
high : scalar, optional
191+
Upper boundary of the generated samples from a uniform distribution.
192+
Default : ``1``.
187193
188194
Returns
189195
-------
@@ -197,13 +203,15 @@ def generate_random_numpy_array(
197203
198204
"""
199205

200-
numpy.random.seed(seed_value)
206+
numpy.random.seed(seed_value) if seed_value else numpy.random.seed(42)
201207

202-
a = numpy.random.randn(*shape).astype(dtype)
208+
# dtype=int is needed for 0d arrays
209+
size = numpy.prod(shape, dtype=int)
210+
a = numpy.random.uniform(low, high, size).astype(dtype)
203211
if numpy.issubdtype(a.dtype, numpy.complexfloating):
204-
numpy.random.seed(seed_value)
205-
a += 1j * numpy.random.randn(*shape)
212+
a += 1j * numpy.random.uniform(low, high, size)
206213

214+
a = a.reshape(shape)
207215
if hermitian and a.size > 0:
208216
if a.ndim > 2:
209217
orig_shape = a.shape

0 commit comments

Comments
 (0)