Skip to content

Commit ff245a0

Browse files
committed
Added tests and fixed typo in logical_not
1 parent d3fc623 commit ff245a0

File tree

6 files changed

+1089
-24
lines changed

6 files changed

+1089
-24
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@
602602

603603
# U24: ==== LOGICAL_NOT (x)
604604
_logical_not_docstring = """
605-
log(x, out=None, order='K')
605+
logical_not(x, out=None, order='K')
606606
Computes the logical NOT for each element `x_i` of input array `x`.
607607
Args:
608608
x (usm_ndarray):

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ namespace py = pybind11;
4848
namespace td_ns = dpctl::tensor::type_dispatch;
4949
namespace tu_ns = dpctl::tensor::type_utils;
5050

51-
template <typename argT, typename resT> struct LogicalNotFunctor
51+
template <typename argT, typename resT>
52+
struct LogicalNotFunctor
5253
{
5354
static_assert(std::is_same_v<resT, bool>);
5455

55-
using is_constant = typename std::disjunction<std::is_same<argT, bool>,
56-
std::is_integral<argT>>;
57-
static constexpr resT constant_value = false;
56+
using is_constant = typename std::false_type;
57+
// constexpr resT constant_value = resT{};
5858
using supports_vec = typename std::false_type;
5959
using supports_sg_loadstore = typename std::negation<
6060
std::disjunction<tu_ns::is_complex<resT>, tu_ns::is_complex<argT>>>;
@@ -84,7 +84,8 @@ using LogicalNotStridedFunctor =
8484
IndexerT,
8585
LogicalNotFunctor<argTy, resTy>>;
8686

87-
template <typename argTy> struct LogicalNotOutputType
87+
template <typename argTy>
88+
struct LogicalNotOutputType
8889
{
8990
using value_type = bool;
9091
};
@@ -94,19 +95,20 @@ class logical_not_contig_kernel;
9495

9596
template <typename argTy>
9697
sycl::event
97-
logical_not_contig_impl(sycl::queue exec_q,
98-
size_t nelems,
99-
const char *arg_p,
100-
char *res_p,
101-
const std::vector<sycl::event> &depends = {})
98+
logical_not_contig_impl(sycl::queue exec_q,
99+
size_t nelems,
100+
const char *arg_p,
101+
char *res_p,
102+
const std::vector<sycl::event> &depends = {})
102103
{
103104
return elementwise_common::unary_contig_impl<argTy, LogicalNotOutputType,
104105
LogicalNotContigFunctor,
105106
logical_not_contig_kernel>(
106107
exec_q, nelems, arg_p, res_p, depends);
107108
}
108109

109-
template <typename fnT, typename T> struct LogicalNotContigFactory
110+
template <typename fnT, typename T>
111+
struct LogicalNotContigFactory
110112
{
111113
fnT get()
112114
{
@@ -115,7 +117,8 @@ template <typename fnT, typename T> struct LogicalNotContigFactory
115117
}
116118
};
117119

118-
template <typename fnT, typename T> struct LogicalNotTypeMapFactory
120+
template <typename fnT, typename T>
121+
struct LogicalNotTypeMapFactory
119122
{
120123
/*! @brief get typeid for output type of sycl::logical_not(T x) */
121124
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
@@ -130,16 +133,16 @@ class logical_not_strided_kernel;
130133

131134
template <typename argTy>
132135
sycl::event
133-
logical_not_strided_impl(sycl::queue exec_q,
134-
size_t nelems,
135-
int nd,
136-
const py::ssize_t *shape_and_strides,
137-
const char *arg_p,
138-
py::ssize_t arg_offset,
139-
char *res_p,
140-
py::ssize_t res_offset,
141-
const std::vector<sycl::event> &depends,
142-
const std::vector<sycl::event> &additional_depends)
136+
logical_not_strided_impl(sycl::queue exec_q,
137+
size_t nelems,
138+
int nd,
139+
const py::ssize_t *shape_and_strides,
140+
const char *arg_p,
141+
py::ssize_t arg_offset,
142+
char *res_p,
143+
py::ssize_t res_offset,
144+
const std::vector<sycl::event> &depends,
145+
const std::vector<sycl::event> &additional_depends)
143146
{
144147
return elementwise_common::unary_strided_impl<argTy, LogicalNotOutputType,
145148
LogicalNotStridedFunctor,
@@ -148,7 +151,8 @@ logical_not_strided_impl(sycl::queue exec_q,
148151
res_offset, depends, additional_depends);
149152
}
150153

151-
template <typename fnT, typename T> struct LogicalNotStridedFactory
154+
template <typename fnT, typename T>
155+
struct LogicalNotStridedFactory
152156
{
153157
fnT get()
154158
{

0 commit comments

Comments
 (0)