@@ -1592,3 +1592,91 @@ def test_take_along_axis_validation():
1592
1592
ind2 = dpt .zeros (1 , dtype = ind_dt , sycl_queue = q2 )
1593
1593
with pytest .raises (ExecutionPlacementError ):
1594
1594
dpt .take_along_axis (x , ind2 )
1595
+
1596
+
1597
+ def check__extract_impl_validation (fn ):
1598
+ x = dpt .ones (10 )
1599
+ ind = dpt .ones (10 , dtype = "?" )
1600
+ with pytest .raises (TypeError ):
1601
+ fn (list (), ind )
1602
+ with pytest .raises (TypeError ):
1603
+ fn (x , list ())
1604
+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1605
+ ind2 = dpt .ones (10 , dtype = "?" , sycl_queue = q2 )
1606
+ with pytest .raises (ExecutionPlacementError ):
1607
+ fn (x , ind2 )
1608
+ with pytest .raises (ValueError ):
1609
+ fn (x , ind , 1 )
1610
+
1611
+
1612
+ def check__nonzero_impl_validation (fn ):
1613
+ with pytest .raises (TypeError ):
1614
+ fn (list ())
1615
+
1616
+
1617
+ def check__take_multi_index (fn ):
1618
+ x = dpt .ones (10 )
1619
+ x_dev = x .sycl_device
1620
+ info_ = dpt .__array_namespace_info__ ()
1621
+ def_dtypes = info_ .default_dtypes (device = x_dev )
1622
+ ind_dt = def_dtypes ["indexing" ]
1623
+ ind = dpt .arange (10 , dtype = ind_dt )
1624
+ with pytest .raises (TypeError ):
1625
+ fn (list (), tuple (), 1 )
1626
+ with pytest .raises (ValueError ):
1627
+ fn (x , (ind ,), 0 , mode = 2 )
1628
+ with pytest .raises (ValueError ):
1629
+ fn (x , (None ,), 1 )
1630
+ with pytest .raises (IndexError ):
1631
+ fn (x , (x ,), 1 )
1632
+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1633
+ ind2 = dpt .arange (10 , dtype = ind_dt , sycl_queue = q2 )
1634
+ with pytest .raises (ExecutionPlacementError ):
1635
+ fn (x , (ind2 ,), 0 )
1636
+ m = dpt .ones ((10 , 10 ))
1637
+ ind_1 = dpt .arange (10 , dtype = "i8" )
1638
+ ind_2 = dpt .arange (10 , dtype = "u8" )
1639
+ with pytest .raises (ValueError ):
1640
+ fn (m , (ind_1 , ind_2 ), 0 )
1641
+
1642
+
1643
+ def check__place_impl_validation (fn ):
1644
+ with pytest .raises (TypeError ):
1645
+ fn (list (), list (), list ())
1646
+ x = dpt .ones (10 )
1647
+ with pytest .raises (TypeError ):
1648
+ fn (x , list (), list ())
1649
+ q2 = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1650
+ mask2 = dpt .ones (10 , dtype = "?" , sycl_queue = q2 )
1651
+ with pytest .raises (ExecutionPlacementError ):
1652
+ fn (x , mask2 , 1 )
1653
+ mask = dpt .ones (x .shape , dtype = "?" )
1654
+ with pytest .raises (ValueError ):
1655
+ fn (x , mask , x , 1 )
1656
+
1657
+
1658
+ def check__put_multi_index_validation (fn ):
1659
+ with pytest .raises (TypeError ):
1660
+ fn (list (), list (), 0 , list ())
1661
+ x = dpt .ones (10 )
1662
+ inds = dpt .arange (10 , dtype = "i8" )
1663
+ vals = dpt .zeros (10 )
1664
+ # test inds which is not a tuple/list
1665
+ fn (x , inds , 0 , vals )
1666
+ x2 = dpt .ones ((5 , 5 ))
1667
+ ind1 = dpt .arange (5 , dtype = "i8" )
1668
+ ind2 = dpt .arange (5 , dtype = "u8" )
1669
+ with pytest .raises (ValueError ):
1670
+ fn (x2 , (ind1 , ind2 ), 0 , x2 )
1671
+
1672
+
1673
+ def test__copy_utils ():
1674
+ import dpctl .tensor ._copy_utils as cu
1675
+
1676
+ get_queue_or_skip ()
1677
+
1678
+ check__extract_impl_validation (cu ._extract_impl )
1679
+ check__nonzero_impl_validation (cu ._nonzero_impl )
1680
+ check__take_multi_index (cu ._take_multi_index )
1681
+ check__place_impl_validation (cu ._place_impl )
1682
+ check__put_multi_index_validation (cu ._put_multi_index )
0 commit comments