Skip to content

Commit 34a78db

Browse files
committed
Removed unnecessary output types from logical and comparison elementwise functions
Logical and comparison functions were constructing kernels for mixed 32-bit floats and 64-bit complex numbers. To prevent binary size inflation, these have been removed. Logical and comparison operations now also elementwise_common templates Corrected various typos
1 parent 47f4bc9 commit 34a78db

File tree

13 files changed

+108
-554
lines changed

13 files changed

+108
-554
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 7 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -135,28 +135,9 @@ sycl::event expm1_contig_impl(sycl::queue exec_q,
135135
char *res_p,
136136
const std::vector<sycl::event> &depends = {})
137137
{
138-
sycl::event expm1_ev = exec_q.submit([&](sycl::handler &cgh) {
139-
cgh.depends_on(depends);
140-
constexpr size_t lws = 64;
141-
constexpr unsigned int vec_sz = 4;
142-
constexpr unsigned int n_vecs = 2;
143-
static_assert(lws % vec_sz == 0);
144-
auto gws_range = sycl::range<1>(
145-
((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) *
146-
lws);
147-
auto lws_range = sycl::range<1>(lws);
148-
149-
using resTy = typename Expm1OutputType<argTy>::value_type;
150-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
151-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
152-
153-
cgh.parallel_for<
154-
class expm1_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
155-
sycl::nd_range<1>(gws_range, lws_range),
156-
Expm1ContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
157-
nelems));
158-
});
159-
return expm1_ev;
138+
return elementwise_common::unary_contig_impl<
139+
argTy, Expm1OutputType, Expm1ContigFunctor, expm1_contig_kernel>(
140+
exec_q, nelems, arg_p, res_p, depends);
160141
}
161142

162143
template <typename fnT, typename T> struct Expm1ContigFactory
@@ -213,26 +194,10 @@ expm1_strided_impl(sycl::queue exec_q,
213194
const std::vector<sycl::event> &depends,
214195
const std::vector<sycl::event> &additional_depends)
215196
{
216-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
217-
cgh.depends_on(depends);
218-
cgh.depends_on(additional_depends);
219-
220-
using resTy = typename Expm1OutputType<argTy>::value_type;
221-
using IndexerT =
222-
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
223-
224-
IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides);
225-
226-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
227-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
228-
229-
sycl::range<1> gRange{nelems};
230-
231-
cgh.parallel_for<expm1_strided_kernel<argTy, resTy, IndexerT>>(
232-
gRange, Expm1StridedFunctor<argTy, resTy, IndexerT>(
233-
arg_tp, res_tp, arg_res_indexer));
234-
});
235-
return comp_ev;
197+
return elementwise_common::unary_strided_impl<
198+
argTy, Expm1OutputType, Expm1StridedFunctor, expm1_strided_kernel>(
199+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
200+
res_offset, depends, additional_depends);
236201
}
237202

238203
template <typename fnT, typename T> struct Expm1StridedFactory

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp

Lines changed: 11 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,8 @@ template <typename argT1, typename argT2, typename resT> struct GreaterFunctor
6363

6464
resT operator()(const argT1 &in1, const argT2 &in2)
6565
{
66-
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
67-
std::is_same_v<argT2, float>)
68-
{
69-
float real1 = std::real(in1);
70-
return (real1 == in2) ? (std::imag(in1) > 0.0f) : real1 > in2;
71-
}
72-
else if constexpr (std::is_same_v<argT1, float> &&
73-
std::is_same_v<argT2, std::complex<float>>)
74-
{
75-
float real2 = std::real(in2);
76-
return (in1 == real2) ? (0.0f > std::imag(in2)) : in1 > real2;
77-
}
78-
else if constexpr (tu_ns::is_complex<argT1>::value ||
79-
tu_ns::is_complex<argT2>::value)
66+
if constexpr (tu_ns::is_complex<argT1>::value ||
67+
tu_ns::is_complex<argT2>::value)
8068
{
8169
static_assert(std::is_same_v<argT1, argT2>);
8270
using realT = typename argT1::value_type;
@@ -174,10 +162,6 @@ template <typename T1, typename T2> struct GreaterOutputType
174162
T2,
175163
std::complex<double>,
176164
bool>,
177-
td_ns::
178-
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
179-
td_ns::
180-
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
181165
td_ns::DefaultResultEntry<void>>::result_type;
182166
};
183167

@@ -199,32 +183,10 @@ sycl::event greater_contig_impl(sycl::queue exec_q,
199183
py::ssize_t res_offset,
200184
const std::vector<sycl::event> &depends = {})
201185
{
202-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
203-
cgh.depends_on(depends);
204-
205-
size_t lws = 64;
206-
constexpr unsigned int vec_sz = 4;
207-
constexpr unsigned int n_vecs = 2;
208-
const size_t n_groups =
209-
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
210-
const auto gws_range = sycl::range<1>(n_groups * lws);
211-
const auto lws_range = sycl::range<1>(lws);
212-
213-
using resTy = typename GreaterOutputType<argTy1, argTy2>::value_type;
214-
215-
const argTy1 *arg1_tp =
216-
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
217-
const argTy2 *arg2_tp =
218-
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
219-
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
220-
221-
cgh.parallel_for<
222-
greater_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
223-
sycl::nd_range<1>(gws_range, lws_range),
224-
GreaterContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
225-
arg1_tp, arg2_tp, res_tp, nelems));
226-
});
227-
return comp_ev;
186+
return elementwise_common::binary_contig_impl<
187+
argTy1, argTy2, GreaterOutputType, GreaterContigFunctor,
188+
greater_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
189+
arg2_offset, res_p, res_offset, depends);
228190
}
229191

230192
template <typename fnT, typename T1, typename T2> struct GreaterContigFactory
@@ -272,28 +234,11 @@ greater_strided_impl(sycl::queue exec_q,
272234
const std::vector<sycl::event> &depends,
273235
const std::vector<sycl::event> &additional_depends)
274236
{
275-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
276-
cgh.depends_on(depends);
277-
cgh.depends_on(additional_depends);
278-
279-
using resTy = typename GreaterOutputType<argTy1, argTy2>::value_type;
280-
281-
using IndexerT =
282-
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
283-
284-
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
285-
shape_and_strides};
286-
287-
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
288-
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
289-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
290-
291-
cgh.parallel_for<
292-
greater_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
293-
{nelems}, GreaterStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
294-
arg1_tp, arg2_tp, res_tp, indexer));
295-
});
296-
return comp_ev;
237+
return elementwise_common::binary_strided_impl<
238+
argTy1, argTy2, GreaterOutputType, GreaterStridedFunctor,
239+
greater_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
240+
arg1_offset, arg2_p, arg2_offset, res_p,
241+
res_offset, depends, additional_depends);
297242
}
298243

299244
template <typename fnT, typename T1, typename T2> struct GreaterStridedFactory

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp

Lines changed: 12 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,8 @@ struct GreaterEqualFunctor
6464

6565
resT operator()(const argT1 &in1, const argT2 &in2)
6666
{
67-
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
68-
std::is_same_v<argT2, float>)
69-
{
70-
float real1 = std::real(in1);
71-
return (real1 == in2) ? (std::imag(in1) >= 0.0f) : real1 >= in2;
72-
}
73-
else if constexpr (std::is_same_v<argT1, float> &&
74-
std::is_same_v<argT2, std::complex<float>>)
75-
{
76-
float real2 = std::real(in2);
77-
return (in1 == real2) ? (0.0f >= std::imag(in2)) : in1 >= real2;
78-
}
79-
else if constexpr (tu_ns::is_complex<argT1>::value ||
80-
tu_ns::is_complex<argT2>::value)
67+
if constexpr (tu_ns::is_complex<argT1>::value ||
68+
tu_ns::is_complex<argT2>::value)
8169
{
8270
static_assert(std::is_same_v<argT1, argT2>);
8371
using realT = typename argT1::value_type;
@@ -175,10 +163,6 @@ template <typename T1, typename T2> struct GreaterEqualOutputType
175163
T2,
176164
std::complex<double>,
177165
bool>,
178-
td_ns::
179-
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
180-
td_ns::
181-
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
182166
td_ns::DefaultResultEntry<void>>::result_type;
183167
};
184168

@@ -201,33 +185,11 @@ greater_equal_contig_impl(sycl::queue exec_q,
201185
py::ssize_t res_offset,
202186
const std::vector<sycl::event> &depends = {})
203187
{
204-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
205-
cgh.depends_on(depends);
206-
207-
size_t lws = 64;
208-
constexpr unsigned int vec_sz = 4;
209-
constexpr unsigned int n_vecs = 2;
210-
const size_t n_groups =
211-
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
212-
const auto gws_range = sycl::range<1>(n_groups * lws);
213-
const auto lws_range = sycl::range<1>(lws);
214-
215-
using resTy =
216-
typename GreaterEqualOutputType<argTy1, argTy2>::value_type;
217-
218-
const argTy1 *arg1_tp =
219-
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
220-
const argTy2 *arg2_tp =
221-
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
222-
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
223-
224-
cgh.parallel_for<
225-
greater_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
226-
sycl::nd_range<1>(gws_range, lws_range),
227-
GreaterEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
228-
arg1_tp, arg2_tp, res_tp, nelems));
229-
});
230-
return comp_ev;
188+
return elementwise_common::binary_contig_impl<
189+
argTy1, argTy2, GreaterEqualOutputType, GreaterEqualContigFunctor,
190+
greater_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset,
191+
arg2_p, arg2_offset, res_p, res_offset,
192+
depends);
231193
}
232194

233195
template <typename fnT, typename T1, typename T2>
@@ -278,30 +240,11 @@ greater_equal_strided_impl(sycl::queue exec_q,
278240
const std::vector<sycl::event> &depends,
279241
const std::vector<sycl::event> &additional_depends)
280242
{
281-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
282-
cgh.depends_on(depends);
283-
cgh.depends_on(additional_depends);
284-
285-
using resTy =
286-
typename GreaterEqualOutputType<argTy1, argTy2>::value_type;
287-
288-
using IndexerT =
289-
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
290-
291-
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
292-
shape_and_strides};
293-
294-
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
295-
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
296-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
297-
298-
cgh.parallel_for<
299-
greater_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
300-
{nelems},
301-
GreaterEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
302-
arg1_tp, arg2_tp, res_tp, indexer));
303-
});
304-
return comp_ev;
243+
return elementwise_common::binary_strided_impl<
244+
argTy1, argTy2, GreaterEqualOutputType, GreaterEqualStridedFunctor,
245+
greater_equal_strided_kernel>(
246+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
247+
arg2_offset, res_p, res_offset, depends, additional_depends);
305248
}
306249

307250
template <typename fnT, typename T1, typename T2>

dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ template <typename fnT, typename T1, typename T2> struct HypotTypeMapFactory
168168
};
169169

170170
template <typename T1, typename T2, typename resT, typename IndexerT>
171-
class hypot_strided_strided_kernel;
171+
class hypot_strided_kernel;
172172

173173
template <typename argTy1, typename argTy2>
174174
sycl::event
@@ -187,9 +187,9 @@ hypot_strided_impl(sycl::queue exec_q,
187187
{
188188
return elementwise_common::binary_strided_impl<
189189
argTy1, argTy2, HypotOutputType, HypotStridedFunctor,
190-
hypot_strided_strided_kernel>(
191-
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
192-
arg2_offset, res_p, res_offset, depends, additional_depends);
190+
hypot_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
191+
arg1_offset, arg2_p, arg2_offset, res_p,
192+
res_offset, depends, additional_depends);
193193
}
194194

195195
template <typename fnT, typename T1, typename T2> struct HypotStridedFactory

0 commit comments

Comments
 (0)