Skip to content

Commit becd3b7

Browse files
Adds implementation of 6 bitwise elementwise functions
Implements bitwise_invert, bitwise_and, bitwise_or, bitwise_xor, bitwise_left_shift, and bitwise_right_shift Implements Python API in _tensor_impl for these functions.
1 parent 47f4bc9 commit becd3b7

File tree

8 files changed

+1967
-12
lines changed

8 files changed

+1967
-12
lines changed
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
//=== bitwise_and.hpp - Binary function BITWISE_AND -------- *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain in1 copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise bitwise_and(ar1, ar2) operation.
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <cstddef>
28+
#include <cstdint>
29+
#include <type_traits>
30+
31+
#include "utils/offset_utils.hpp"
32+
#include "utils/type_dispatch.hpp"
33+
#include "utils/type_utils.hpp"
34+
35+
#include "kernels/elementwise_functions/common.hpp"
36+
#include <pybind11/pybind11.h>
37+
38+
namespace dpctl
39+
{
40+
namespace tensor
41+
{
42+
namespace kernels
43+
{
44+
namespace bitwise_and
45+
{
46+
47+
namespace py = pybind11;
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
namespace tu_ns = dpctl::tensor::type_utils;
50+
51+
template <typename argT1, typename argT2, typename resT>
52+
struct BitwiseAndFunctor
53+
{
54+
static_assert(std::is_same_v<resT, argT1>);
55+
static_assert(std::is_same_v<resT, argT2>);
56+
57+
using supports_sg_loadstore = typename std::true_type;
58+
using supports_vec = typename std::true_type;
59+
60+
resT operator()(const argT1 &in1, const argT2 &in2)
61+
{
62+
using tu_ns::convert_impl;
63+
64+
if constexpr (std::is_same_v<resT, bool>) {
65+
return in1 && in2;
66+
}
67+
else {
68+
return (in1 & in2);
69+
}
70+
}
71+
72+
template <int vec_sz>
73+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
74+
const sycl::vec<argT2, vec_sz> &in2)
75+
{
76+
77+
if constexpr (std::is_same_v<resT, bool>) {
78+
using dpctl::tensor::type_utils::vec_cast;
79+
80+
auto tmp = (in1 && in2);
81+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
82+
tmp);
83+
}
84+
else {
85+
return (in1 & in2);
86+
}
87+
}
88+
};
89+
90+
template <typename argT1,
91+
typename argT2,
92+
typename resT,
93+
unsigned int vec_sz = 4,
94+
unsigned int n_vecs = 2>
95+
using BitwiseAndContigFunctor = elementwise_common::BinaryContigFunctor<
96+
argT1,
97+
argT2,
98+
resT,
99+
BitwiseAndFunctor<argT1, argT2, resT>,
100+
vec_sz,
101+
n_vecs>;
102+
103+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
104+
using BitwiseAndStridedFunctor = elementwise_common::BinaryStridedFunctor<
105+
argT1,
106+
argT2,
107+
resT,
108+
IndexerT,
109+
BitwiseAndFunctor<argT1, argT2, resT>>;
110+
111+
template <typename T1, typename T2> struct BitwiseAndOutputType
112+
{
113+
using value_type = typename std::disjunction< // disjunction is C++17
114+
// feature, supported by
115+
// DPC++
116+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
117+
td_ns::BinaryTypeMapResultEntry<T1,
118+
std::uint8_t,
119+
T2,
120+
std::uint8_t,
121+
std::uint8_t>,
122+
td_ns::BinaryTypeMapResultEntry<T1,
123+
std::int8_t,
124+
T2,
125+
std::int8_t,
126+
std::int8_t>,
127+
td_ns::BinaryTypeMapResultEntry<T1,
128+
std::uint16_t,
129+
T2,
130+
std::uint16_t,
131+
std::uint16_t>,
132+
td_ns::BinaryTypeMapResultEntry<T1,
133+
std::int16_t,
134+
T2,
135+
std::int16_t,
136+
std::int16_t>,
137+
td_ns::BinaryTypeMapResultEntry<T1,
138+
std::uint32_t,
139+
T2,
140+
std::uint32_t,
141+
std::uint32_t>,
142+
td_ns::BinaryTypeMapResultEntry<T1,
143+
std::int32_t,
144+
T2,
145+
std::int32_t,
146+
std::int32_t>,
147+
td_ns::BinaryTypeMapResultEntry<T1,
148+
std::uint64_t,
149+
T2,
150+
std::uint64_t,
151+
std::uint64_t>,
152+
td_ns::BinaryTypeMapResultEntry<T1,
153+
std::int64_t,
154+
T2,
155+
std::int64_t,
156+
std::int64_t>,
157+
td_ns::DefaultResultEntry<void>>::result_type;
158+
};
159+
160+
template <typename argT1,
161+
typename argT2,
162+
typename resT,
163+
unsigned int vec_sz,
164+
unsigned int n_vecs>
165+
class bitwise_and_contig_kernel;
166+
167+
template <typename argTy1, typename argTy2>
168+
sycl::event
169+
bitwise_and_contig_impl(sycl::queue exec_q,
170+
size_t nelems,
171+
const char *arg1_p,
172+
py::ssize_t arg1_offset,
173+
const char *arg2_p,
174+
py::ssize_t arg2_offset,
175+
char *res_p,
176+
py::ssize_t res_offset,
177+
const std::vector<sycl::event> &depends = {})
178+
{
179+
return elementwise_common::binary_contig_impl<
180+
argTy1, argTy2, BitwiseAndOutputType, BitwiseAndContigFunctor,
181+
bitwise_and_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
182+
arg2_offset, res_p, res_offset, depends);
183+
}
184+
185+
template <typename fnT, typename T1, typename T2> struct BitwiseAndContigFactory
186+
{
187+
fnT get()
188+
{
189+
if constexpr (std::is_same_v<
190+
typename BitwiseAndOutputType<T1, T2>::value_type,
191+
void>)
192+
{
193+
fnT fn = nullptr;
194+
return fn;
195+
}
196+
else {
197+
fnT fn = bitwise_and_contig_impl<T1, T2>;
198+
return fn;
199+
}
200+
}
201+
};
202+
203+
template <typename fnT, typename T1, typename T2>
204+
struct BitwiseAndTypeMapFactory
205+
{
206+
/*! @brief get typeid for output type of operator()>(x, y), always bool
207+
*/
208+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
209+
{
210+
using rT = typename BitwiseAndOutputType<T1, T2>::value_type;
211+
return td_ns::GetTypeid<rT>{}.get();
212+
}
213+
};
214+
215+
template <typename T1, typename T2, typename resT, typename IndexerT>
216+
class bitwise_and_strided_kernel;
217+
218+
template <typename argTy1, typename argTy2>
219+
sycl::event
220+
bitwise_and_strided_impl(sycl::queue exec_q,
221+
size_t nelems,
222+
int nd,
223+
const py::ssize_t *shape_and_strides,
224+
const char *arg1_p,
225+
py::ssize_t arg1_offset,
226+
const char *arg2_p,
227+
py::ssize_t arg2_offset,
228+
char *res_p,
229+
py::ssize_t res_offset,
230+
const std::vector<sycl::event> &depends,
231+
const std::vector<sycl::event> &additional_depends)
232+
{
233+
return elementwise_common::binary_strided_impl<
234+
argTy1, argTy2, BitwiseAndOutputType, BitwiseAndStridedFunctor,
235+
bitwise_and_strided_kernel>(
236+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
237+
arg2_offset, res_p, res_offset, depends, additional_depends);
238+
}
239+
240+
template <typename fnT, typename T1, typename T2>
241+
struct BitwiseAndStridedFactory
242+
{
243+
fnT get()
244+
{
245+
if constexpr (std::is_same_v<
246+
typename BitwiseAndOutputType<T1, T2>::value_type,
247+
void>)
248+
{
249+
fnT fn = nullptr;
250+
return fn;
251+
}
252+
else {
253+
fnT fn = bitwise_and_strided_impl<T1, T2>;
254+
return fn;
255+
}
256+
}
257+
};
258+
259+
} // namespace bitwise_and
260+
} // namespace kernels
261+
} // namespace tensor
262+
} // namespace dpctl

0 commit comments

Comments
 (0)