Skip to content

Commit 1a3866e

Browse files
authored
update dpnp.vdot implementation (#1692)
* update dpnp_vdot * address comments * address more comments
1 parent 666486f commit 1a3866e

13 files changed

+537
-43
lines changed

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(python_module_name _blas_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/dot.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/dotc.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/dotu.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace py = pybind11;
4040
void init_dispatch_tables(void)
4141
{
4242
blas_ext::init_dot_dispatch_table();
43+
blas_ext::init_dotc_dispatch_table();
4344
blas_ext::init_dotu_dispatch_table();
4445
blas_ext::init_gemm_batch_dispatch_table();
4546
blas_ext::init_gemm_dispatch_table();
@@ -57,6 +58,15 @@ PYBIND11_MODULE(_blas_impl, m)
5758
py::arg("result"), py::arg("depends") = py::list());
5859
}
5960

61+
{
62+
m.def("_dotc", &blas_ext::dotc,
63+
"Call `dotc` from OneMKL LAPACK library to return "
64+
"the dot product of two complex vectors, "
65+
"conjugating the first vector.",
66+
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
67+
py::arg("result"), py::arg("depends") = py::list());
68+
}
69+
6070
{
6171
m.def("_dotu", &blas_ext::dotu,
6272
"Call `dotu` from OneMKL LAPACK library to return "

dpnp/backend/extensions/blas/dot.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ extern std::pair<sycl::event, sycl::event>
4545
dpctl::tensor::usm_ndarray result,
4646
const std::vector<sycl::event> &depends);
4747

48+
extern std::pair<sycl::event, sycl::event>
49+
dotc(sycl::queue &exec_q,
50+
dpctl::tensor::usm_ndarray vectorA,
51+
dpctl::tensor::usm_ndarray vectorB,
52+
dpctl::tensor::usm_ndarray result,
53+
const std::vector<sycl::event> &depends);
54+
4855
extern std::pair<sycl::event, sycl::event>
4956
dotu(sycl::queue &exec_q,
5057
dpctl::tensor::usm_ndarray vectorA,
@@ -53,6 +60,7 @@ extern std::pair<sycl::event, sycl::event>
5360
const std::vector<sycl::event> &depends);
5461

5562
extern void init_dot_dispatch_table(void);
63+
extern void init_dotc_dispatch_table(void);
5664
extern void init_dotu_dispatch_table(void);
5765
} // namespace blas
5866
} // namespace ext

dpnp/backend/extensions/blas/dotc.cpp

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "dot.hpp"
33+
#include "types_matrix.hpp"
34+
35+
#include "dpnp_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace blas
44+
{
45+
namespace mkl_blas = oneapi::mkl::blas;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
49+
typedef sycl::event (*dotc_impl_fn_ptr_t)(sycl::queue &,
50+
const std::int64_t,
51+
char *,
52+
const std::int64_t,
53+
char *,
54+
const std::int64_t,
55+
char *,
56+
const std::vector<sycl::event> &);
57+
58+
static dotc_impl_fn_ptr_t dotc_dispatch_table[dpctl_td_ns::num_types]
59+
[dpctl_td_ns::num_types];
60+
61+
template <typename Tab, typename Tc>
62+
static sycl::event dotc_impl(sycl::queue &exec_q,
63+
const std::int64_t n,
64+
char *vectorA,
65+
const std::int64_t stride_a,
66+
char *vectorB,
67+
const std::int64_t stride_b,
68+
char *result,
69+
const std::vector<sycl::event> &depends)
70+
{
71+
type_utils::validate_type_for_device<Tab>(exec_q);
72+
type_utils::validate_type_for_device<Tc>(exec_q);
73+
74+
Tab *a = reinterpret_cast<Tab *>(vectorA);
75+
Tab *b = reinterpret_cast<Tab *>(vectorB);
76+
Tc *res = reinterpret_cast<Tc *>(result);
77+
78+
std::stringstream error_msg;
79+
bool is_exception_caught = false;
80+
81+
sycl::event dotc_event;
82+
try {
83+
dotc_event = mkl_blas::row_major::dotc(exec_q,
84+
n, // size of the input vectors
85+
a, // Pointer to vector a.
86+
stride_a, // Stride of vector a.
87+
b, // Pointer to vector b.
88+
stride_b, // Stride of vector b.
89+
res, // Pointer to result.
90+
depends);
91+
} catch (oneapi::mkl::exception const &e) {
92+
error_msg
93+
<< "Unexpected MKL exception caught during dotc() call:\nreason: "
94+
<< e.what();
95+
is_exception_caught = true;
96+
} catch (sycl::exception const &e) {
97+
error_msg << "Unexpected SYCL exception caught during dotc() call:\n"
98+
<< e.what();
99+
is_exception_caught = true;
100+
}
101+
102+
if (is_exception_caught) // an unexpected error occurs
103+
{
104+
throw std::runtime_error(error_msg.str());
105+
}
106+
107+
return dotc_event;
108+
}
109+
110+
std::pair<sycl::event, sycl::event>
111+
dotc(sycl::queue &exec_q,
112+
dpctl::tensor::usm_ndarray vectorA,
113+
dpctl::tensor::usm_ndarray vectorB,
114+
dpctl::tensor::usm_ndarray result,
115+
const std::vector<sycl::event> &depends)
116+
{
117+
const int vectorA_nd = vectorA.get_ndim();
118+
const int vectorB_nd = vectorB.get_ndim();
119+
const int result_nd = result.get_ndim();
120+
121+
if ((vectorA_nd != 1)) {
122+
throw py::value_error(
123+
"The first input array has ndim=" + std::to_string(vectorA_nd) +
124+
", but a 1-dimensional array is expected.");
125+
}
126+
127+
if ((vectorB_nd != 1)) {
128+
throw py::value_error(
129+
"The second input array has ndim=" + std::to_string(vectorB_nd) +
130+
", but a 1-dimensional array is expected.");
131+
}
132+
133+
if ((result_nd != 0)) {
134+
throw py::value_error(
135+
"The output array has ndim=" + std::to_string(result_nd) +
136+
", but a 0-dimensional array is expected.");
137+
}
138+
139+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
140+
if (overlap(vectorA, result)) {
141+
throw py::value_error(
142+
"The first input array and output array are overlapping "
143+
"segments of memory");
144+
}
145+
if (overlap(vectorB, result)) {
146+
throw py::value_error(
147+
"The second input array and output array are overlapping "
148+
"segments of memory");
149+
}
150+
151+
// check compatibility of execution queue and allocation queue
152+
if (!dpctl::utils::queues_are_compatible(
153+
exec_q,
154+
{vectorA.get_queue(), vectorB.get_queue(), result.get_queue()}))
155+
{
156+
throw py::value_error(
157+
"USM allocations are not compatible with the execution queue.");
158+
}
159+
160+
py::ssize_t a_size = vectorA.get_size();
161+
py::ssize_t b_size = vectorB.get_size();
162+
if (a_size != b_size) {
163+
throw py::value_error("The size of the first input array must be "
164+
"equal to the size of the second input array.");
165+
}
166+
167+
std::vector<py::ssize_t> a_stride = vectorA.get_strides_vector();
168+
std::vector<py::ssize_t> b_stride = vectorB.get_strides_vector();
169+
170+
const std::int64_t n = a_size;
171+
const std::int64_t str_a = a_stride[0];
172+
const std::int64_t str_b = b_stride[0];
173+
174+
int vectorA_typenum = vectorA.get_typenum();
175+
int vectorB_typenum = vectorB.get_typenum();
176+
int result_typenum = result.get_typenum();
177+
178+
if (vectorA_typenum != vectorB_typenum) {
179+
throw py::value_error(
180+
"Input arrays must be of must be of the same type.");
181+
}
182+
183+
auto array_types = dpctl_td_ns::usm_ndarray_types();
184+
int vectorAB_type_id = array_types.typenum_to_lookup_id(vectorA_typenum);
185+
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
186+
187+
dotc_impl_fn_ptr_t dotc_fn =
188+
dotc_dispatch_table[vectorAB_type_id][result_type_id];
189+
if (dotc_fn == nullptr) {
190+
throw py::value_error(
191+
"Types of input vectors and result array are mismatched.");
192+
}
193+
194+
char *a_typeless_ptr = vectorA.get_data();
195+
char *b_typeless_ptr = vectorB.get_data();
196+
char *r_typeless_ptr = result.get_data();
197+
198+
const int a_elemsize = vectorA.get_elemsize();
199+
const int b_elemsize = vectorB.get_elemsize();
200+
if (str_a < 0) {
201+
a_typeless_ptr -= (n - 1) * std::abs(str_a) * a_elemsize;
202+
}
203+
if (str_b < 0) {
204+
b_typeless_ptr -= (n - 1) * std::abs(str_b) * b_elemsize;
205+
}
206+
207+
sycl::event dotc_ev =
208+
dotc_fn(exec_q, n, a_typeless_ptr, str_a, b_typeless_ptr, str_b,
209+
r_typeless_ptr, depends);
210+
211+
sycl::event args_ev = dpctl::utils::keep_args_alive(
212+
exec_q, {vectorA, vectorB, result}, {dotc_ev});
213+
214+
return std::make_pair(args_ev, dotc_ev);
215+
}
216+
217+
template <typename fnT, typename Tab, typename Tc>
218+
struct DotcContigFactory
219+
{
220+
fnT get()
221+
{
222+
if constexpr (types::DotcTypePairSupportFactory<Tab, Tc>::is_defined) {
223+
return dotc_impl<Tab, Tc>;
224+
}
225+
else {
226+
return nullptr;
227+
}
228+
}
229+
};
230+
231+
void init_dotc_dispatch_table(void)
232+
{
233+
dpctl_td_ns::DispatchTableBuilder<dotc_impl_fn_ptr_t, DotcContigFactory,
234+
dpctl_td_ns::num_types>
235+
contig;
236+
contig.populate_dispatch_table(dotc_dispatch_table);
237+
}
238+
} // namespace blas
239+
} // namespace ext
240+
} // namespace backend
241+
} // namespace dpnp

dpnp/backend/extensions/blas/types_matrix.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,30 @@ struct DotTypePairSupportFactory
6262
dpctl_td_ns::NotDefinedEntry>::is_defined;
6363
};
6464

65+
/**
66+
* @brief A factory to define pairs of supported types for which
67+
* MKL BLAS library provides support in oneapi::mkl::blas::dotc<Tab, Tc>
68+
* function.
69+
*
70+
* @tparam Tab Type of arrays containing input vectors A and B.
71+
* @tparam Tc Type of array containing output.
72+
*/
73+
template <typename Tab, typename Tc>
74+
struct DotcTypePairSupportFactory
75+
{
76+
static constexpr bool is_defined = std::disjunction<
77+
dpctl_td_ns::TypePairDefinedEntry<Tab,
78+
std::complex<float>,
79+
Tc,
80+
std::complex<float>>,
81+
dpctl_td_ns::TypePairDefinedEntry<Tab,
82+
std::complex<double>,
83+
Tc,
84+
std::complex<double>>,
85+
// fall-through
86+
dpctl_td_ns::NotDefinedEntry>::is_defined;
87+
};
88+
6589
/**
6690
* @brief A factory to define pairs of supported types for which
6791
* MKL BLAS library provides support in oneapi::mkl::blas::dotu<Tab, Tc>

dpnp/dpnp_iface.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True):
205205
return dpnp_array._create_from_usm_ndarray(array_obj)
206206

207207

208-
def check_supported_arrays_type(*arrays, scalar_type=False):
208+
def check_supported_arrays_type(*arrays, scalar_type=False, all_scalars=False):
209209
"""
210210
Return ``True`` if each array has either type of scalar,
211211
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
@@ -216,7 +216,9 @@ def check_supported_arrays_type(*arrays, scalar_type=False):
216216
arrays : {dpnp_array, usm_ndarray}
217217
Input arrays to check for supported types.
218218
scalar_type : {bool}, optional
219-
A scalar type is also considered as supported if flag is True.
219+
A scalar type is also considered as supported if flag is ``True``.
220+
all_scalars : {bool}, optional
221+
All the input arrays can be scalar if flag is ``True``.
220222
221223
Returns
222224
-------
@@ -231,13 +233,22 @@ def check_supported_arrays_type(*arrays, scalar_type=False):
231233
232234
"""
233235

236+
any_is_array = False
234237
for a in arrays:
235-
if scalar_type and dpnp.isscalar(a) or is_supported_array_type(a):
238+
if is_supported_array_type(a):
239+
any_is_array = True
240+
continue
241+
elif scalar_type and dpnp.isscalar(a):
236242
continue
237243

238244
raise TypeError(
239245
"An array must be any of supported type, but got {}".format(type(a))
240246
)
247+
248+
if len(arrays) > 1 and not (all_scalars or any_is_array):
249+
raise TypeError(
250+
"At least one input must be of supported array type, but got all scalars."
251+
)
241252
return True
242253

243254

0 commit comments

Comments
 (0)