Skip to content

Commit 20847f2

Browse files
Add test for setting shape by a scalar
Replaced uses of np.prod with math.prod.
1 parent 726432d commit 20847f2

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import ctypes
1818
import numbers
19+
from math import prod
1920

2021
import numpy as np
2122
import pytest
@@ -1102,7 +1103,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
11021103
skip_if_dtype_not_supported(dtype, q)
11031104
shape = (2, 4, 3)
11041105
Xnp = (
1105-
np.random.randint(-10, 10, size=np.prod(shape))
1106+
np.random.randint(-10, 10, size=prod(shape))
11061107
.astype(dtype)
11071108
.reshape(shape)
11081109
)
@@ -1307,6 +1308,10 @@ def relaxed_strides_equal(st1, st2, sh):
13071308
X = dpt.usm_ndarray(sh_s, dtype="?")
13081309
X.shape = sh_f
13091310
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
1311+
sz = X.size
1312+
X.shape = sz
1313+
assert X.shape == (sz,)
1314+
assert relaxed_strides_equal(X.strides, (1,), (sz,))
13101315

13111316
X = dpt.usm_ndarray(sh_s, dtype="u4")
13121317
with pytest.raises(TypeError):
@@ -2077,11 +2082,9 @@ def test_tril(dtype):
20772082
skip_if_dtype_not_supported(dtype, q)
20782083

20792084
shape = (2, 3, 4, 5, 5)
2080-
X = dpt.reshape(
2081-
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
2082-
)
2085+
X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape)
20832086
Y = dpt.tril(X)
2084-
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
2087+
Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape)
20852088
Ynp = np.tril(Xnp)
20862089
assert Y.dtype == Ynp.dtype
20872090
assert np.array_equal(Ynp, dpt.asnumpy(Y))
@@ -2093,11 +2096,9 @@ def test_triu(dtype):
20932096
skip_if_dtype_not_supported(dtype, q)
20942097

20952098
shape = (4, 5)
2096-
X = dpt.reshape(
2097-
dpt.arange(np.prod(shape), dtype=dtype, sycl_queue=q), shape
2098-
)
2099+
X = dpt.reshape(dpt.arange(prod(shape), dtype=dtype, sycl_queue=q), shape)
20992100
Y = dpt.triu(X, k=1)
2100-
Xnp = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
2101+
Xnp = np.arange(prod(shape), dtype=dtype).reshape(shape)
21012102
Ynp = np.triu(Xnp, k=1)
21022103
assert Y.dtype == Ynp.dtype
21032104
assert np.array_equal(Ynp, dpt.asnumpy(Y))
@@ -2110,7 +2111,7 @@ def test_tri_usm_type(tri_fn, usm_type):
21102111
dtype = dpt.uint16
21112112

21122113
shape = (2, 3, 4, 5, 5)
2113-
size = np.prod(shape)
2114+
size = prod(shape)
21142115
X = dpt.reshape(
21152116
dpt.arange(size, dtype=dtype, usm_type=usm_type, sycl_queue=q), shape
21162117
)
@@ -2129,11 +2130,11 @@ def test_tril_slice():
21292130
q = get_queue_or_skip()
21302131

21312132
shape = (6, 10)
2132-
X = dpt.reshape(
2133-
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
2134-
)[1:, ::-2]
2133+
X = dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape)[
2134+
1:, ::-2
2135+
]
21352136
Y = dpt.tril(X)
2136-
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape)[1:, ::-2]
2137+
Xnp = np.arange(prod(shape), dtype="int").reshape(shape)[1:, ::-2]
21372138
Ynp = np.tril(Xnp)
21382139
assert Y.dtype == Ynp.dtype
21392140
assert np.array_equal(Ynp, dpt.asnumpy(Y))
@@ -2144,14 +2145,12 @@ def test_triu_permute_dims():
21442145

21452146
shape = (2, 3, 4, 5)
21462147
X = dpt.permute_dims(
2147-
dpt.reshape(
2148-
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q), shape
2149-
),
2148+
dpt.reshape(dpt.arange(prod(shape), dtype="int", sycl_queue=q), shape),
21502149
(3, 2, 1, 0),
21512150
)
21522151
Y = dpt.triu(X)
21532152
Xnp = np.transpose(
2154-
np.arange(np.prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0)
2153+
np.arange(prod(shape), dtype="int").reshape(shape), (3, 2, 1, 0)
21552154
)
21562155
Ynp = np.triu(Xnp)
21572156
assert Y.dtype == Ynp.dtype
@@ -2189,12 +2188,12 @@ def test_triu_order_k(order, k):
21892188

21902189
shape = (3, 3)
21912190
X = dpt.reshape(
2192-
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
2191+
dpt.arange(prod(shape), dtype="int", sycl_queue=q),
21932192
shape,
21942193
order=order,
21952194
)
21962195
Y = dpt.triu(X, k=k)
2197-
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
2196+
Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order)
21982197
Ynp = np.triu(Xnp, k=k)
21992198
assert Y.dtype == Ynp.dtype
22002199
assert X.flags == Y.flags
@@ -2210,12 +2209,12 @@ def test_tril_order_k(order, k):
22102209
pytest.skip("Queue could not be created")
22112210
shape = (3, 3)
22122211
X = dpt.reshape(
2213-
dpt.arange(np.prod(shape), dtype="int", sycl_queue=q),
2212+
dpt.arange(prod(shape), dtype="int", sycl_queue=q),
22142213
shape,
22152214
order=order,
22162215
)
22172216
Y = dpt.tril(X, k=k)
2218-
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
2217+
Xnp = np.arange(prod(shape), dtype="int").reshape(shape, order=order)
22192218
Ynp = np.tril(Xnp, k=k)
22202219
assert Y.dtype == Ynp.dtype
22212220
assert X.flags == Y.flags

0 commit comments

Comments
 (0)