Skip to content

Commit ac1fca7

Browse files
vtavanaantonwolfy
andauthored
update dpnp.dot implementation (#1669)
* dot_func * using mkl::dotu instead mkl::dotc for complex * fix a test * fix negative strides * add a temporary workaround * address comments * add a TODO comment * call dpt.vecdot for integer data types * update doc string * pass argument by reference * update doc to add boolean dtype --------- Co-authored-by: Anton <[email protected]>
1 parent 554bcdd commit ac1fca7

25 files changed

+1329
-425
lines changed

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# *****************************************************************************
2-
# Copyright (c) 2016-2023, Intel Corporation
2+
# Copyright (c) 2024, Intel Corporation
33
# All rights reserved.
44
#
55
# Redistribution and use in source and binary forms, with or without
@@ -27,6 +27,8 @@
2727
set(python_module_name _blas_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/dot.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/dotu.cpp
3032
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
3133
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
3234
)

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2023, Intel Corporation
2+
// Copyright (c) 2024, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -30,6 +30,7 @@
3030
#include <pybind11/pybind11.h>
3131
#include <pybind11/stl.h>
3232

33+
#include "dot.hpp"
3334
#include "gemm.hpp"
3435

3536
namespace blas_ext = dpnp::backend::ext::blas;
@@ -38,6 +39,8 @@ namespace py = pybind11;
3839
// populate dispatch tables
3940
void init_dispatch_tables(void)
4041
{
42+
blas_ext::init_dot_dispatch_table();
43+
blas_ext::init_dotu_dispatch_table();
4144
blas_ext::init_gemm_batch_dispatch_table();
4245
blas_ext::init_gemm_dispatch_table();
4346
}
@@ -46,6 +49,22 @@ PYBIND11_MODULE(_blas_impl, m)
4649
{
4750
init_dispatch_tables();
4851

52+
{
53+
m.def("_dot", &blas_ext::dot,
54+
"Call `dot` from OneMKL LAPACK library to return "
55+
"the dot product of two real-valued vectors.",
56+
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
57+
py::arg("result"), py::arg("depends") = py::list());
58+
}
59+
60+
{
61+
m.def("_dotu", &blas_ext::dotu,
62+
"Call `dotu` from OneMKL LAPACK library to return "
63+
"the dot product of two complex vectors.",
64+
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
65+
py::arg("result"), py::arg("depends") = py::list());
66+
}
67+
4968
{
5069
m.def("_gemm", &blas_ext::gemm,
5170
"Call `gemm` from OneMKL LAPACK library to return "

dpnp/backend/extensions/blas/dot.cpp

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "dot.hpp"
33+
#include "types_matrix.hpp"
34+
35+
#include "dpnp_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace blas
44+
{
45+
namespace mkl_blas = oneapi::mkl::blas;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
49+
typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &,
50+
const std::int64_t,
51+
char *,
52+
const std::int64_t,
53+
char *,
54+
const std::int64_t,
55+
char *,
56+
const std::vector<sycl::event> &);
57+
58+
static dot_impl_fn_ptr_t dot_dispatch_table[dpctl_td_ns::num_types]
59+
[dpctl_td_ns::num_types];
60+
61+
template <typename Tab, typename Tc>
62+
static sycl::event dot_impl(sycl::queue &exec_q,
63+
const std::int64_t n,
64+
char *vectorA,
65+
const std::int64_t stride_a,
66+
char *vectorB,
67+
const std::int64_t stride_b,
68+
char *result,
69+
const std::vector<sycl::event> &depends)
70+
{
71+
type_utils::validate_type_for_device<Tab>(exec_q);
72+
type_utils::validate_type_for_device<Tc>(exec_q);
73+
74+
Tab *a = reinterpret_cast<Tab *>(vectorA);
75+
Tab *b = reinterpret_cast<Tab *>(vectorB);
76+
Tc *res = reinterpret_cast<Tc *>(result);
77+
78+
std::stringstream error_msg;
79+
bool is_exception_caught = false;
80+
81+
sycl::event dot_event;
82+
try {
83+
dot_event = mkl_blas::row_major::dot(exec_q,
84+
n, // size of the input vectors
85+
a, // Pointer to vector a.
86+
stride_a, // Stride of vector a.
87+
b, // Pointer to vector b.
88+
stride_b, // Stride of vector b.
89+
res, // Pointer to result.
90+
depends);
91+
} catch (oneapi::mkl::exception const &e) {
92+
error_msg
93+
<< "Unexpected MKL exception caught during dot() call:\nreason: "
94+
<< e.what();
95+
is_exception_caught = true;
96+
} catch (sycl::exception const &e) {
97+
error_msg << "Unexpected SYCL exception caught during dot() call:\n"
98+
<< e.what();
99+
is_exception_caught = true;
100+
}
101+
102+
if (is_exception_caught) // an unexpected error occurs
103+
{
104+
throw std::runtime_error(error_msg.str());
105+
}
106+
107+
return dot_event;
108+
}
109+
110+
std::pair<sycl::event, sycl::event> dot(sycl::queue &exec_q,
111+
dpctl::tensor::usm_ndarray vectorA,
112+
dpctl::tensor::usm_ndarray vectorB,
113+
dpctl::tensor::usm_ndarray result,
114+
const std::vector<sycl::event> &depends)
115+
{
116+
const int vectorA_nd = vectorA.get_ndim();
117+
const int vectorB_nd = vectorB.get_ndim();
118+
const int result_nd = result.get_ndim();
119+
120+
if ((vectorA_nd != 1)) {
121+
throw py::value_error(
122+
"The first input array has ndim=" + std::to_string(vectorA_nd) +
123+
", but a 1-dimensional array is expected.");
124+
}
125+
126+
if ((vectorB_nd != 1)) {
127+
throw py::value_error(
128+
"The second input array has ndim=" + std::to_string(vectorB_nd) +
129+
", but a 1-dimensional array is expected.");
130+
}
131+
132+
if ((result_nd != 0)) {
133+
throw py::value_error(
134+
"The output array has ndim=" + std::to_string(result_nd) +
135+
", but a 0-dimensional array is expected.");
136+
}
137+
138+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
139+
if (overlap(vectorA, result)) {
140+
throw py::value_error(
141+
"The first input array and output array are overlapping "
142+
"segments of memory");
143+
}
144+
if (overlap(vectorB, result)) {
145+
throw py::value_error(
146+
"The second input array and output array are overlapping "
147+
"segments of memory");
148+
}
149+
150+
// check compatibility of execution queue and allocation queue
151+
if (!dpctl::utils::queues_are_compatible(
152+
exec_q,
153+
{vectorA.get_queue(), vectorB.get_queue(), result.get_queue()}))
154+
{
155+
throw py::value_error(
156+
"USM allocations are not compatible with the execution queue.");
157+
}
158+
159+
py::ssize_t a_size = vectorA.get_size();
160+
py::ssize_t b_size = vectorB.get_size();
161+
if (a_size != b_size) {
162+
throw py::value_error("The size of the first input array must be "
163+
"equal to the size of the second input array.");
164+
}
165+
166+
std::vector<py::ssize_t> a_stride = vectorA.get_strides_vector();
167+
std::vector<py::ssize_t> b_stride = vectorB.get_strides_vector();
168+
169+
const std::int64_t n = a_size;
170+
const std::int64_t str_a = a_stride[0];
171+
const std::int64_t str_b = b_stride[0];
172+
173+
int vectorA_typenum = vectorA.get_typenum();
174+
int vectorB_typenum = vectorB.get_typenum();
175+
int result_typenum = result.get_typenum();
176+
177+
if (vectorA_typenum != vectorB_typenum) {
178+
throw py::value_error("vectorA and vectorB must be of the same type.");
179+
}
180+
181+
auto array_types = dpctl_td_ns::usm_ndarray_types();
182+
int vectorAB_type_id = array_types.typenum_to_lookup_id(vectorA_typenum);
183+
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
184+
185+
dot_impl_fn_ptr_t dot_fn =
186+
dot_dispatch_table[vectorAB_type_id][result_type_id];
187+
if (dot_fn == nullptr) {
188+
throw py::value_error(
189+
"Types of input vectors and result array are mismatched.");
190+
}
191+
192+
char *a_typeless_ptr = vectorA.get_data();
193+
char *b_typeless_ptr = vectorB.get_data();
194+
char *r_typeless_ptr = result.get_data();
195+
196+
const int a_elemsize = vectorA.get_elemsize();
197+
const int b_elemsize = vectorB.get_elemsize();
198+
if (str_a < 0) {
199+
a_typeless_ptr -= (n - 1) * std::abs(str_a) * a_elemsize;
200+
}
201+
if (str_b < 0) {
202+
b_typeless_ptr -= (n - 1) * std::abs(str_b) * b_elemsize;
203+
}
204+
205+
sycl::event dot_ev = dot_fn(exec_q, n, a_typeless_ptr, str_a,
206+
b_typeless_ptr, str_b, r_typeless_ptr, depends);
207+
208+
sycl::event args_ev = dpctl::utils::keep_args_alive(
209+
exec_q, {vectorA, vectorB, result}, {dot_ev});
210+
211+
return std::make_pair(args_ev, dot_ev);
212+
}
213+
214+
template <typename fnT, typename Tab, typename Tc>
215+
struct DotContigFactory
216+
{
217+
fnT get()
218+
{
219+
if constexpr (types::DotTypePairSupportFactory<Tab, Tc>::is_defined) {
220+
return dot_impl<Tab, Tc>;
221+
}
222+
else {
223+
return nullptr;
224+
}
225+
}
226+
};
227+
228+
void init_dot_dispatch_table(void)
229+
{
230+
dpctl_td_ns::DispatchTableBuilder<dot_impl_fn_ptr_t, DotContigFactory,
231+
dpctl_td_ns::num_types>
232+
contig;
233+
contig.populate_dispatch_table(dot_dispatch_table);
234+
}
235+
} // namespace blas
236+
} // namespace ext
237+
} // namespace backend
238+
} // namespace dpnp

dpnp/backend/extensions/blas/dot.hpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
#include <oneapi/mkl.hpp>
30+
31+
#include <dpctl4pybind11.hpp>
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace blas
40+
{
41+
extern std::pair<sycl::event, sycl::event>
42+
dot(sycl::queue &exec_q,
43+
dpctl::tensor::usm_ndarray vectorA,
44+
dpctl::tensor::usm_ndarray vectorB,
45+
dpctl::tensor::usm_ndarray result,
46+
const std::vector<sycl::event> &depends);
47+
48+
extern std::pair<sycl::event, sycl::event>
49+
dotu(sycl::queue &exec_q,
50+
dpctl::tensor::usm_ndarray vectorA,
51+
dpctl::tensor::usm_ndarray vectorB,
52+
dpctl::tensor::usm_ndarray result,
53+
const std::vector<sycl::event> &depends);
54+
55+
extern void init_dot_dispatch_table(void);
56+
extern void init_dotu_dispatch_table(void);
57+
} // namespace blas
58+
} // namespace ext
59+
} // namespace backend
60+
} // namespace dpnp

0 commit comments

Comments
 (0)