Skip to content

Commit 79660f4

Browse files
committed
Implements logical operators and, or, not, and xor
1 parent 7c1d147 commit 79660f4

File tree

4 files changed

+1009
-0
lines changed

4 files changed

+1009
-0
lines changed
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
//=== logical_and.hpp - Binary function GREATER ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain in1 copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of comparison of
24+
/// tensor elements.
25+
//===---------------------------------------------------------------------===//
26+
27+
#pragma once
28+
#include <CL/sycl.hpp>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "utils/offset_utils.hpp"
34+
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
36+
37+
#include "kernels/elementwise_functions/common.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace logical_and
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace tu_ns = dpctl::tensor::type_utils;
52+
53+
template <typename argT1, typename argT2, typename resT>
54+
struct LogicalAndFunctor
55+
{
56+
static_assert(std::is_same_v<resT, bool>);
57+
58+
using supports_sg_loadstore = std::negation<
59+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
60+
using supports_vec = std::conjunction<
61+
std::is_same<argT1, argT2>,
62+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
63+
tu_ns::is_complex<argT2>>>>;
64+
65+
resT operator()(const argT1 &in1, const argT2 &in2)
66+
{
67+
if constexpr (tu_ns::is_complex<argT1> && tu_ns::is_complex<argT2>) {
68+
static_assert(std::is_same_v<argT1, argT2>);
69+
70+
return (std::real(in1) || std::imag(in1)) &&
71+
(std::real(in2) || std::imag(in2));
72+
}
73+
else {
74+
return (in1 && in2);
75+
}
76+
77+
template <int vec_sz>
78+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
79+
const sycl::vec<argT2, vec_sz> &in2)
80+
{
81+
82+
auto tmp = (in1 && in2);
83+
84+
if constexpr (std::is_same_v<resT,
85+
typename decltype(tmp)::element_type>)
86+
{
87+
return tmp;
88+
}
89+
else {
90+
using dpctl::tensor::type_utils::vec_cast;
91+
92+
return vec_cast<resT, typename decltype(tmp)::element_type,
93+
vec_sz>(tmp);
94+
}
95+
}
96+
};
97+
98+
template <typename argT1,
99+
typename argT2,
100+
typename resT,
101+
unsigned int vec_sz = 4,
102+
unsigned int n_vecs = 2>
103+
using LogicalAndContigFunctor = elementwise_common::BinaryContigFunctor<
104+
argT1,
105+
argT2,
106+
resT,
107+
LogicalAndFunctor<argT1, argT2, resT>,
108+
vec_sz,
109+
n_vecs>;
110+
111+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
112+
using LogicalAndStridedFunctor = elementwise_common::BinaryStridedFunctor<
113+
argT1,
114+
argT2,
115+
resT,
116+
IndexerT,
117+
LogicalAndFunctor<argT1, argT2, resT>>;
118+
119+
template <typename T1, typename T2> struct LogicalAndOutputType
120+
{
121+
using value_type = typename std::disjunction< // disjunction is C++17
122+
// feature, supported by
123+
// DPC++
124+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
125+
td_ns::BinaryTypeMapResultEntry<T1,
126+
std::uint8_t,
127+
T2,
128+
std::uint8_t,
129+
bool>,
130+
td_ns::BinaryTypeMapResultEntry<T1,
131+
std::int8_t,
132+
T2,
133+
std::int8_t,
134+
bool>,
135+
td_ns::BinaryTypeMapResultEntry<T1,
136+
std::uint16_t,
137+
T2,
138+
std::uint16_t,
139+
bool>,
140+
td_ns::BinaryTypeMapResultEntry<T1,
141+
std::int16_t,
142+
T2,
143+
std::int16_t,
144+
bool>,
145+
td_ns::BinaryTypeMapResultEntry<T1,
146+
std::uint32_t,
147+
T2,
148+
std::uint32_t,
149+
bool>,
150+
td_ns::BinaryTypeMapResultEntry<T1,
151+
std::int32_t,
152+
T2,
153+
std::int32_t,
154+
bool>,
155+
td_ns::BinaryTypeMapResultEntry<T1,
156+
std::uint64_t,
157+
T2,
158+
std::uint64_t,
159+
bool>,
160+
td_ns::BinaryTypeMapResultEntry<T1,
161+
std::int64_t,
162+
T2,
163+
std::int64_t,
164+
bool>,
165+
td_ns::
166+
BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
167+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
168+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
169+
td_ns::BinaryTypeMapResultEntry<T1,
170+
std::complex<float>,
171+
T2,
172+
std::complex<float>,
173+
bool>,
174+
td_ns::BinaryTypeMapResultEntry<T1,
175+
std::complex<double>,
176+
T2,
177+
std::complex<double>,
178+
bool>,
179+
td_ns::BinaryTypeMapResultEntry<T1,
180+
float,
181+
T2,
182+
std::complex<float>,
183+
bool>,
184+
td_ns::BinaryTypeMapResultEntry<T1,
185+
std::complex<float>,
186+
T2,
187+
float,
188+
bool>,
189+
td_ns::DefaultResultEntry<void>>::result_type;
190+
};
191+
192+
template <typename argT1,
193+
typename argT2,
194+
typename resT,
195+
unsigned int vec_sz,
196+
unsigned int n_vecs>
197+
class logical_and_contig_kernel;
198+
199+
template <typename argTy1, typename argTy2>
200+
sycl::event
201+
logical_and_contig_impl(sycl::queue exec_q,
202+
size_t nelems,
203+
const char *arg1_p,
204+
py::ssize_t arg1_offset,
205+
const char *arg2_p,
206+
py::ssize_t arg2_offset,
207+
char *res_p,
208+
py::ssize_t res_offset,
209+
const std::vector<sycl::event> &depends = {})
210+
{
211+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
212+
cgh.depends_on(depends);
213+
214+
size_t lws = 64;
215+
constexpr unsigned int vec_sz = 4;
216+
constexpr unsigned int n_vecs = 2;
217+
const size_t n_groups = ((nelems + lws * n_vecs * vec_sz - 1) /
218+
(lws * n_vecs * vec_sz));
219+
const auto gws_range = sycl::range<1>(n_groups * lws);
220+
const auto lws_range = sycl::range<1>(lws);
221+
222+
using resTy =
223+
typename LogicalAndOutputType<argTy1, argTy2>::value_type;
224+
225+
const argTy1 *arg1_tp =
226+
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
227+
const argTy2 *arg2_tp =
228+
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
229+
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
230+
231+
cgh.parallel_for<logical_and_contig_kernel<argTy1, argTy2, resTy,
232+
vec_sz, n_vecs>>(
233+
sycl::nd_range<1>(gws_range, lws_range),
234+
LogicalAndContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
235+
arg1_tp, arg2_tp, res_tp, nelems));
236+
});
237+
return comp_ev;
238+
}
239+
240+
template <typename fnT, typename T1, typename T2>
241+
struct LogicalAndContigFactory
242+
{
243+
fnT get()
244+
{
245+
if constexpr (std::is_same_v<
246+
typename LogicalAndOutputType<T1, T2>::value_type,
247+
void>)
248+
{
249+
fnT fn = nullptr;
250+
return fn;
251+
}
252+
else {
253+
fnT fn = logical_and_contig_impl<T1, T2>;
254+
return fn;
255+
}
256+
}
257+
};
258+
259+
template <typename fnT, typename T1, typename T2>
260+
struct LogicalAndTypeMapFactory
261+
{
262+
/*! @brief get typeid for output type of operator()>(x, y), always bool
263+
*/
264+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
265+
{
266+
using rT = typename LogicalAndOutputType<T1, T2>::value_type;
267+
return td_ns::GetTypeid<rT>{}.get();
268+
}
269+
};
270+
271+
template <typename T1, typename T2, typename resT, typename IndexerT>
272+
class logical_and_strided_kernel;
273+
274+
template <typename argTy1, typename argTy2>
275+
sycl::event
276+
logical_and_strided_impl(sycl::queue exec_q,
277+
size_t nelems,
278+
int nd,
279+
const py::ssize_t *shape_and_strides,
280+
const char *arg1_p,
281+
py::ssize_t arg1_offset,
282+
const char *arg2_p,
283+
py::ssize_t arg2_offset,
284+
char *res_p,
285+
py::ssize_t res_offset,
286+
const std::vector<sycl::event> &depends,
287+
const std::vector<sycl::event> &additional_depends)
288+
{
289+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
290+
cgh.depends_on(depends);
291+
cgh.depends_on(additional_depends);
292+
293+
using resTy =
294+
typename LogicalAndOutputType<argTy1, argTy2>::value_type;
295+
296+
using IndexerT = typename dpctl::tensor::offset_utils::
297+
ThreeOffsets_StridedIndexer;
298+
299+
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
300+
shape_and_strides};
301+
302+
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
303+
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
304+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
305+
306+
cgh.parallel_for<
307+
logical_and_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
308+
{nelems},
309+
LogicalAndStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
310+
arg1_tp, arg2_tp, res_tp, indexer));
311+
});
312+
return comp_ev;
313+
}
314+
315+
template <typename fnT, typename T1, typename T2>
316+
struct LogicalAndStridedFactory
317+
{
318+
fnT get()
319+
{
320+
if constexpr (std::is_same_v<
321+
typename LogicalAndOutputType<T1, T2>::value_type,
322+
void>)
323+
{
324+
fnT fn = nullptr;
325+
return fn;
326+
}
327+
else {
328+
fnT fn = logical_and_strided_impl<T1, T2>;
329+
return fn;
330+
}
331+
}
332+
};
333+
334+
} // namespace logical_and
335+
} // namespace kernels
336+
} // namespace tensor
337+
} // namespace dpctl

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp

Whitespace-only changes.

0 commit comments

Comments
 (0)