Skip to content

Commit 7a94856

Browse files
Moved linear seq. functions to dedicated files
1 parent 6e82889 commit 7a94856

File tree

5 files changed

+243
-119
lines changed

5 files changed

+243
-119
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pybind11_add_module(${python_module_name} MODULE
2222
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
2323
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
2424
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
25+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
2526
)
2627
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
2728
target_include_directories(${python_module_name}

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ template <typename T> T unbox_py_scalar(py::object o)
5252
return py::cast<T>(o);
5353
}
5454

55-
template <> sycl::half unbox_py_scalar<sycl::half>(py::object o)
55+
template <> inline sycl::half unbox_py_scalar<sycl::half>(py::object o)
5656
{
5757
float tmp = py::cast<float>(o);
5858
return static_cast<sycl::half>(tmp);
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#include "dpctl4pybind11.hpp"
26+
#include <CL/sycl.hpp>
27+
#include <complex>
28+
#include <pybind11/complex.h>
29+
#include <pybind11/pybind11.h>
30+
#include <utility>
31+
#include <vector>
32+
33+
#include "kernels/constructors.hpp"
34+
#include "utils/strided_iters.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
38+
#include "linear_sequences.hpp"
39+
40+
namespace py = pybind11;
41+
namespace _ns = dpctl::tensor::detail;
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace py_internal
48+
{
49+
50+
using dpctl::utils::keep_args_alive;
51+
52+
using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t;
53+
54+
static lin_space_step_fn_ptr_t lin_space_step_dispatch_vector[_ns::num_types];
55+
56+
using dpctl::tensor::kernels::constructors::lin_space_affine_fn_ptr_t;
57+
58+
static lin_space_affine_fn_ptr_t
59+
lin_space_affine_dispatch_vector[_ns::num_types];
60+
61+
std::pair<sycl::event, sycl::event>
62+
usm_ndarray_linear_sequence_step(py::object start,
63+
py::object dt,
64+
dpctl::tensor::usm_ndarray dst,
65+
sycl::queue exec_q,
66+
const std::vector<sycl::event> &depends)
67+
{
68+
// dst must be 1D and C-contiguous
69+
// start, end should be coercible into data type of dst
70+
71+
if (dst.get_ndim() != 1) {
72+
throw py::value_error(
73+
"usm_ndarray_linspace: Expecting 1D array to populate");
74+
}
75+
76+
if (!dst.is_c_contiguous()) {
77+
throw py::value_error(
78+
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
79+
}
80+
81+
sycl::queue dst_q = dst.get_queue();
82+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
83+
throw py::value_error(
84+
"Execution queue is not compatible with the allocation queue");
85+
}
86+
87+
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
88+
int dst_typenum = dst.get_typenum();
89+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
90+
91+
py::ssize_t len = dst.get_shape(0);
92+
if (len == 0) {
93+
// nothing to do
94+
return std::make_pair(sycl::event{}, sycl::event{});
95+
}
96+
97+
char *dst_data = dst.get_data();
98+
sycl::event linspace_step_event;
99+
100+
auto fn = lin_space_step_dispatch_vector[dst_typeid];
101+
102+
linspace_step_event =
103+
fn(exec_q, static_cast<size_t>(len), start, dt, dst_data, depends);
104+
105+
return std::make_pair(keep_args_alive(exec_q, {dst}, {linspace_step_event}),
106+
linspace_step_event);
107+
}
108+
109+
std::pair<sycl::event, sycl::event>
110+
usm_ndarray_linear_sequence_affine(py::object start,
111+
py::object end,
112+
dpctl::tensor::usm_ndarray dst,
113+
bool include_endpoint,
114+
sycl::queue exec_q,
115+
const std::vector<sycl::event> &depends)
116+
{
117+
// dst must be 1D and C-contiguous
118+
// start, end should be coercible into data type of dst
119+
120+
if (dst.get_ndim() != 1) {
121+
throw py::value_error(
122+
"usm_ndarray_linspace: Expecting 1D array to populate");
123+
}
124+
125+
if (!dst.is_c_contiguous()) {
126+
throw py::value_error(
127+
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
128+
}
129+
130+
sycl::queue dst_q = dst.get_queue();
131+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
132+
throw py::value_error(
133+
"Execution queue context is not the same as allocation context");
134+
}
135+
136+
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
137+
int dst_typenum = dst.get_typenum();
138+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
139+
140+
py::ssize_t len = dst.get_shape(0);
141+
if (len == 0) {
142+
// nothing to do
143+
return std::make_pair(sycl::event{}, sycl::event{});
144+
}
145+
146+
char *dst_data = dst.get_data();
147+
sycl::event linspace_affine_event;
148+
149+
auto fn = lin_space_affine_dispatch_vector[dst_typeid];
150+
151+
linspace_affine_event = fn(exec_q, static_cast<size_t>(len), start, end,
152+
include_endpoint, dst_data, depends);
153+
154+
return std::make_pair(
155+
keep_args_alive(exec_q, {dst}, {linspace_affine_event}),
156+
linspace_affine_event);
157+
}
158+
159+
void init_linear_sequences_dispatch_vectors(void)
160+
{
161+
using namespace dpctl::tensor::detail;
162+
using dpctl::tensor::kernels::constructors::LinSpaceAffineFactory;
163+
using dpctl::tensor::kernels::constructors::LinSpaceStepFactory;
164+
165+
DispatchVectorBuilder<lin_space_step_fn_ptr_t, LinSpaceStepFactory,
166+
num_types>
167+
dvb1;
168+
dvb1.populate_dispatch_vector(lin_space_step_dispatch_vector);
169+
170+
DispatchVectorBuilder<lin_space_affine_fn_ptr_t, LinSpaceAffineFactory,
171+
num_types>
172+
dvb2;
173+
dvb2.populate_dispatch_vector(lin_space_affine_dispatch_vector);
174+
}
175+
176+
} // namespace py_internal
177+
} // namespace tensor
178+
} // namespace dpctl
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <utility>
28+
#include <vector>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::pair<sycl::event, sycl::event>
41+
usm_ndarray_linear_sequence_step(py::object start,
42+
py::object dt,
43+
dpctl::tensor::usm_ndarray dst,
44+
sycl::queue exec_q,
45+
const std::vector<sycl::event> &depends = {});
46+
47+
extern std::pair<sycl::event, sycl::event> usm_ndarray_linear_sequence_affine(
48+
py::object start,
49+
py::object end,
50+
dpctl::tensor::usm_ndarray dst,
51+
bool include_endpoint,
52+
sycl::queue exec_q,
53+
const std::vector<sycl::event> &depends = {});
54+
55+
extern void init_linear_sequences_dispatch_vectors(void);
56+
57+
} // namespace py_internal
58+
} // namespace tensor
59+
} // namespace dpctl

0 commit comments

Comments
 (0)