Skip to content

Commit 40739b4

Browse files
Moved full ctor to dedicated file
1 parent 7a94856 commit 40739b4

File tree

4 files changed

+170
-53
lines changed

4 files changed

+170
-53
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pybind11_add_module(${python_module_name} MODULE
2323
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
2424
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
2525
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
26+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
2627
)
2728
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
2829
target_include_directories(${python_module_name}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#include "dpctl4pybind11.hpp"
26+
#include <CL/sycl.hpp>
27+
#include <complex>
28+
#include <pybind11/complex.h>
29+
#include <pybind11/pybind11.h>
30+
#include <utility>
31+
#include <vector>
32+
33+
#include "kernels/constructors.hpp"
34+
#include "utils/strided_iters.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
38+
#include "full_ctor.hpp"
39+
40+
namespace py = pybind11;
41+
namespace _ns = dpctl::tensor::detail;
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace py_internal
48+
{
49+
50+
using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t;
51+
using dpctl::utils::keep_args_alive;
52+
53+
using dpctl::tensor::kernels::constructors::full_contig_fn_ptr_t;
54+
55+
static full_contig_fn_ptr_t full_contig_dispatch_vector[_ns::num_types];
56+
57+
std::pair<sycl::event, sycl::event>
58+
usm_ndarray_full(py::object py_value,
59+
dpctl::tensor::usm_ndarray dst,
60+
sycl::queue exec_q,
61+
const std::vector<sycl::event> &depends)
62+
{
63+
// start, end should be coercible into data type of dst
64+
65+
py::ssize_t dst_nelems = dst.get_size();
66+
67+
if (dst_nelems == 0) {
68+
// nothing to do
69+
return std::make_pair(sycl::event(), sycl::event());
70+
}
71+
72+
sycl::queue dst_q = dst.get_queue();
73+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
74+
throw py::value_error(
75+
"Execution queue is not compatible with the allocation queue");
76+
}
77+
78+
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
79+
int dst_typenum = dst.get_typenum();
80+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
81+
82+
char *dst_data = dst.get_data();
83+
sycl::event full_event;
84+
85+
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
86+
auto fn = full_contig_dispatch_vector[dst_typeid];
87+
88+
sycl::event full_contig_event =
89+
fn(exec_q, static_cast<size_t>(dst_nelems), py_value, dst_data,
90+
depends);
91+
92+
return std::make_pair(
93+
keep_args_alive(exec_q, {dst}, {full_contig_event}),
94+
full_contig_event);
95+
}
96+
else {
97+
throw std::runtime_error(
98+
"Only population of contiguous usm_ndarray objects is supported.");
99+
}
100+
}
101+
102+
void init_full_ctor_dispatch_vectors(void)
103+
{
104+
using namespace dpctl::tensor::detail;
105+
using dpctl::tensor::kernels::constructors::FullContigFactory;
106+
107+
DispatchVectorBuilder<full_contig_fn_ptr_t, FullContigFactory, num_types>
108+
dvb;
109+
dvb.populate_dispatch_vector(full_contig_dispatch_vector);
110+
111+
return;
112+
}
113+
114+
} // namespace py_internal
115+
} // namespace tensor
116+
} // namespace dpctl
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===--------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <utility>
28+
#include <vector>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::pair<sycl::event, sycl::event>
41+
usm_ndarray_full(py::object py_value,
42+
dpctl::tensor::usm_ndarray dst,
43+
sycl::queue exec_q,
44+
const std::vector<sycl::event> &depends = {});
45+
46+
extern void init_full_ctor_dispatch_vectors(void);
47+
48+
} // namespace py_internal
49+
} // namespace tensor
50+
} // namespace dpctl

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "copy_and_cast_usm_to_usm.hpp"
4343
#include "copy_for_reshape.hpp"
4444
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
45+
#include "full_ctor.hpp"
4546
#include "linear_sequences.hpp"
4647
#include "simplify_iteration_space.hpp"
4748

@@ -74,54 +75,7 @@ using dpctl::tensor::py_internal::usm_ndarray_linear_sequence_step;
7475

7576
/* ================ Full ================== */
7677

77-
using dpctl::tensor::kernels::constructors::full_contig_fn_ptr_t;
78-
79-
static full_contig_fn_ptr_t full_contig_dispatch_vector[_ns::num_types];
80-
81-
std::pair<sycl::event, sycl::event>
82-
usm_ndarray_full(py::object py_value,
83-
dpctl::tensor::usm_ndarray dst,
84-
sycl::queue exec_q,
85-
const std::vector<sycl::event> &depends = {})
86-
{
87-
// start, end should be coercible into data type of dst
88-
89-
py::ssize_t dst_nelems = dst.get_size();
90-
91-
if (dst_nelems == 0) {
92-
// nothing to do
93-
return std::make_pair(sycl::event(), sycl::event());
94-
}
95-
96-
sycl::queue dst_q = dst.get_queue();
97-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
98-
throw py::value_error(
99-
"Execution queue is not compatible with the allocation queue");
100-
}
101-
102-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
103-
int dst_typenum = dst.get_typenum();
104-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
105-
106-
char *dst_data = dst.get_data();
107-
sycl::event full_event;
108-
109-
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
110-
auto fn = full_contig_dispatch_vector[dst_typeid];
111-
112-
sycl::event full_contig_event =
113-
fn(exec_q, static_cast<size_t>(dst_nelems), py_value, dst_data,
114-
depends);
115-
116-
return std::make_pair(
117-
keep_args_alive(exec_q, {dst}, {full_contig_event}),
118-
full_contig_event);
119-
}
120-
else {
121-
throw std::runtime_error(
122-
"Only population of contiguous usm_ndarray objects is supported.");
123-
}
124-
}
78+
using dpctl::tensor::py_internal::usm_ndarray_full;
12579

12680
/* ================ Eye ================== */
12781

@@ -435,17 +389,13 @@ void init_dispatch_vectors(void)
435389
{
436390
dpctl::tensor::py_internal::init_copy_for_reshape_dispatch_vectors();
437391
dpctl::tensor::py_internal::init_linear_sequences_dispatch_vectors();
392+
dpctl::tensor::py_internal::init_full_ctor_dispatch_vectors();
438393

439394
using namespace dpctl::tensor::detail;
440395
using dpctl::tensor::kernels::constructors::EyeFactory;
441-
using dpctl::tensor::kernels::constructors::FullContigFactory;
442396
using dpctl::tensor::kernels::constructors::TrilGenericFactory;
443397
using dpctl::tensor::kernels::constructors::TriuGenericFactory;
444398

445-
DispatchVectorBuilder<full_contig_fn_ptr_t, FullContigFactory, num_types>
446-
dvb3;
447-
dvb3.populate_dispatch_vector(full_contig_dispatch_vector);
448-
449399
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb4;
450400
dvb4.populate_dispatch_vector(eye_dispatch_vector);
451401

0 commit comments

Comments
 (0)