@@ -1890,3 +1890,75 @@ def test_put_along_axis_uint64_indices():
1890
1890
dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ), axis = 1 )
1891
1891
expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), (2 , 5 ))
1892
1892
assert dpt .all (expected == x )
1893
+
1894
+
1895
+ @pytest .mark .parametrize (
1896
+ "data_dt" ,
1897
+ _all_dtypes ,
1898
+ )
1899
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1900
+ def test_take_out (data_dt , order ):
1901
+ q = get_queue_or_skip ()
1902
+ skip_if_dtype_not_supported (data_dt , q )
1903
+
1904
+ axis = 0
1905
+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1906
+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1907
+ out_sh = x .shape [:axis ] + ind .shape + x .shape [axis + 1 :]
1908
+ out = dpt .empty (out_sh , dtype = data_dt , sycl_queue = q )
1909
+
1910
+ expected = dpt .take (x , ind , axis = axis )
1911
+
1912
+ dpt .take (x , ind , axis = axis , out = out )
1913
+
1914
+ assert dpt .all (out == expected )
1915
+
1916
+
1917
+ @pytest .mark .parametrize (
1918
+ "data_dt" ,
1919
+ _all_dtypes ,
1920
+ )
1921
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1922
+ def test_take_out_overlap (data_dt , order ):
1923
+ q = get_queue_or_skip ()
1924
+ skip_if_dtype_not_supported (data_dt , q )
1925
+
1926
+ axis = 0
1927
+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1928
+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1929
+ out = x [x .shape [axis ] - ind .shape [axis ] : x .shape [axis ], :]
1930
+
1931
+ expected = dpt .take (x , ind , axis = axis )
1932
+
1933
+ dpt .take (x , ind , axis = axis , out = out )
1934
+
1935
+ assert dpt .all (out == expected )
1936
+ assert dpt .all (x [x .shape [0 ] - ind .shape [0 ] : x .shape [0 ], :] == out )
1937
+
1938
+
1939
+ def test_take_out_errors ():
1940
+ q1 = get_queue_or_skip ()
1941
+ q2 = get_queue_or_skip ()
1942
+
1943
+ x = dpt .arange (10 , dtype = "i4" , sycl_queue = q1 )
1944
+ ind = dpt .arange (2 , dtype = "i4" , sycl_queue = q1 )
1945
+
1946
+ with pytest .raises (TypeError ):
1947
+ dpt .take (x , ind , out = dict ())
1948
+
1949
+ out_read_only = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q1 )
1950
+ out_read_only .flags ["W" ] = False
1951
+ with pytest .raises (ValueError ):
1952
+ dpt .take (x , ind , out = out_read_only )
1953
+
1954
+ out_bad_shape = dpt .empty (0 , dtype = x .dtype , sycl_queue = q1 )
1955
+ with pytest .raises (ValueError ):
1956
+ dpt .take (x , ind , out = out_bad_shape )
1957
+
1958
+ out_bad_dt = dpt .empty (ind .shape , dtype = "i8" , sycl_queue = q1 )
1959
+ with pytest .raises (ValueError ):
1960
+ dpt .take (x , ind , out = out_bad_dt )
1961
+
1962
+ out_bad_q = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q2 )
1963
+ with pytest .raises (dpctl .utils .ExecutionPlacementError ):
1964
+ dpt .take (x , ind , out = out_bad_q )
0 commit comments