Skip to content

Commit 6592abc

Browse files
committed
Merge branch 'master' into extended_types_support
2 parents 98983bd + db97d59 commit 6592abc

File tree

8 files changed

+297
-261
lines changed

8 files changed

+297
-261
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,10 @@ array_api_tests/test_linalg.py::test_svd
2323
array_api_tests/test_linalg.py::test_qr
2424
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
2525

26-
# unexpected result is returned
26+
# unexpected result is returned - unmute when dpctl-1986 is resolved
2727
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
2828
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
2929

3030
# missing 'correction' keyword argument
3131
array_api_tests/test_signatures.py::test_func_signature[std]
3232
array_api_tests/test_signatures.py::test_func_signature[var]
33-
34-
# arrays have different values
35-
array_api_tests/test_linalg.py::test_linalg_tensordot

.github/workflows/check-mkl-interfaces.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ jobs:
216216
id: run_tests
217217
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
218218
with:
219-
timeout_minutes: 12
219+
timeout_minutes: 15
220220
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
221221
retry_on: any
222222
command: |

.github/workflows/conda-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ jobs:
218218
id: run_tests_linux
219219
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
220220
with:
221-
timeout_minutes: 12
221+
timeout_minutes: 15
222222
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
223223
retry_on: any
224224
command: |
@@ -460,7 +460,7 @@ jobs:
460460
id: run_tests_win
461461
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
462462
with:
463-
timeout_minutes: 15
463+
timeout_minutes: 17
464464
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
465465
retry_on: any
466466
command: |

.github/workflows/cron-run-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ jobs:
126126
id: run_tests_linux
127127
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
128128
with:
129-
timeout_minutes: 12
129+
timeout_minutes: 15
130130
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
131131
retry_on: any
132132
command: |
@@ -143,7 +143,7 @@ jobs:
143143
id: run_tests_win
144144
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
145145
with:
146-
timeout_minutes: 15
146+
timeout_minutes: 17
147147
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
148148
retry_on: any
149149
command: |

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,18 @@ PYBIND11_MODULE(_blas_impl, m)
142142
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
143143
py::arg("vectorY"), py::arg("transpose"),
144144
py::arg("depends") = py::list());
145+
}
146+
147+
{
145148
m.def(
146-
"_row_major_is_available",
147-
[](void) {
148-
#if defined(USE_ONEMKL_CUBLAS)
149-
return false;
150-
#else
149+
"_using_onemkl_interfaces",
150+
[]() {
151+
#ifdef USE_ONEMKL_INTERFACES
151152
return true;
152-
#endif // USE_ONEMKL_CUBLAS
153+
#else
154+
return false;
155+
#endif
153156
},
154-
"Check if the onemkl::blas::row_major can be used.");
157+
"Check if the OneMKL interfaces are being used.");
155158
}
156159
}

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 121 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,23 @@
5050
]
5151

5252

53-
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
53+
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"):
5454
"""
55-
Determines the output array data type and an intermediate data type
56-
used in performing calculations related to a specific math function.
57-
If dtype is ``None``, the output array data type of the operation is
58-
determined based on the Promotion Type Rule and device capabilities.
59-
Otherwise, `dtype` is used as output array dtype, if input arrays
60-
can cast to it according to the casting rule determined. If casting
61-
cannot be done, a ``TypeError`` is raised.
62-
The intermediate data type is the data type used for performing the math
63-
function calculations. If output array dtype is a floating-point data type,
64-
it is also used for the intermediate data type. If output array dtype is an
65-
integral data type, the default floating point data type of the device where
66-
input arrays are allocated on are used for intermediate data type.
55+
Determines the output array data type.
56+
If `dtype` and `out` are ``None``, the output array data type of the
57+
operation is determined based on the Promotion Type Rule and device
58+
capabilities. if `out` is given, its data type is used as the output
59+
array dtypes. Otherwise, `dtype` is used as output array dtype.
60+
If input arrays cannot be cast to the determined output array dtype,
61+
a ``TypeError`` is raised.
6762
6863
Parameters
6964
----------
7065
arrays : {dpnp.ndarray, usm_ndarray}
7166
Input arrays.
7267
dtype : dtype
68+
If not ``None`` and `out` is not defined, data type of the output array.
69+
out : {dpnp.ndarray, usm_ndarray}
7370
If not ``None``, data type of the output array.
7471
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
7572
Controls what kind of data casting may occur.
@@ -78,17 +75,23 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
7875
7976
Returns
8077
-------
81-
compute_dtype, res_dtype :
82-
`compute_dtype` is the data type used in performing math function calculations.
83-
The input arrays of the math function are cast to `compute_dtype` and then
84-
the calculations are performed.
85-
`res_dtype` is the output data type. When the result is obtained, it is cast
86-
to `res_dtype`.
78+
res_dtype :
79+
`res_dtype` is the output data type. When the result is obtained,
80+
it is cast to `res_dtype`.
8781
8882
"""
8983

9084
res_dtype = dpnp.result_type(*arrays)
91-
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)
85+
86+
# If inputs are boolean and `out` is given and it is not boolean, the
87+
# calculation should be performed in boolean and at the end the result
88+
# is cast to out dtype. It is different than general case where the inputs
89+
# are cast to out dtype and then calculation is performed. Even when inputs
90+
# are boolean and `dtype` is given, the casting is done first and then the
91+
# calculation is performed.
92+
if out is not None and res_dtype != dpnp.bool:
93+
# out dtype is prioritized over a given dtype
94+
dtype = out.dtype
9295

9396
if dtype is not None:
9497
if dpnp.can_cast(res_dtype, dtype, casting=casting):
@@ -98,11 +101,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
98101
f"Cannot cast from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
99102
)
100103

101-
compute_dtype = (
102-
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
103-
)
104-
105-
return compute_dtype, res_dtype
104+
return res_dtype
106105

107106

108107
def _copy_array(x, copy_flag=False, dtype=None, order="C"):
@@ -504,6 +503,23 @@ def _gemm_matmul(exec_q, x1, x2, res):
504503
return res
505504

506505

506+
def _gemm_special_case(x1, x2, res_dtype, call_flag):
507+
"""
508+
`gemm` and `gemm_batch` support these special cases of data types
509+
while `gemv` does not.
510+
511+
"""
512+
# TODO: replace with dpnp.int8 when it is added
513+
is_int8 = x1.dtype == numpy.int8 and x2.dtype == numpy.int8
514+
is_int32_or_f32 = res_dtype in [dpnp.int32, dpnp.float32]
515+
flag = is_int8 and is_int32_or_f32 and call_flag in ["gemm", "gemm_batch"]
516+
517+
# onemkl_interfaces does not support these data types
518+
onemkl_interfaces = bi._using_onemkl_interfaces()
519+
520+
return flag and not onemkl_interfaces
521+
522+
507523
def _shape_error(shape1, shape2, func, err_msg):
508524
"""Validate the shapes of input and output arrays."""
509525

@@ -749,17 +765,19 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
749765
_validate_out_array(out, exec_q)
750766

751767
# Determine the appropriate data types
752-
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
768+
res_dtype = _compute_res_dtype(
769+
a, b, out=out, casting=casting, sycl_queue=exec_q
770+
)
753771

754772
result = _create_result_array(
755-
a, b, out, (), dot_dtype, res_usm_type, exec_q
773+
a, b, out, (), res_dtype, res_usm_type, exec_q
756774
)
757775

758776
# input arrays should have the proper data type
759777
if dpnp.issubdtype(res_dtype, dpnp.inexact):
760778
# copying is needed if dtypes of input arrays are different
761-
a = _copy_array(a, dtype=dot_dtype)
762-
b = _copy_array(b, dtype=dot_dtype)
779+
a = _copy_array(a, dtype=res_dtype)
780+
b = _copy_array(b, dtype=res_dtype)
763781

764782
_manager = dpu.SequentialOrderManager[exec_q]
765783

@@ -777,14 +795,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
777795
)
778796
_manager.add_event_pair(ht_ev, dot_ev)
779797
else:
780-
# oneapi::mkl::blas::dot is slow for integer data type,
798+
# oneapi::mkl::blas::dot does not support integer dtypes,
781799
# so using dpctl.tensor.vecdot instead
782-
dpt_a = dpnp.get_usm_ndarray(a)
783-
dpt_b = dpnp.get_usm_ndarray(b)
784-
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b))
785-
786-
if dot_dtype != res_dtype:
787-
result = result.astype(res_dtype, copy=False)
800+
a_usm = dpnp.get_usm_ndarray(a)
801+
b_usm = dpnp.get_usm_ndarray(b)
802+
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(a_usm, b_usm))
788803

789804
return dpnp.get_result_array(result, out, casting=casting)
790805

@@ -902,8 +917,8 @@ def dpnp_multiplication(
902917
axes_res = normalize_axis_tuple(axes_res, len(result_shape), "axes")
903918

904919
# Determine the appropriate data types
905-
compute_dtype, res_dtype = _compute_res_dtype(
906-
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
920+
res_dtype = _compute_res_dtype(
921+
x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q
907922
)
908923

909924
call_flag = None
@@ -998,7 +1013,7 @@ def dpnp_multiplication(
9981013
x2,
9991014
out,
10001015
res_shape,
1001-
compute_dtype,
1016+
res_dtype,
10021017
res_usm_type,
10031018
exec_q,
10041019
res_order,
@@ -1010,64 +1025,82 @@ def dpnp_multiplication(
10101025
elif x1.size == 0 or x2.size == 0:
10111026
result.fill(0)
10121027
else:
1013-
# input arrays should have the proper data type and
1014-
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
1015-
x1 = _copy_array(
1016-
x1,
1017-
copy_flag=not x1_contig_flag,
1018-
dtype=compute_dtype,
1019-
order=res_order,
1020-
)
1021-
x2 = _copy_array(
1022-
x2,
1023-
copy_flag=not x2_contig_flag,
1024-
dtype=compute_dtype,
1025-
order=res_order,
1026-
)
1027-
1028-
if call_flag == "gemv":
1029-
if transpose:
1030-
a_usm = dpnp.get_usm_ndarray(x2)
1031-
x_usm = dpnp.get_usm_ndarray(x1)
1032-
else:
1033-
a_usm = dpnp.get_usm_ndarray(x1)
1034-
x_usm = dpnp.get_usm_ndarray(x2)
1035-
1036-
_manager = dpu.SequentialOrderManager[exec_q]
1037-
1038-
ht_ev, gemv_ev = bi._gemv(
1039-
exec_q,
1040-
a_usm,
1041-
x_usm,
1042-
dpnp.get_usm_ndarray(result),
1043-
transpose,
1044-
depends=_manager.submitted_events,
1028+
if _gemm_special_case(x1, x2, res_dtype, call_flag):
1029+
x1 = _copy_array(
1030+
x1, copy_flag=not x1_contig_flag, order=res_order
10451031
)
1046-
_manager.add_event_pair(ht_ev, gemv_ev)
1047-
elif call_flag == "gemm":
1048-
result = _gemm_matmul(
1049-
exec_q,
1050-
x1,
1051-
x2,
1052-
result,
1032+
x2 = _copy_array(
1033+
x2, copy_flag=not x2_contig_flag, order=res_order
10531034
)
1054-
else: # call_flag == "gemm_batch"
1055-
assert call_flag == "gemm_batch"
1056-
result = _gemm_batch_matmul(
1057-
exec_q,
1035+
if call_flag == "gemm":
1036+
result = _gemm_matmul(exec_q, x1, x2, result)
1037+
else:
1038+
assert call_flag == "gemm_batch"
1039+
result = _gemm_batch_matmul(exec_q, x1, x2, result)
1040+
elif dpnp.issubdtype(res_dtype, dpnp.inexact):
1041+
# copying is needed if dtypes of input arrays are different or
1042+
# their base (last 2-dimensions) is not c-contiguous or f-contiguous
1043+
x1 = _copy_array(
10581044
x1,
1045+
copy_flag=not x1_contig_flag,
1046+
dtype=res_dtype,
1047+
order=res_order,
1048+
)
1049+
x2 = _copy_array(
10591050
x2,
1060-
result,
1051+
copy_flag=not x2_contig_flag,
1052+
dtype=res_dtype,
1053+
order=res_order,
1054+
)
1055+
1056+
if call_flag == "gemv":
1057+
if transpose:
1058+
a_usm = dpnp.get_usm_ndarray(x2)
1059+
x_usm = dpnp.get_usm_ndarray(x1)
1060+
else:
1061+
a_usm = dpnp.get_usm_ndarray(x1)
1062+
x_usm = dpnp.get_usm_ndarray(x2)
1063+
1064+
_manager = dpu.SequentialOrderManager[exec_q]
1065+
1066+
ht_ev, gemv_ev = bi._gemv(
1067+
exec_q,
1068+
a_usm,
1069+
x_usm,
1070+
dpnp.get_usm_ndarray(result),
1071+
transpose,
1072+
depends=_manager.submitted_events,
1073+
)
1074+
_manager.add_event_pair(ht_ev, gemv_ev)
1075+
elif call_flag == "gemm":
1076+
result = _gemm_matmul(exec_q, x1, x2, result)
1077+
else:
1078+
assert call_flag == "gemm_batch"
1079+
result = _gemm_batch_matmul(exec_q, x1, x2, result)
1080+
else:
1081+
# oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
1082+
# except for special cases determined in `_gemm_special_case`,
1083+
# use dpctl.tensor.matmul for unsupported cases
1084+
1085+
# `dpt.matmul` does not support `casting` kwarg.
1086+
# We may need to change input dtypes based on given `casting`.
1087+
# The possibility of casting is already validated in
1088+
# `_compute_res_dtype`.
1089+
x1 = _copy_array(x1, dtype=res_dtype, order=res_order)
1090+
x2 = _copy_array(x2, dtype=res_dtype, order=res_order)
1091+
1092+
x1_usm = dpnp.get_usm_ndarray(x1)
1093+
x2_usm = dpnp.get_usm_ndarray(x2)
1094+
out_usm = dpnp.get_usm_ndarray(result)
1095+
dpt.matmul(
1096+
x1_usm, x2_usm, out=out_usm, dtype=dtype, order=order
10611097
)
10621098

10631099
if NumPy_special_case:
10641100
result = dpnp.tile(result, out.shape)
10651101
elif res_shape != result_shape:
10661102
result = dpnp.reshape(result, result_shape)
10671103

1068-
if compute_dtype != res_dtype:
1069-
result = dpnp.astype(result, res_dtype, copy=False)
1070-
10711104
if out is None:
10721105
if axes is not None:
10731106
# Move the data back to the appropriate axes of the result array
@@ -1207,8 +1240,8 @@ def dpnp_vecdot(
12071240
)
12081241

12091242
# Determine the appropriate data types
1210-
_, res_dtype = _compute_res_dtype(
1211-
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
1243+
res_dtype = _compute_res_dtype(
1244+
x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q
12121245
)
12131246

12141247
_, x1_is_1D, _ = _define_dim_flags(x1, axis=-1)

0 commit comments

Comments
 (0)