|
34 | 34 | __all__ = ["dpnp_dot", "dpnp_matmul"]
|
35 | 35 |
|
36 | 36 |
|
37 |
| -def _op_res_dtype(*arrays, dtype, casting, sycl_queue): |
| 37 | +def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): |
38 | 38 | """
|
39 |
| - _op_res_dtype(*arrays, dtype, casting, sycl_queue) |
40 |
| -
|
41 |
| - Determines the output array data type and an intermediate data type |
42 |
| - used in performing calculations related to a specific math function. |
43 |
| - If dtype is ``None``, the output array data type of the operation is |
44 |
| - determined based on the Promotion Type Rule and device capabilities. |
45 |
| - Otherwise, `dtype` is used as output array dtype, if input arrays |
46 |
| - can cast to it according to the casting rule determined. If casting |
47 |
| - cannot be done, a ``TypeError`` is raised. |
48 |
| - The intermediate data type is the data type used for performing the math |
49 |
| - function calculations. If output array dtype is a floating-point data type, |
50 |
| - it is also used for the intermediate data type. If output array dtype is an |
51 |
| - integral data type, the default floating point data type of the device where |
52 |
| - input arrays are allocated on are used for intermediate data type. |
53 |
| -
|
54 |
| - Parameters |
55 |
| - ---------- |
56 |
| - arrays : {dpnp.ndarray, usm_ndarray} |
57 |
| - Input arrays. |
58 |
| - dtype : dtype |
59 |
| - If not ``None``, data type of the output array. |
60 |
| - casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional |
61 |
| - Controls what kind of data casting may occur. |
62 |
| - sycl_queue : {SyclQueue} |
63 |
| - A SYCL queue to use for determining default floating point datat type. |
| 39 | + Creating a copy of input array if needed. |
64 | 40 |
|
65 |
| - Returns |
66 |
| - ------- |
67 |
| - op_dtype, res_dtype : |
68 |
| - `op_dtype` is the data type used in performing math function calculations. |
69 |
| - The input arrays of the math function are cast to `op_dtype` and then |
70 |
| - the calculations are performed. |
71 |
| - `res_dtype` is the output data type. When the result is obtained, it is cast |
72 |
| - to `res_dtype`. |
| 41 | + If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. |
| 42 | + In this case, the copy array has the input array data type unless `dtype` is |
| 43 | + determined. |
| 44 | + If `contig_copy` is ``False`` and input array data type is different than `dtype`, |
| 45 | + a C-contiguous copy of input array with specified `dtype` is returned. |
73 | 46 |
|
74 | 47 | """
|
75 | 48 |
|
76 |
| - res_dtype = dpnp.result_type(*arrays) |
77 |
| - default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) |
78 |
| - |
79 |
| - if dtype is not None: |
80 |
| - if dpnp.can_cast(res_dtype, dtype, casting=casting): |
81 |
| - res_dtype = dtype |
82 |
| - else: |
83 |
| - raise TypeError( |
84 |
| - f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" |
85 |
| - ) |
86 |
| - |
87 |
| - op_dtype = ( |
88 |
| - res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype |
89 |
| - ) |
| 49 | + if contig_copy: |
| 50 | + copy = contig_copy |
| 51 | + else: |
| 52 | + copy = x.dtype != dtype if dtype is not None else False |
90 | 53 |
|
91 |
| - return op_dtype, res_dtype |
| 54 | + if copy: |
| 55 | + x_copy = dpnp.empty_like(x, dtype=dtype, order="C") |
| 56 | + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( |
| 57 | + src=dpnp.get_usm_ndarray(x), |
| 58 | + dst=x_copy.get_array(), |
| 59 | + sycl_queue=x.sycl_queue, |
| 60 | + ) |
| 61 | + dep_events.append(copy_ev) |
| 62 | + host_events.append(ht_copy_ev) |
| 63 | + return x_copy |
| 64 | + return x |
92 | 65 |
|
93 | 66 |
|
94 | 67 | def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
|
@@ -153,34 +126,61 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
|
153 | 126 | return ht_blas_ev, ht_tasks_list, res
|
154 | 127 |
|
155 | 128 |
|
156 |
| -def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): |
| 129 | +def _op_res_dtype(*arrays, dtype, casting, sycl_queue): |
157 | 130 | """
|
158 |
| - Creating a copy of input array if needed. |
| 131 | + _op_res_dtype(*arrays, dtype, casting, sycl_queue) |
159 | 132 |
|
160 |
| - If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. |
161 |
| - In this case, the copy array has the input array data type unless `dtype` is |
162 |
| - determined. |
163 |
| - If `contig_copy` is ``False`` and input array data type is different than `dtype`, |
164 |
| - a C-contiguous copy of input array with specified `dtype` is returned. |
| 133 | + Determines the output array data type and an intermediate data type |
| 134 | + used in performing calculations related to a specific math function. |
| 135 | + If dtype is ``None``, the output array data type of the operation is |
| 136 | + determined based on the Promotion Type Rule and device capabilities. |
| 137 | + Otherwise, `dtype` is used as output array dtype, if input arrays |
| 138 | + can cast to it according to the casting rule determined. If casting |
| 139 | + cannot be done, a ``TypeError`` is raised. |
| 140 | + The intermediate data type is the data type used for performing the math |
| 141 | + function calculations. If output array dtype is a floating-point data type, |
| 142 | + it is also used for the intermediate data type. If output array dtype is an |
| 143 | + integral data type, the default floating point data type of the device where |
| 144 | + input arrays are allocated on are used for intermediate data type. |
| 145 | +
|
| 146 | + Parameters |
| 147 | + ---------- |
| 148 | + arrays : {dpnp.ndarray, usm_ndarray} |
| 149 | + Input arrays. |
| 150 | + dtype : dtype |
| 151 | + If not ``None``, data type of the output array. |
| 152 | + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional |
| 153 | + Controls what kind of data casting may occur. |
| 154 | + sycl_queue : {SyclQueue} |
| 155 | + A SYCL queue to use for determining default floating point datat type. |
| 156 | +
|
| 157 | + Returns |
| 158 | + ------- |
| 159 | + op_dtype, res_dtype : |
| 160 | + `op_dtype` is the data type used in performing math function calculations. |
| 161 | + The input arrays of the math function are cast to `op_dtype` and then |
| 162 | + the calculations are performed. |
| 163 | + `res_dtype` is the output data type. When the result is obtained, it is cast |
| 164 | + to `res_dtype`. |
165 | 165 |
|
166 | 166 | """
|
167 | 167 |
|
168 |
| - if contig_copy: |
169 |
| - copy = contig_copy |
170 |
| - else: |
171 |
| - copy = x.dtype != dtype if dtype is not None else False |
| 168 | + res_dtype = dpnp.result_type(*arrays) |
| 169 | + default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) |
172 | 170 |
|
173 |
| - if copy: |
174 |
| - x_copy = dpnp.empty_like(x, dtype=dtype, order="C") |
175 |
| - ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( |
176 |
| - src=dpnp.get_usm_ndarray(x), |
177 |
| - dst=x_copy.get_array(), |
178 |
| - sycl_queue=x.sycl_queue, |
179 |
| - ) |
180 |
| - dep_events.append(copy_ev) |
181 |
| - host_events.append(ht_copy_ev) |
182 |
| - return x_copy |
183 |
| - return x |
| 171 | + if dtype is not None: |
| 172 | + if dpnp.can_cast(res_dtype, dtype, casting=casting): |
| 173 | + res_dtype = dtype |
| 174 | + else: |
| 175 | + raise TypeError( |
| 176 | + f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" |
| 177 | + ) |
| 178 | + |
| 179 | + op_dtype = ( |
| 180 | + res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype |
| 181 | + ) |
| 182 | + |
| 183 | + return op_dtype, res_dtype |
184 | 184 |
|
185 | 185 |
|
186 | 186 | def dpnp_dot(
|
@@ -394,6 +394,11 @@ def dpnp_matmul(
|
394 | 394 | dtype=gemm_dtype,
|
395 | 395 | )
|
396 | 396 |
|
| 397 | + # TODO: investigate usage of gemv (gemv_batch) function |
| 398 | + # from BLAS when one of the inputs is a vector to |
| 399 | + # gain performance. |
| 400 | + # TODO: investigate usage of syrk function from BLAS in |
| 401 | + # case of a.T @ a and a @ a.T to gain performance. |
397 | 402 | if x1_is_2D and x2_is_2D:
|
398 | 403 | ht_blas_ev, _ = bi._gemm(
|
399 | 404 | exec_q,
|
|
0 commit comments