Skip to content

Commit 0b49fd4

Browse files
Import AxisError from dpctl.tensor._numpy_helper
1 parent eb6330f commit 0b49fd4

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._numpy_helper import AxisError
2425
from dpctl.tests.helper import get_queue_or_skip
2526
from dpctl.utils import ExecutionPlacementError
2627

@@ -59,7 +60,7 @@ def test_permute_dims_0d_1d():
5960
assert_array_equal(dpt.asnumpy(Y_1d), dpt.asnumpy(X_1d))
6061

6162
pytest.raises(ValueError, dpt.permute_dims, X_1d, ())
62-
pytest.raises(np.AxisError, dpt.permute_dims, X_1d, (1))
63+
pytest.raises(AxisError, dpt.permute_dims, X_1d, (1))
6364
pytest.raises(ValueError, dpt.permute_dims, X_1d, (1, 0))
6465
pytest.raises(
6566
ValueError, dpt.permute_dims, dpt.reshape(X_1d, (2, 3)), (1, 1)
@@ -105,8 +106,8 @@ def test_expand_dims_0d():
105106
Ynp = np.expand_dims(Xnp, axis=-1)
106107
assert_array_equal(Ynp, dpt.asnumpy(Y))
107108

108-
pytest.raises(np.AxisError, dpt.expand_dims, X, axis=1)
109-
pytest.raises(np.AxisError, dpt.expand_dims, X, axis=-2)
109+
pytest.raises(AxisError, dpt.expand_dims, X, axis=1)
110+
pytest.raises(AxisError, dpt.expand_dims, X, axis=-2)
110111

111112

112113
@pytest.mark.parametrize("shapes", [(3,), (3, 3), (3, 3, 3)])
@@ -123,8 +124,8 @@ def test_expand_dims_1d_3d(shapes):
123124
Ynp = np.expand_dims(Xnp, axis=axis)
124125
assert_array_equal(Ynp, dpt.asnumpy(Y))
125126

126-
pytest.raises(np.AxisError, dpt.expand_dims, X, axis=shape_len + 1)
127-
pytest.raises(np.AxisError, dpt.expand_dims, X, axis=-shape_len - 2)
127+
pytest.raises(AxisError, dpt.expand_dims, X, axis=shape_len + 1)
128+
pytest.raises(AxisError, dpt.expand_dims, X, axis=-shape_len - 2)
128129

129130

130131
@pytest.mark.parametrize(
@@ -145,9 +146,9 @@ def test_expand_dims_incorrect_tuple():
145146
X = dpt.empty((3, 3, 3), dtype="i4")
146147
except dpctl.SyclDeviceCreationError:
147148
pytest.skip("No SYCL devices available")
148-
with pytest.raises(np.AxisError):
149+
with pytest.raises(AxisError):
149150
dpt.expand_dims(X, axis=(0, -6))
150-
with pytest.raises(np.AxisError):
151+
with pytest.raises(AxisError):
151152
dpt.expand_dims(X, axis=(0, 5))
152153

153154
with pytest.raises(ValueError):
@@ -181,10 +182,10 @@ def test_squeeze_0d():
181182
Ynp = Xnp.squeeze(-1)
182183
assert_array_equal(Ynp, dpt.asnumpy(Y))
183184

184-
pytest.raises(np.AxisError, dpt.squeeze, X, 1)
185-
pytest.raises(np.AxisError, dpt.squeeze, X, -2)
186-
pytest.raises(np.AxisError, dpt.squeeze, X, (1))
187-
pytest.raises(np.AxisError, dpt.squeeze, X, (-2))
185+
pytest.raises(AxisError, dpt.squeeze, X, 1)
186+
pytest.raises(AxisError, dpt.squeeze, X, -2)
187+
pytest.raises(AxisError, dpt.squeeze, X, (1))
188+
pytest.raises(AxisError, dpt.squeeze, X, (-2))
188189
pytest.raises(ValueError, dpt.squeeze, X, (0, 0))
189190

190191

@@ -446,10 +447,10 @@ def test_flip_axis_incorrect():
446447
X_np = np.ones((4, 4))
447448
X = dpt.asarray(X_np, sycl_queue=q)
448449

449-
pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axis=1)
450-
pytest.raises(np.AxisError, dpt.flip, X, axis=2)
451-
pytest.raises(np.AxisError, dpt.flip, X, axis=-3)
452-
pytest.raises(np.AxisError, dpt.flip, X, axis=(0, 3))
450+
pytest.raises(AxisError, dpt.flip, dpt.asarray(np.ones(4)), axis=1)
451+
pytest.raises(AxisError, dpt.flip, X, axis=2)
452+
pytest.raises(AxisError, dpt.flip, X, axis=-3)
453+
pytest.raises(AxisError, dpt.flip, X, axis=(0, 3))
453454

454455

455456
def test_flip_0d():
@@ -461,9 +462,9 @@ def test_flip_0d():
461462
Y = dpt.flip(X)
462463
assert_array_equal(Ynp, dpt.asnumpy(Y))
463464

464-
pytest.raises(np.AxisError, dpt.flip, X, axis=0)
465-
pytest.raises(np.AxisError, dpt.flip, X, axis=1)
466-
pytest.raises(np.AxisError, dpt.flip, X, axis=-1)
465+
pytest.raises(AxisError, dpt.flip, X, axis=0)
466+
pytest.raises(AxisError, dpt.flip, X, axis=1)
467+
pytest.raises(AxisError, dpt.flip, X, axis=-1)
467468

468469

469470
def test_flip_1d():
@@ -588,9 +589,9 @@ def test_roll_empty():
588589
Y = dpt.roll(X, 1)
589590
Ynp = np.roll(Xnp, 1)
590591
assert_array_equal(Ynp, dpt.asnumpy(Y))
591-
with pytest.raises(np.AxisError):
592+
with pytest.raises(AxisError):
592593
dpt.roll(X, 1, axis=0)
593-
with pytest.raises(np.AxisError):
594+
with pytest.raises(AxisError):
594595
dpt.roll(X, 1, axis=1)
595596

596597

@@ -1086,13 +1087,13 @@ def test_moveaxis_errors():
10861087
pytest.skip("No SYCL devices available")
10871088
x = dpt.reshape(x_flat, (1, 2, 3))
10881089
assert_raises_regex(
1089-
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0
1090+
AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0
10901091
)
10911092
assert_raises_regex(
1092-
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0
1093+
AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0
10931094
)
10941095
assert_raises_regex(
1095-
np.AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5
1096+
AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5
10961097
)
10971098
assert_raises_regex(
10981099
ValueError, "repeated axis in `source`", dpt.moveaxis, x, [0, 0], [0, 1]

0 commit comments

Comments
 (0)