Skip to content

Commit fa4924a

Browse files
Moved inline functions into separate translation units
Instead of using inline keyword to allow multiple definitions of the same function in different translation units, introduced elementwise_functions_type_utils.cpp that defines these functions and a header file to use in other translatioon units. This should reduce the binary size of the produced object files and simplify the linker's job reducing the link-time.
1 parent 22b04e4 commit fa4924a

File tree

4 files changed

+116
-52
lines changed

4 files changed

+116
-52
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ endif()
3232

3333
set(_elementwise_sources
3434
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_common.cpp
35+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
3536
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp
3637
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
3738
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp

dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <pybind11/stl.h>
3232
#include <vector>
3333

34+
#include "elementwise_functions_type_utils.hpp"
3435
#include "simplify_iteration_space.hpp"
3536
#include "utils/memory_overlap.hpp"
3637
#include "utils/offset_utils.hpp"
@@ -46,56 +47,7 @@ namespace tensor
4647
namespace py_internal
4748
{
4849

49-
namespace
50-
{
51-
inline py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t)
52-
{
53-
switch (dst_typenum_t) {
54-
case td_ns::typenum_t::BOOL:
55-
return py::dtype("?");
56-
case td_ns::typenum_t::INT8:
57-
return py::dtype("i1");
58-
case td_ns::typenum_t::UINT8:
59-
return py::dtype("u1");
60-
case td_ns::typenum_t::INT16:
61-
return py::dtype("i2");
62-
case td_ns::typenum_t::UINT16:
63-
return py::dtype("u2");
64-
case td_ns::typenum_t::INT32:
65-
return py::dtype("i4");
66-
case td_ns::typenum_t::UINT32:
67-
return py::dtype("u4");
68-
case td_ns::typenum_t::INT64:
69-
return py::dtype("i8");
70-
case td_ns::typenum_t::UINT64:
71-
return py::dtype("u8");
72-
case td_ns::typenum_t::HALF:
73-
return py::dtype("f2");
74-
case td_ns::typenum_t::FLOAT:
75-
return py::dtype("f4");
76-
case td_ns::typenum_t::DOUBLE:
77-
return py::dtype("f8");
78-
case td_ns::typenum_t::CFLOAT:
79-
return py::dtype("c8");
80-
case td_ns::typenum_t::CDOUBLE:
81-
return py::dtype("c16");
82-
default:
83-
throw py::value_error("Unrecognized dst_typeid");
84-
}
85-
}
86-
87-
inline int _result_typeid(int arg_typeid, const int *fn_output_id)
88-
{
89-
if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) {
90-
throw py::value_error("Input typeid " + std::to_string(arg_typeid) +
91-
" is outside of expected bounds.");
92-
}
93-
94-
return fn_output_id[arg_typeid];
95-
}
96-
97-
} // end of anonymous namespace
98-
50+
/*! @brief Template implementing Python API for unary elementwise functions */
9951
template <typename output_typesT,
10052
typename contig_dispatchT,
10153
typename strided_dispatchT>
@@ -297,6 +249,8 @@ py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
297249
strided_fn_ev);
298250
}
299251

252+
/*! @brief Template implementing Python API for querying of type support by
253+
* unary elementwise functions */
300254
template <typename output_typesT>
301255
py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
302256
const output_typesT &output_types)
@@ -312,15 +266,17 @@ py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
312266
throw py::value_error(e.what());
313267
}
314268

269+
using dpctl::tensor::py_internal::type_utils::_result_typeid;
315270
int dst_typeid = _result_typeid(src_typeid, output_types);
316271

317272
if (dst_typeid < 0) {
318273
auto res = py::none();
319274
return py::cast<py::object>(res);
320275
}
321276
else {
322-
auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
277+
using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum;
323278

279+
auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
324280
auto dt = _dtype_from_typenum(dst_typenum_t);
325281

326282
return py::cast<py::object>(dt);
@@ -338,6 +294,8 @@ bool isEqual(Container const &c, std::initializer_list<T> const &l)
338294
}
339295
} // namespace
340296

297+
/*! @brief Template implementing Python API for binary elementwise
298+
* functions */
341299
template <typename output_typesT,
342300
typename contig_dispatchT,
343301
typename strided_dispatchT,
@@ -605,6 +563,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
605563
strided_fn_ev);
606564
}
607565

566+
/*! @brief Type querying for binary elementwise functions */
608567
template <typename output_typesT>
609568
py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
610569
const py::dtype &input2_dtype,
@@ -636,8 +595,9 @@ py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
636595
return py::cast<py::object>(res);
637596
}
638597
else {
639-
auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
598+
using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum;
640599

600+
auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
641601
auto dt = _dtype_from_typenum(dst_typenum_t);
642602

643603
return py::cast<py::object>(dt);
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#include "dpctl4pybind11.hpp"
2+
#include <CL/sycl.hpp>
3+
#include <pybind11/numpy.h>
4+
#include <pybind11/pybind11.h>
5+
6+
#include "elementwise_functions_type_utils.hpp"
7+
#include "utils/type_dispatch.hpp"
8+
9+
namespace py = pybind11;
10+
namespace td_ns = dpctl::tensor::type_dispatch;
11+
12+
namespace dpctl
13+
{
14+
namespace tensor
15+
{
16+
namespace py_internal
17+
{
18+
namespace type_utils
19+
{
20+
21+
py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t)
22+
{
23+
switch (dst_typenum_t) {
24+
case td_ns::typenum_t::BOOL:
25+
return py::dtype("?");
26+
case td_ns::typenum_t::INT8:
27+
return py::dtype("i1");
28+
case td_ns::typenum_t::UINT8:
29+
return py::dtype("u1");
30+
case td_ns::typenum_t::INT16:
31+
return py::dtype("i2");
32+
case td_ns::typenum_t::UINT16:
33+
return py::dtype("u2");
34+
case td_ns::typenum_t::INT32:
35+
return py::dtype("i4");
36+
case td_ns::typenum_t::UINT32:
37+
return py::dtype("u4");
38+
case td_ns::typenum_t::INT64:
39+
return py::dtype("i8");
40+
case td_ns::typenum_t::UINT64:
41+
return py::dtype("u8");
42+
case td_ns::typenum_t::HALF:
43+
return py::dtype("f2");
44+
case td_ns::typenum_t::FLOAT:
45+
return py::dtype("f4");
46+
case td_ns::typenum_t::DOUBLE:
47+
return py::dtype("f8");
48+
case td_ns::typenum_t::CFLOAT:
49+
return py::dtype("c8");
50+
case td_ns::typenum_t::CDOUBLE:
51+
return py::dtype("c16");
52+
default:
53+
throw py::value_error("Unrecognized dst_typeid");
54+
}
55+
}
56+
57+
int _result_typeid(int arg_typeid, const int *fn_output_id)
58+
{
59+
if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) {
60+
throw py::value_error("Input typeid " + std::to_string(arg_typeid) +
61+
" is outside of expected bounds.");
62+
}
63+
64+
return fn_output_id[arg_typeid];
65+
}
66+
67+
} // namespace type_utils
68+
} // namespace py_internal
69+
} // namespace tensor
70+
} // namespace dpctl
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#pragma once
4+
#include "dpctl4pybind11.hpp"
5+
#include <CL/sycl.hpp>
6+
#include <pybind11/numpy.h>
7+
#include <pybind11/pybind11.h>
8+
9+
#include "utils/type_dispatch.hpp"
10+
11+
namespace py = pybind11;
12+
namespace td_ns = dpctl::tensor::type_dispatch;
13+
14+
namespace dpctl
15+
{
16+
namespace tensor
17+
{
18+
namespace py_internal
19+
{
20+
namespace type_utils
21+
{
22+
23+
/*! @brief Produce dtype from a type number */
24+
extern py::dtype _dtype_from_typenum(td_ns::typenum_t);
25+
26+
/*! @brief Lookup typeid of the result from typeid of
27+
* argument and the mapping table */
28+
extern int _result_typeid(int, const int *);
29+
30+
} // namespace type_utils
31+
} // namespace py_internal
32+
} // namespace tensor
33+
} // namespace dpctl

0 commit comments

Comments
 (0)