@@ -1560,17 +1560,31 @@ def test_take_along_axis():
1560
1560
1561
1561
1562
1562
def test_take_along_axis_validation ():
1563
+ # type check on the first argument
1563
1564
with pytest .raises (TypeError ):
1564
1565
dpt .take_along_axis (tuple (), list ())
1565
1566
get_queue_or_skip ()
1566
- x = dpt .ones (10 )
1567
+ n1 , n2 = 2 , 5
1568
+ x = dpt .ones (n1 * n2 )
1569
+ # type check on the second argument
1567
1570
with pytest .raises (TypeError ):
1568
1571
dpt .take_along_axis (x , list ())
1569
- ind_dt = dpt .__array_namespace_info__ ().default_dtypes (
1570
- device = x .sycl_device
1571
- )["indexing" ]
1572
+ x_dev = x .sycl_device
1573
+ info_ = dpt .__array_namespace_info__ ()
1574
+ def_dtypes = info_ .default_dtypes (device = x_dev )
1575
+ ind_dt = def_dtypes ["indexing" ]
1572
1576
ind = dpt .zeros (1 , dtype = ind_dt )
1577
+ # axis valudation
1573
1578
with pytest .raises (ValueError ):
1574
1579
dpt .take_along_axis (x , ind , axis = 1 )
1580
+ # mode validation
1575
1581
with pytest .raises (ValueError ):
1576
1582
dpt .take_along_axis (x , ind , axis = 0 , mode = "invalid" )
1583
+ # same array-ranks validation
1584
+ with pytest .raises (ValueError ):
1585
+ dpt .take_along_axis (dpt .reshape (x , (n1 , n2 )), ind )
1586
+ # check compute-follows-data
1587
+ q2 = dpctl .SyclQueue (x_dev , property = "enable_profiling" )
1588
+ ind2 = dpt .zeros (1 , dtype = ind_dt , sycl_queue = q2 )
1589
+ with pytest .raises (ExecutionPlacementError ):
1590
+ dpt .take_along_axis (x , ind2 )
0 commit comments