@@ -1433,3 +1433,107 @@ def test_nonzero_dtype():
1433
1433
index_dt = dpt .dtype (ti .default_device_index_type (x .sycl_queue ))
1434
1434
assert idx .dtype == index_dt
1435
1435
assert idy .dtype == index_dt
1436
+
1437
+
1438
+ def test_take_empty_axes ():
1439
+ get_queue_or_skip ()
1440
+
1441
+ x = dpt .ones ((3 , 0 , 4 , 5 , 6 ), dtype = "f4" )
1442
+ inds = dpt .ones (1 , dtype = "i4" )
1443
+
1444
+ with pytest .raises (IndexError ):
1445
+ dpt .take (x , inds , axis = 1 )
1446
+
1447
+ inds = dpt .ones (0 , dtype = "i4" )
1448
+ r = dpt .take (x , inds , axis = 1 )
1449
+ assert r .shape == x .shape
1450
+
1451
+
1452
+ def test_put_empty_axes ():
1453
+ get_queue_or_skip ()
1454
+
1455
+ x = dpt .ones ((3 , 0 , 4 , 5 , 6 ), dtype = "f4" )
1456
+ inds = dpt .ones (1 , dtype = "i4" )
1457
+ vals = dpt .zeros ((3 , 1 , 4 , 5 , 6 ), dtype = "f4" )
1458
+
1459
+ with pytest .raises (IndexError ):
1460
+ dpt .put (x , inds , vals , axis = 1 )
1461
+
1462
+ inds = dpt .ones (0 , dtype = "i4" )
1463
+ vals = dpt .zeros_like (x )
1464
+
1465
+ with pytest .raises (ValueError ):
1466
+ dpt .put (x , inds , vals , axis = 1 )
1467
+
1468
+
1469
+ def test_put_cast_vals ():
1470
+ get_queue_or_skip ()
1471
+
1472
+ x = dpt .arange (10 , dtype = "i4" )
1473
+ inds = dpt .arange (7 , 10 , dtype = "i4" )
1474
+ vals = dpt .zeros_like (inds , dtype = "f4" )
1475
+
1476
+ dpt .put (x , inds , vals )
1477
+ assert dpt .all (x [7 :10 ] == 0 )
1478
+
1479
+
1480
+ def test_advanced_integer_indexing_cast_vals ():
1481
+ get_queue_or_skip ()
1482
+
1483
+ x = dpt .arange (10 , dtype = "i4" )
1484
+ inds = dpt .arange (7 , 10 , dtype = "i4" )
1485
+ vals = dpt .zeros_like (inds , dtype = "f4" )
1486
+
1487
+ x [inds ] = vals
1488
+ assert dpt .all (x [7 :10 ] == 0 )
1489
+
1490
+
1491
+ def test_advanced_integer_indexing_empty_axis ():
1492
+ get_queue_or_skip ()
1493
+
1494
+ # getting
1495
+ x = dpt .ones ((3 , 0 , 4 , 5 , 6 ), dtype = "f4" )
1496
+ inds = dpt .ones (1 , dtype = "i4" )
1497
+ with pytest .raises (IndexError ):
1498
+ x [:, inds , ...]
1499
+ with pytest .raises (IndexError ):
1500
+ x [inds , inds , inds , ...]
1501
+
1502
+ # setting
1503
+ with pytest .raises (IndexError ):
1504
+ x [:, inds , ...] = 2
1505
+ with pytest .raises (IndexError ):
1506
+ x [inds , inds , inds , ...] = 2
1507
+
1508
+ # empty inds
1509
+ inds = dpt .ones (0 , dtype = "i4" )
1510
+ assert x [:, inds , ...].shape == x .shape
1511
+ assert x [inds , inds , inds , ...].shape == (0 , 5 , 6 )
1512
+
1513
+ vals = dpt .zeros_like (x )
1514
+ x [:, inds , ...] = vals
1515
+ vals = dpt .zeros ((0 , 5 , 6 ), dtype = "f4" )
1516
+ x [inds , inds , inds , ...] = vals
1517
+
1518
+
1519
+ def test_advanced_integer_indexing_cast_indices ():
1520
+ get_queue_or_skip ()
1521
+
1522
+ inds0 = dpt .asarray ([0 , 1 ], dtype = "i1" )
1523
+ for ind_dts in (("i1" , "i2" , "i4" ), ("i1" , "u4" , "i4" ), ("u1" , "u2" , "u8" )):
1524
+ x = dpt .ones ((3 , 4 , 5 , 6 ), dtype = "i4" )
1525
+ inds0 = dpt .asarray ([0 , 1 ], dtype = ind_dts [0 ])
1526
+ inds1 = dpt .astype (inds0 , ind_dts [1 ])
1527
+ x [inds0 , inds1 , ...] = 2
1528
+ assert dpt .all (x [inds0 , inds1 , ...] == 2 )
1529
+ inds2 = dpt .astype (inds0 , ind_dts [2 ])
1530
+ x [inds0 , inds1 , ...] = 2
1531
+ assert dpt .all (x [inds0 , inds1 , inds2 , ...] == 2 )
1532
+
1533
+ # fail when float would be required per type promotion
1534
+ inds0 = dpt .asarray ([0 , 1 ], dtype = "i1" )
1535
+ inds1 = dpt .astype (inds0 , "u4" )
1536
+ inds2 = dpt .astype (inds0 , "u8" )
1537
+ x = dpt .ones ((3 , 4 , 5 , 6 ), dtype = "i4" )
1538
+ with pytest .raises (ValueError ):
1539
+ x [inds0 , inds1 , inds2 , ...]
0 commit comments