33
33
import dpctl .tensor ._tensor_elementwise_impl as tei
34
34
import dpctl .tensor ._tensor_impl as ti
35
35
import numpy
36
+ from dpctl .utils import ExecutionPlacementError
36
37
from numpy .core .numeric import normalize_axis_tuple
37
38
38
39
import dpnp
@@ -218,7 +219,9 @@ def _compute_size(start, shape):
218
219
return ret
219
220
220
221
221
- def _copy_array (x , dep_events , host_events , copy_flag = False , dtype = None ):
222
+ def _copy_array (
223
+ x , dep_events , host_events , copy_flag = False , dtype = None , order = "C"
224
+ ):
222
225
"""
223
226
Creating a copy of input array if needed.
224
227
@@ -236,7 +239,7 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
236
239
copy = x .dtype != dtype if dtype is not None else False
237
240
238
241
if copy :
239
- x_copy = dpnp .empty_like (x , dtype = dtype , order = "C" )
242
+ x_copy = dpnp .empty_like (x , dtype = dtype , order = order )
240
243
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
241
244
src = dpnp .get_usm_ndarray (x ),
242
245
dst = x_copy .get_array (),
@@ -248,7 +251,9 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
248
251
return x
249
252
250
253
251
- def _create_result_array (x1 , x2 , out , shape , dtype , usm_type , sycl_queue ):
254
+ def _create_result_array (
255
+ x1 , x2 , out , shape , dtype , usm_type , sycl_queue , order = "C"
256
+ ):
252
257
"""
253
258
Create the result array.
254
259
@@ -263,13 +268,12 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
263
268
x1_usm = dpnp .get_usm_ndarray (x1 )
264
269
x2_usm = dpnp .get_usm_ndarray (x2 )
265
270
out_usm = dpnp .get_usm_ndarray (out )
266
- contig_flag = _define_contig_flag (out )
271
+ contig_flag , _ , _ = _define_contig_flag (out )
267
272
268
273
if (
269
274
out .dtype == dtype
270
275
and out .shape == shape
271
276
and out .usm_type == usm_type
272
- and out .sycl_queue == sycl_queue
273
277
and contig_flag
274
278
and not ti ._array_overlap (x1_usm , out_usm )
275
279
and not ti ._array_overlap (x2_usm , out_usm )
@@ -279,6 +283,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
279
283
return dpnp .empty (
280
284
shape ,
281
285
dtype = dtype ,
286
+ order = order ,
282
287
usm_type = usm_type ,
283
288
sycl_queue = sycl_queue ,
284
289
)
@@ -295,14 +300,14 @@ def _define_contig_flag(x):
295
300
x_strides = x .strides
296
301
x_shape = x .shape
297
302
if x .ndim < 2 :
298
- return True
303
+ return True , True , True
299
304
300
305
x_strides = _standardize_strides_to_nonzero (x_strides , x_shape )
301
306
x_is_c_contiguous = x_strides [- 1 ] == 1 and x_strides [- 2 ] == x_shape [- 1 ]
302
307
x_is_f_contiguous = x_strides [- 2 ] == 1 and x_strides [- 1 ] == x_shape [- 2 ]
303
308
if x_is_c_contiguous or x_is_f_contiguous :
304
309
flag = True
305
- return flag
310
+ return flag , x_is_c_contiguous , x_is_f_contiguous
306
311
307
312
308
313
def _define_dim_flags (x , pos ):
@@ -746,17 +751,26 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, dev_tasks_list):
746
751
)
747
752
ht_tasks_list .append (ht_blas_ev )
748
753
dpctl .SyclEvent .wait_for (ht_tasks_list )
754
+
749
755
res_shape = res .shape
750
- if not row_major :
751
- res = dpnp .reshape (
752
- res .ravel (), (batch_size , res_shape [2 ], res_shape [1 ])
753
- ).transpose (0 , 2 , 1 )
756
+ _ , res_is_c_contig , res_is_f_contig = _define_contig_flag (res )
757
+ if row_major :
758
+ if res_is_f_contig :
759
+ res = dpnp .reshape (
760
+ dpnp .ravel (res , order = "F" ),
761
+ (res_shape [1 ], res_shape [2 ], batch_size ),
762
+ ).transpose (2 , 0 , 1 )
763
+ else :
764
+ if res_is_c_contig :
765
+ res = dpnp .reshape (
766
+ dpnp .ravel (res , order = "C" ),
767
+ (batch_size , res_shape [2 ], res_shape [1 ]),
768
+ ).transpose (0 , 2 , 1 )
754
769
755
770
if res_shape != orig_shape :
756
771
res = res .reshape (orig_shape )
757
772
758
- res = dpnp .ascontiguousarray (res )
759
- return res
773
+ return dpnp .ascontiguousarray (res )
760
774
761
775
762
776
def _gemm_matmul (exec_q , x1 , x2 , res , dev_tasks_list ):
@@ -769,14 +783,16 @@ def _gemm_matmul(exec_q, x1, x2, res, dev_tasks_list):
769
783
)
770
784
ht_blas_ev .wait ()
771
785
772
- if not row_major :
773
- # TODO: investigate the possibility of defining result
774
- # array with "F" order for this case
775
- res = dpnp .ascontiguousarray (
776
- dpnp .reshape (res .ravel (), res .shape , order = "F" )
777
- )
786
+ if row_major :
787
+ if res .flags .f_contiguous is True :
788
+ # read data in "F" order and write it in "C" order
789
+ res = dpnp .reshape (dpnp .ravel (res , order = "F" ), res .shape , order = "C" )
790
+ else :
791
+ if res .flags .c_contiguous is True :
792
+ # read data in "C" order and write it in "F" order
793
+ res = dpnp .reshape (dpnp .ravel (res , order = "C" ), res .shape , order = "F" )
778
794
779
- return res
795
+ return dpnp . ascontiguousarray ( res )
780
796
781
797
782
798
def _greedy_path (input_sets , output_set , idx_dict , memory_limit ):
@@ -1746,6 +1762,13 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
1746
1762
)
1747
1763
1748
1764
res_usm_type , exec_q = get_usm_allocations ([a , b ])
1765
+ if (
1766
+ out is not None
1767
+ and dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None
1768
+ ):
1769
+ raise ExecutionPlacementError (
1770
+ "Input and output allocation queues are not compatible"
1771
+ )
1749
1772
1750
1773
# Determine the appropriate data types
1751
1774
dot_dtype , res_dtype = _compute_res_dtype (a , b , sycl_queue = exec_q )
@@ -1812,6 +1835,12 @@ def dpnp_einsum(
1812
1835
arrays .append (a )
1813
1836
1814
1837
res_usm_type , exec_q = get_usm_allocations (arrays )
1838
+ if out is not None :
1839
+ dpnp .check_supported_arrays_type (out )
1840
+ if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
1841
+ raise ExecutionPlacementError (
1842
+ "Input and output allocation queues are not compatible"
1843
+ )
1815
1844
result_dtype = dpnp .result_type (* arrays ) if dtype is None else dtype
1816
1845
for id , a in enumerate (operands ):
1817
1846
if dpnp .isscalar (a ):
@@ -2056,10 +2085,17 @@ def dpnp_matmul(
2056
2085
2057
2086
"""
2058
2087
2059
- x1_ndim = x1 .ndim
2060
- x2_ndim = x2 .ndim
2088
+ dpnp .check_supported_arrays_type (x1 , x2 )
2061
2089
res_usm_type , exec_q = get_usm_allocations ([x1 , x2 ])
2090
+ if out is not None :
2091
+ dpnp .check_supported_arrays_type (out )
2092
+ if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
2093
+ raise ExecutionPlacementError (
2094
+ "Input and output allocation queues are not compatible"
2095
+ )
2062
2096
2097
+ x1_ndim = x1 .ndim
2098
+ x2_ndim = x2 .ndim
2063
2099
if axes is not None :
2064
2100
axes = _validate_axes (x1 , x2 , axes )
2065
2101
@@ -2072,7 +2108,6 @@ def dpnp_matmul(
2072
2108
x2 = dpnp .moveaxis (x2 , axes_x2 , (- 2 , - 1 )) if x2_ndim != 1 else x2
2073
2109
out_orig = out
2074
2110
if out is not None :
2075
- dpnp .check_supported_arrays_type (out )
2076
2111
# out that is passed to the backend should have the correct shape
2077
2112
if len (axes_res ) == 2 :
2078
2113
out = dpnp .moveaxis (out , axes_res , (- 2 , - 1 ))
@@ -2161,8 +2196,18 @@ def dpnp_matmul(
2161
2196
res = dpnp_dot (x1 , x2 , out = out )
2162
2197
res_shape = res .shape
2163
2198
else :
2199
+ x1_contig_flag , _ , x1_f = _define_contig_flag (x1 )
2200
+ x2_contig_flag , _ , x2_f = _define_contig_flag (x2 )
2201
+ res_order = "F" if (x1_f and x2_f and call_flag == "gemm" ) else "C"
2164
2202
res = _create_result_array (
2165
- x1 , x2 , out , res_shape , compute_dtype , res_usm_type , exec_q
2203
+ x1 ,
2204
+ x2 ,
2205
+ out ,
2206
+ res_shape ,
2207
+ compute_dtype ,
2208
+ res_usm_type ,
2209
+ exec_q ,
2210
+ res_order ,
2166
2211
)
2167
2212
2168
2213
# calculate result
@@ -2175,21 +2220,21 @@ def dpnp_matmul(
2175
2220
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
2176
2221
dep_events_list = []
2177
2222
host_tasks_list = []
2178
- contig_flag = _define_contig_flag (x1 )
2179
2223
x1 = _copy_array (
2180
2224
x1 ,
2181
2225
dep_events_list ,
2182
2226
host_tasks_list ,
2183
- copy_flag = not contig_flag ,
2227
+ copy_flag = not x1_contig_flag ,
2184
2228
dtype = compute_dtype ,
2229
+ order = res_order ,
2185
2230
)
2186
- contig_flag = _define_contig_flag (x2 )
2187
2231
x2 = _copy_array (
2188
2232
x2 ,
2189
2233
dep_events_list ,
2190
2234
host_tasks_list ,
2191
- copy_flag = not contig_flag ,
2235
+ copy_flag = not x2_contig_flag ,
2192
2236
dtype = compute_dtype ,
2237
+ order = res_order ,
2193
2238
)
2194
2239
2195
2240
if call_flag == "gemv" :
0 commit comments