Skip to content

Commit ce507b1

Browse files
authored
Merge branch 'master' into impl-ldexp
2 parents 4acb885 + 86daf27 commit ce507b1

File tree

4 files changed

+29
-39
lines changed

4 files changed

+29
-39
lines changed

dpnp/dpnp_utils/dpnp_utils_einsum.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -109,29 +109,6 @@ def _compute_size_by_dict(indices, idx_dict):
109109
return ret
110110

111111

112-
def _compute_size(start, shape):
113-
"""
114-
Compute the total size of a multi-dimensional array starting from a given index.
115-
116-
Parameters
117-
----------
118-
start : int
119-
The starting index from which to compute the size.
120-
shape : tuple
121-
The shape of the multi-dimensional array.
122-
123-
Returns
124-
-------
125-
out : int
126-
The total size of the array.
127-
128-
"""
129-
ret = 1
130-
for i in range(start, len(shape)):
131-
ret *= shape[i]
132-
return ret
133-
134-
135112
def _einsum_diagonals(input_subscripts, operands):
136113
"""
137114
Adopted from _einsum_diagonals in cupy/core/_einsum.py
@@ -818,11 +795,11 @@ def _parse_int_subscript(list_subscript):
818795
"For this input type lists must contain "
819796
"either int or Ellipsis"
820797
) from e
821-
if isinstance(s, int):
822-
if not 0 <= s < len(_einsum_symbols):
823-
raise ValueError(
824-
f"subscript is not within the valid range [0, {len(_einsum_symbols)})."
825-
)
798+
799+
if not 0 <= s < len(_einsum_symbols):
800+
raise ValueError(
801+
f"subscript is not within the valid range [0, {len(_einsum_symbols)})."
802+
)
826803
str_subscript += _einsum_symbols[s]
827804
return str_subscript
828805

@@ -1116,12 +1093,14 @@ def dpnp_einsum(
11161093
f"'{_chr(label)}' which never appeared in an input."
11171094
)
11181095
if len(output_subscript) != len(set(output_subscript)):
1096+
repeated_subscript = []
11191097
for label in output_subscript:
11201098
if output_subscript.count(label) >= 2:
1121-
raise ValueError(
1122-
"einstein sum subscripts string includes output "
1123-
f"subscript '{_chr(label)}' multiple times."
1124-
)
1099+
repeated_subscript.append(_chr(label))
1100+
raise ValueError(
1101+
"einstein sum subscripts string includes output "
1102+
f"subscript {set(repeated_subscript)} multiple times."
1103+
)
11251104

11261105
_einsum_diagonals(input_subscripts, operands)
11271106

dpnp/dpnp_utils/dpnp_utils_statistics.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
import dpnp
28-
from dpnp.dpnp_utils import get_usm_allocations
28+
from dpnp.dpnp_utils import get_usm_allocations, map_dtype_to_device
2929

3030
__all__ = ["dpnp_cov"]
3131

@@ -73,12 +73,7 @@ def _get_2dmin_array(x, dtype):
7373
dtypes.append(y.dtype)
7474
dtype = dpnp.result_type(*dtypes)
7575
# TODO: remove when dpctl.result_type() is returned dtype based on fp64
76-
fp64 = queue.sycl_device.has_aspect_fp64
77-
if not fp64:
78-
if dtype == dpnp.float64:
79-
dtype = dpnp.float32
80-
elif dtype == dpnp.complex128:
81-
dtype = dpnp.complex64
76+
dtype = map_dtype_to_device(dtype, queue.sycl_device)
8277

8378
X = _get_2dmin_array(m, dtype)
8479
if y is not None:

tests/test_linalg.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,10 @@ def test_einsum_error1(self):
719719
# different size for same label 5 != 4
720720
assert_raises(ValueError, inp.einsum, "ii", a)
721721

722+
a = inp.arange(25).reshape(5, 5)
723+
# subscript is not within the valid range [0, 52)
724+
assert_raises(ValueError, inp.einsum, a, [53, 53])
725+
722726
@pytest.mark.parametrize("do_opt", [True, False])
723727
@pytest.mark.parametrize("xp", [numpy, inp])
724728
def test_einsum_error2(self, do_opt, xp):
@@ -1740,6 +1744,17 @@ def test_output_order(self):
17401744
tmp = inp.einsum("...ft,mf->...mt", d, c, order="a", optimize=opt)
17411745
assert tmp.flags.c_contiguous
17421746

1747+
def test_einsum_path(self):
1748+
# Test einsum path for covergae
1749+
a = numpy.random.rand(1, 2, 3, 4)
1750+
b = numpy.random.rand(4, 3, 2, 1)
1751+
a_dp = inp.array(a)
1752+
b_dp = inp.array(b)
1753+
expected = numpy.einsum_path("ijkl,dcba->dcba", a, b)
1754+
result = inp.einsum_path("ijkl,dcba->dcba", a_dp, b_dp)
1755+
assert expected[0] == result[0]
1756+
assert expected[1] == result[1]
1757+
17431758

17441759
class TestInv:
17451760
@pytest.mark.parametrize(

tests/third_party/cupy/linalg_tests/test_product.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def test_dot_with_out(self, xp, dtype_a, dtype_b, dtype_c):
102102
)
103103
)
104104
class TestCrossProduct(unittest.TestCase):
105+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
105106
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"])
106107
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
107108
def test_cross(self, xp, dtype_a, dtype_b):

0 commit comments

Comments
 (0)