21
21
22
22
import dpctl
23
23
import dpctl .tensor as dpt
24
+ from dpctl .tensor ._numpy_helper import AxisError
24
25
from dpctl .tests .helper import get_queue_or_skip
25
26
from dpctl .utils import ExecutionPlacementError
26
27
@@ -59,7 +60,7 @@ def test_permute_dims_0d_1d():
59
60
assert_array_equal (dpt .asnumpy (Y_1d ), dpt .asnumpy (X_1d ))
60
61
61
62
pytest .raises (ValueError , dpt .permute_dims , X_1d , ())
62
- pytest .raises (np . AxisError , dpt .permute_dims , X_1d , (1 ))
63
+ pytest .raises (AxisError , dpt .permute_dims , X_1d , (1 ))
63
64
pytest .raises (ValueError , dpt .permute_dims , X_1d , (1 , 0 ))
64
65
pytest .raises (
65
66
ValueError , dpt .permute_dims , dpt .reshape (X_1d , (2 , 3 )), (1 , 1 )
@@ -105,8 +106,8 @@ def test_expand_dims_0d():
105
106
Ynp = np .expand_dims (Xnp , axis = - 1 )
106
107
assert_array_equal (Ynp , dpt .asnumpy (Y ))
107
108
108
- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = 1 )
109
- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = - 2 )
109
+ pytest .raises (AxisError , dpt .expand_dims , X , axis = 1 )
110
+ pytest .raises (AxisError , dpt .expand_dims , X , axis = - 2 )
110
111
111
112
112
113
@pytest .mark .parametrize ("shapes" , [(3 ,), (3 , 3 ), (3 , 3 , 3 )])
@@ -123,8 +124,8 @@ def test_expand_dims_1d_3d(shapes):
123
124
Ynp = np .expand_dims (Xnp , axis = axis )
124
125
assert_array_equal (Ynp , dpt .asnumpy (Y ))
125
126
126
- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = shape_len + 1 )
127
- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = - shape_len - 2 )
127
+ pytest .raises (AxisError , dpt .expand_dims , X , axis = shape_len + 1 )
128
+ pytest .raises (AxisError , dpt .expand_dims , X , axis = - shape_len - 2 )
128
129
129
130
130
131
@pytest .mark .parametrize (
@@ -145,9 +146,9 @@ def test_expand_dims_incorrect_tuple():
145
146
X = dpt .empty ((3 , 3 , 3 ), dtype = "i4" )
146
147
except dpctl .SyclDeviceCreationError :
147
148
pytest .skip ("No SYCL devices available" )
148
- with pytest .raises (np . AxisError ):
149
+ with pytest .raises (AxisError ):
149
150
dpt .expand_dims (X , axis = (0 , - 6 ))
150
- with pytest .raises (np . AxisError ):
151
+ with pytest .raises (AxisError ):
151
152
dpt .expand_dims (X , axis = (0 , 5 ))
152
153
153
154
with pytest .raises (ValueError ):
@@ -181,10 +182,10 @@ def test_squeeze_0d():
181
182
Ynp = Xnp .squeeze (- 1 )
182
183
assert_array_equal (Ynp , dpt .asnumpy (Y ))
183
184
184
- pytest .raises (np . AxisError , dpt .squeeze , X , 1 )
185
- pytest .raises (np . AxisError , dpt .squeeze , X , - 2 )
186
- pytest .raises (np . AxisError , dpt .squeeze , X , (1 ))
187
- pytest .raises (np . AxisError , dpt .squeeze , X , (- 2 ))
185
+ pytest .raises (AxisError , dpt .squeeze , X , 1 )
186
+ pytest .raises (AxisError , dpt .squeeze , X , - 2 )
187
+ pytest .raises (AxisError , dpt .squeeze , X , (1 ))
188
+ pytest .raises (AxisError , dpt .squeeze , X , (- 2 ))
188
189
pytest .raises (ValueError , dpt .squeeze , X , (0 , 0 ))
189
190
190
191
@@ -446,10 +447,10 @@ def test_flip_axis_incorrect():
446
447
X_np = np .ones ((4 , 4 ))
447
448
X = dpt .asarray (X_np , sycl_queue = q )
448
449
449
- pytest .raises (np . AxisError , dpt .flip , dpt .asarray (np .ones (4 )), axis = 1 )
450
- pytest .raises (np . AxisError , dpt .flip , X , axis = 2 )
451
- pytest .raises (np . AxisError , dpt .flip , X , axis = - 3 )
452
- pytest .raises (np . AxisError , dpt .flip , X , axis = (0 , 3 ))
450
+ pytest .raises (AxisError , dpt .flip , dpt .asarray (np .ones (4 )), axis = 1 )
451
+ pytest .raises (AxisError , dpt .flip , X , axis = 2 )
452
+ pytest .raises (AxisError , dpt .flip , X , axis = - 3 )
453
+ pytest .raises (AxisError , dpt .flip , X , axis = (0 , 3 ))
453
454
454
455
455
456
def test_flip_0d ():
@@ -461,9 +462,9 @@ def test_flip_0d():
461
462
Y = dpt .flip (X )
462
463
assert_array_equal (Ynp , dpt .asnumpy (Y ))
463
464
464
- pytest .raises (np . AxisError , dpt .flip , X , axis = 0 )
465
- pytest .raises (np . AxisError , dpt .flip , X , axis = 1 )
466
- pytest .raises (np . AxisError , dpt .flip , X , axis = - 1 )
465
+ pytest .raises (AxisError , dpt .flip , X , axis = 0 )
466
+ pytest .raises (AxisError , dpt .flip , X , axis = 1 )
467
+ pytest .raises (AxisError , dpt .flip , X , axis = - 1 )
467
468
468
469
469
470
def test_flip_1d ():
@@ -588,9 +589,9 @@ def test_roll_empty():
588
589
Y = dpt .roll (X , 1 )
589
590
Ynp = np .roll (Xnp , 1 )
590
591
assert_array_equal (Ynp , dpt .asnumpy (Y ))
591
- with pytest .raises (np . AxisError ):
592
+ with pytest .raises (AxisError ):
592
593
dpt .roll (X , 1 , axis = 0 )
593
- with pytest .raises (np . AxisError ):
594
+ with pytest .raises (AxisError ):
594
595
dpt .roll (X , 1 , axis = 1 )
595
596
596
597
@@ -1086,13 +1087,13 @@ def test_moveaxis_errors():
1086
1087
pytest .skip ("No SYCL devices available" )
1087
1088
x = dpt .reshape (x_flat , (1 , 2 , 3 ))
1088
1089
assert_raises_regex (
1089
- np . AxisError , "source.*out of bounds" , dpt .moveaxis , x , 3 , 0
1090
+ AxisError , "source.*out of bounds" , dpt .moveaxis , x , 3 , 0
1090
1091
)
1091
1092
assert_raises_regex (
1092
- np . AxisError , "source.*out of bounds" , dpt .moveaxis , x , - 4 , 0
1093
+ AxisError , "source.*out of bounds" , dpt .moveaxis , x , - 4 , 0
1093
1094
)
1094
1095
assert_raises_regex (
1095
- np . AxisError , "destination.*out of bounds" , dpt .moveaxis , x , 0 , 5
1096
+ AxisError , "destination.*out of bounds" , dpt .moveaxis , x , 0 , 5
1096
1097
)
1097
1098
assert_raises_regex (
1098
1099
ValueError , "repeated axis in `source`" , dpt .moveaxis , x , [0 , 0 ], [0 , 1 ]
0 commit comments