Skip to content

Commit 51fd051

Browse files
Factored out device-capabilities into dedicated file
1 parent f498829 commit 51fd051

File tree

4 files changed

+205
-83
lines changed

4 files changed

+205
-83
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pybind11_add_module(${python_module_name} MODULE
2626
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
2727
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
2930
)
3031
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
3132
target_include_directories(${python_module_name}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#include <string>
26+
27+
#include "dpctl4pybind11.hpp"
28+
#include <CL/sycl.hpp>
29+
#include <pybind11/pybind11.h>
30+
#include <pybind11/stl.h>
31+
32+
namespace dpctl
33+
{
34+
namespace tensor
35+
{
36+
namespace py_internal
37+
{
38+
39+
namespace
40+
{
41+
42+
std::string _default_device_fp_type(sycl::device d)
43+
{
44+
if (d.has(sycl::aspect::fp64)) {
45+
return "f8";
46+
}
47+
else {
48+
return "f4";
49+
}
50+
}
51+
52+
std::string _default_device_int_type(sycl::device)
53+
{
54+
return "i8";
55+
}
56+
57+
std::string _default_device_complex_type(sycl::device d)
58+
{
59+
if (d.has(sycl::aspect::fp64)) {
60+
return "c16";
61+
}
62+
else {
63+
return "c8";
64+
}
65+
}
66+
67+
std::string _default_device_bool_type(sycl::device)
68+
{
69+
return "b1";
70+
}
71+
72+
sycl::device _extract_device(py::object arg)
73+
{
74+
auto &api = dpctl::detail::dpctl_capi::get();
75+
76+
PyObject *source = arg.ptr();
77+
if (api.PySyclQueue_Check_(source)) {
78+
sycl::queue q = py::cast<sycl::queue>(arg);
79+
return q.get_device();
80+
}
81+
else if (api.PySyclDevice_Check_(source)) {
82+
return py::cast<sycl::device>(arg);
83+
}
84+
else {
85+
throw py::type_error(
86+
"Expected type `dpctl.SyclQueue` or `dpctl.SyclDevice`.");
87+
}
88+
}
89+
90+
} // namespace
91+
92+
std::string default_device_fp_type(py::object arg)
93+
{
94+
sycl::device d = _extract_device(arg);
95+
return _default_device_fp_type(d);
96+
}
97+
98+
std::string default_device_int_type(py::object arg)
99+
{
100+
sycl::device d = _extract_device(arg);
101+
return _default_device_int_type(d);
102+
}
103+
104+
std::string default_device_bool_type(py::object arg)
105+
{
106+
sycl::device d = _extract_device(arg);
107+
return _default_device_bool_type(d);
108+
}
109+
110+
std::string default_device_complex_type(py::object arg)
111+
{
112+
sycl::device d = _extract_device(arg);
113+
return _default_device_complex_type(d);
114+
}
115+
116+
} // namespace py_internal
117+
} // namespace tensor
118+
} // namespace dpctl
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <string>
27+
28+
#include "dpctl4pybind11.hpp"
29+
#include <CL/sycl.hpp>
30+
#include <pybind11/pybind11.h>
31+
#include <pybind11/stl.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::string default_device_fp_type(py::object);
41+
extern std::string default_device_int_type(py::object);
42+
extern std::string default_device_bool_type(py::object);
43+
extern std::string default_device_complex_type(py::object);
44+
45+
} // namespace py_internal
46+
} // namespace tensor
47+
} // namespace dpctl

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 39 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include <algorithm>
2727
#include <complex>
2828
#include <cstdint>
29-
#include <pybind11/complex.h>
3029
#include <pybind11/pybind11.h>
3130
#include <pybind11/stl.h>
3231
#include <thread>
@@ -37,6 +36,7 @@
3736
#include "copy_and_cast_usm_to_usm.hpp"
3837
#include "copy_for_reshape.hpp"
3938
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
39+
#include "device_support_queries.hpp"
4040
#include "eye_ctor.hpp"
4141
#include "full_ctor.hpp"
4242
#include "linear_sequences.hpp"
@@ -102,36 +102,6 @@ void init_dispatch_vectors(void)
102102
return;
103103
}
104104

105-
std::string get_default_device_fp_type(sycl::device d)
106-
{
107-
if (d.has(sycl::aspect::fp64)) {
108-
return "f8";
109-
}
110-
else {
111-
return "f4";
112-
}
113-
}
114-
115-
std::string get_default_device_int_type(sycl::device)
116-
{
117-
return "i8";
118-
}
119-
120-
std::string get_default_device_complex_type(sycl::device d)
121-
{
122-
if (d.has(sycl::aspect::fp64)) {
123-
return "c16";
124-
}
125-
else {
126-
return "c8";
127-
}
128-
}
129-
130-
std::string get_default_device_bool_type(sycl::device)
131-
{
132-
return "b1";
133-
}
134-
135105
} // namespace
136106

137107
PYBIND11_MODULE(_tensor_impl, m)
@@ -209,57 +179,43 @@ PYBIND11_MODULE(_tensor_impl, m)
209179
py::arg("k"), py::arg("dst"), py::arg("sycl_queue"),
210180
py::arg("depends") = py::list());
211181

212-
m.def("default_device_fp_type", [](sycl::queue q) -> std::string {
213-
return get_default_device_fp_type(q.get_device());
214-
});
215-
m.def("default_device_fp_type_device", [](sycl::device dev) -> std::string {
216-
return get_default_device_fp_type(dev);
217-
});
218-
219-
m.def("default_device_int_type", [](sycl::queue q) -> std::string {
220-
return get_default_device_int_type(q.get_device());
221-
});
222-
m.def("default_device_int_type_device",
223-
[](sycl::device dev) -> std::string {
224-
return get_default_device_int_type(dev);
225-
});
226-
227-
m.def("default_device_bool_type", [](sycl::queue q) -> std::string {
228-
return get_default_device_bool_type(q.get_device());
229-
});
230-
m.def("default_device_bool_type_device",
231-
[](sycl::device dev) -> std::string {
232-
return get_default_device_bool_type(dev);
233-
});
234-
235-
m.def("default_device_complex_type", [](sycl::queue q) -> std::string {
236-
return get_default_device_complex_type(q.get_device());
237-
});
238-
m.def("default_device_complex_type_device",
239-
[](sycl::device dev) -> std::string {
240-
return get_default_device_complex_type(dev);
241-
});
242-
m.def(
243-
"_tril",
244-
[](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst,
245-
py::ssize_t k, sycl::queue exec_q,
246-
const std::vector<sycl::event> depends)
247-
-> std::pair<sycl::event, sycl::event> {
248-
return usm_ndarray_triul(exec_q, src, dst, 'l', k, depends);
249-
},
250-
"Tril helper function.", py::arg("src"), py::arg("dst"),
251-
py::arg("k") = 0, py::arg("sycl_queue"),
252-
py::arg("depends") = py::list());
182+
m.def("default_device_fp_type",
183+
dpctl::tensor::py_internal::default_device_fp_type,
184+
"Gives default floating point type supported by device.",
185+
py::arg("dev"));
186+
187+
m.def("default_device_int_type",
188+
dpctl::tensor::py_internal::default_device_int_type,
189+
"Gives default integer type supported by device.", py::arg("dev"));
190+
191+
m.def("default_device_bool_type",
192+
dpctl::tensor::py_internal::default_device_bool_type,
193+
"Gives default boolean type supported by device.", py::arg("dev"));
194+
195+
m.def("default_device_complex_type",
196+
dpctl::tensor::py_internal::default_device_complex_type,
197+
"Gives default complex floating point type support by device.",
198+
py::arg("dev"));
199+
200+
auto tril_fn = [](dpctl::tensor::usm_ndarray src,
201+
dpctl::tensor::usm_ndarray dst, py::ssize_t k,
202+
sycl::queue exec_q,
203+
const std::vector<sycl::event> depends)
204+
-> std::pair<sycl::event, sycl::event> {
205+
return usm_ndarray_triul(exec_q, src, dst, 'l', k, depends);
206+
};
207+
m.def("_tril", tril_fn, "Tril helper function.", py::arg("src"),
208+
py::arg("dst"), py::arg("k") = 0, py::arg("sycl_queue"),
209+
py::arg("depends") = py::list());
253210

254-
m.def(
255-
"_triu",
256-
[](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst,
257-
py::ssize_t k, sycl::queue exec_q,
258-
const std::vector<sycl::event> depends)
259-
-> std::pair<sycl::event, sycl::event> {
260-
return usm_ndarray_triul(exec_q, src, dst, 'u', k, depends);
261-
},
262-
"Triu helper function.", py::arg("src"), py::arg("dst"),
263-
py::arg("k") = 0, py::arg("sycl_queue"),
264-
py::arg("depends") = py::list());
211+
auto triu_fn = [](dpctl::tensor::usm_ndarray src,
212+
dpctl::tensor::usm_ndarray dst, py::ssize_t k,
213+
sycl::queue exec_q,
214+
const std::vector<sycl::event> depends)
215+
-> std::pair<sycl::event, sycl::event> {
216+
return usm_ndarray_triul(exec_q, src, dst, 'u', k, depends);
217+
};
218+
m.def("_triu", triu_fn, "Triu helper function.", py::arg("src"),
219+
py::arg("dst"), py::arg("k") = 0, py::arg("sycl_queue"),
220+
py::arg("depends") = py::list());
265221
}

0 commit comments

Comments
 (0)