Skip to content

Commit 785670e

Browse files
Moved eye constructor out to dedicated implementation file
1 parent 40739b4 commit 785670e

File tree

4 files changed

+188
-77
lines changed

4 files changed

+188
-77
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pybind11_add_module(${python_module_name} MODULE
2323
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
2424
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
2525
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
26+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
2627
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
2728
)
2829
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#include <CL/sycl.hpp>
26+
#include <utility>
27+
#include <vector>
28+
29+
#include "dpctl4pybind11.hpp"
30+
#include <pybind11/pybind11.h>
31+
32+
#include "eye_ctor.hpp"
33+
#include "kernels/constructors.hpp"
34+
#include "utils/type_dispatch.hpp"
35+
36+
namespace py = pybind11;
37+
namespace _ns = dpctl::tensor::detail;
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace py_internal
44+
{
45+
46+
using dpctl::utils::keep_args_alive;
47+
48+
using dpctl::tensor::kernels::constructors::eye_fn_ptr_t;
49+
static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types];
50+
51+
std::pair<sycl::event, sycl::event>
52+
usm_ndarray_eye(py::ssize_t k,
53+
dpctl::tensor::usm_ndarray dst,
54+
sycl::queue exec_q,
55+
const std::vector<sycl::event> &depends)
56+
{
57+
// dst must be 2D
58+
59+
if (dst.get_ndim() != 2) {
60+
throw py::value_error(
61+
"usm_ndarray_eye: Expecting 2D array to populate");
62+
}
63+
64+
sycl::queue dst_q = dst.get_queue();
65+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
66+
throw py::value_error("Execution queue is not compatible with the "
67+
"allocation queue");
68+
}
69+
70+
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
71+
int dst_typenum = dst.get_typenum();
72+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
73+
74+
const py::ssize_t nelem = dst.get_size();
75+
const py::ssize_t rows = dst.get_shape(0);
76+
const py::ssize_t cols = dst.get_shape(1);
77+
if (rows == 0 || cols == 0) {
78+
// nothing to do
79+
return std::make_pair(sycl::event{}, sycl::event{});
80+
}
81+
82+
bool is_dst_c_contig = dst.is_c_contiguous();
83+
bool is_dst_f_contig = dst.is_f_contiguous();
84+
if (!is_dst_c_contig && !is_dst_f_contig) {
85+
throw py::value_error("USM array is not contiguous");
86+
}
87+
88+
py::ssize_t start;
89+
if (is_dst_c_contig) {
90+
start = (k < 0) ? -k * cols : k;
91+
}
92+
else {
93+
start = (k < 0) ? -k : k * rows;
94+
}
95+
96+
const py::ssize_t *strides = dst.get_strides_raw();
97+
py::ssize_t step;
98+
if (strides == nullptr) {
99+
step = (is_dst_c_contig) ? cols + 1 : rows + 1;
100+
}
101+
else {
102+
step = strides[0] + strides[1];
103+
}
104+
105+
const py::ssize_t length = std::min({rows, cols, rows + k, cols - k});
106+
const py::ssize_t end = start + step * (length - 1);
107+
108+
char *dst_data = dst.get_data();
109+
sycl::event eye_event;
110+
111+
auto fn = eye_dispatch_vector[dst_typeid];
112+
113+
eye_event = fn(exec_q, static_cast<size_t>(nelem), start, end, step,
114+
dst_data, depends);
115+
116+
return std::make_pair(keep_args_alive(exec_q, {dst}, {eye_event}),
117+
eye_event);
118+
}
119+
120+
void init_eye_ctor_dispatch_vectors(void)
121+
{
122+
using namespace dpctl::tensor::detail;
123+
using dpctl::tensor::kernels::constructors::EyeFactory;
124+
125+
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb;
126+
dvb.populate_dispatch_vector(eye_dispatch_vector);
127+
128+
return;
129+
}
130+
131+
} // namespace py_internal
132+
} // namespace tensor
133+
} // namespace dpctl
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <utility>
28+
#include <vector>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::pair<sycl::event, sycl::event>
41+
usm_ndarray_eye(py::ssize_t k,
42+
dpctl::tensor::usm_ndarray dst,
43+
sycl::queue exec_q,
44+
const std::vector<sycl::event> &depends = {});
45+
46+
extern void init_eye_ctor_dispatch_vectors(void);
47+
48+
} // namespace py_internal
49+
} // namespace tensor
50+
} // namespace dpctl

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 4 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "copy_and_cast_usm_to_usm.hpp"
4343
#include "copy_for_reshape.hpp"
4444
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
45+
#include "eye_ctor.hpp"
4546
#include "full_ctor.hpp"
4647
#include "linear_sequences.hpp"
4748
#include "simplify_iteration_space.hpp"
@@ -79,78 +80,7 @@ using dpctl::tensor::py_internal::usm_ndarray_full;
7980

8081
/* ================ Eye ================== */
8182

82-
using dpctl::tensor::kernels::constructors::eye_fn_ptr_t;
83-
84-
static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types];
85-
86-
std::pair<sycl::event, sycl::event>
87-
eye(py::ssize_t k,
88-
dpctl::tensor::usm_ndarray dst,
89-
sycl::queue exec_q,
90-
const std::vector<sycl::event> &depends = {})
91-
{
92-
// dst must be 2D
93-
94-
if (dst.get_ndim() != 2) {
95-
throw py::value_error(
96-
"usm_ndarray_eye: Expecting 2D array to populate");
97-
}
98-
99-
sycl::queue dst_q = dst.get_queue();
100-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
101-
throw py::value_error("Execution queue is not compatible with the "
102-
"allocation queue");
103-
}
104-
105-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
106-
int dst_typenum = dst.get_typenum();
107-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
108-
109-
const py::ssize_t nelem = dst.get_size();
110-
const py::ssize_t rows = dst.get_shape(0);
111-
const py::ssize_t cols = dst.get_shape(1);
112-
if (rows == 0 || cols == 0) {
113-
// nothing to do
114-
return std::make_pair(sycl::event{}, sycl::event{});
115-
}
116-
117-
bool is_dst_c_contig = dst.is_c_contiguous();
118-
bool is_dst_f_contig = dst.is_f_contiguous();
119-
if (!is_dst_c_contig && !is_dst_f_contig) {
120-
throw py::value_error("USM array is not contiguous");
121-
}
122-
123-
py::ssize_t start;
124-
if (is_dst_c_contig) {
125-
start = (k < 0) ? -k * cols : k;
126-
}
127-
else {
128-
start = (k < 0) ? -k : k * rows;
129-
}
130-
131-
const py::ssize_t *strides = dst.get_strides_raw();
132-
py::ssize_t step;
133-
if (strides == nullptr) {
134-
step = (is_dst_c_contig) ? cols + 1 : rows + 1;
135-
}
136-
else {
137-
step = strides[0] + strides[1];
138-
}
139-
140-
const py::ssize_t length = std::min({rows, cols, rows + k, cols - k});
141-
const py::ssize_t end = start + step * (length - 1);
142-
143-
char *dst_data = dst.get_data();
144-
sycl::event eye_event;
145-
146-
auto fn = eye_dispatch_vector[dst_typeid];
147-
148-
eye_event = fn(exec_q, static_cast<size_t>(nelem), start, end, step,
149-
dst_data, depends);
150-
151-
return std::make_pair(keep_args_alive(exec_q, {dst}, {eye_event}),
152-
eye_event);
153-
}
83+
using dpctl::tensor::py_internal::usm_ndarray_eye;
15484

15585
/* =========================== Tril and triu ============================== */
15686

@@ -390,15 +320,12 @@ void init_dispatch_vectors(void)
390320
dpctl::tensor::py_internal::init_copy_for_reshape_dispatch_vectors();
391321
dpctl::tensor::py_internal::init_linear_sequences_dispatch_vectors();
392322
dpctl::tensor::py_internal::init_full_ctor_dispatch_vectors();
323+
dpctl::tensor::py_internal::init_eye_ctor_dispatch_vectors();
393324

394325
using namespace dpctl::tensor::detail;
395-
using dpctl::tensor::kernels::constructors::EyeFactory;
396326
using dpctl::tensor::kernels::constructors::TrilGenericFactory;
397327
using dpctl::tensor::kernels::constructors::TriuGenericFactory;
398328

399-
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb4;
400-
dvb4.populate_dispatch_vector(eye_dispatch_vector);
401-
402329
DispatchVectorBuilder<tri_fn_ptr_t, TrilGenericFactory, num_types> dvb5;
403330
dvb5.populate_dispatch_vector(tril_generic_dispatch_vector);
404331

@@ -505,7 +432,7 @@ PYBIND11_MODULE(_tensor_impl, m)
505432
py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"),
506433
py::arg("depends") = py::list());
507434

508-
m.def("_eye", &eye,
435+
m.def("_eye", &usm_ndarray_eye,
509436
"Fills input 2D contiguous usm_ndarray `dst` with "
510437
"zeros outside of the diagonal "
511438
"specified by "

0 commit comments

Comments
 (0)