Skip to content

Removes redundant output types from logical and comparison functions #1309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,28 +135,9 @@ sycl::event expm1_contig_impl(sycl::queue exec_q,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
sycl::event expm1_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
constexpr size_t lws = 64;
constexpr unsigned int vec_sz = 4;
constexpr unsigned int n_vecs = 2;
static_assert(lws % vec_sz == 0);
auto gws_range = sycl::range<1>(
((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) *
lws);
auto lws_range = sycl::range<1>(lws);

using resTy = typename Expm1OutputType<argTy>::value_type;
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
class expm1_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
sycl::nd_range<1>(gws_range, lws_range),
Expm1ContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
nelems));
});
return expm1_ev;
return elementwise_common::unary_contig_impl<
argTy, Expm1OutputType, Expm1ContigFunctor, expm1_contig_kernel>(
exec_q, nelems, arg_p, res_p, depends);
}

template <typename fnT, typename T> struct Expm1ContigFactory
Expand Down Expand Up @@ -213,26 +194,10 @@ expm1_strided_impl(sycl::queue exec_q,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.depends_on(additional_depends);

using resTy = typename Expm1OutputType<argTy>::value_type;
using IndexerT =
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;

IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides);

const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

sycl::range<1> gRange{nelems};

cgh.parallel_for<expm1_strided_kernel<argTy, resTy, IndexerT>>(
gRange, Expm1StridedFunctor<argTy, resTy, IndexerT>(
arg_tp, res_tp, arg_res_indexer));
});
return comp_ev;
return elementwise_common::unary_strided_impl<
argTy, Expm1OutputType, Expm1StridedFunctor, expm1_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T> struct Expm1StridedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,8 @@ template <typename argT1, typename argT2, typename resT> struct GreaterFunctor

resT operator()(const argT1 &in1, const argT2 &in2)
{
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
std::is_same_v<argT2, float>)
{
float real1 = std::real(in1);
return (real1 == in2) ? (std::imag(in1) > 0.0f) : real1 > in2;
}
else if constexpr (std::is_same_v<argT1, float> &&
std::is_same_v<argT2, std::complex<float>>)
{
float real2 = std::real(in2);
return (in1 == real2) ? (0.0f > std::imag(in2)) : in1 > real2;
}
else if constexpr (tu_ns::is_complex<argT1>::value ||
tu_ns::is_complex<argT2>::value)
if constexpr (tu_ns::is_complex<argT1>::value ||
tu_ns::is_complex<argT2>::value)
{
static_assert(std::is_same_v<argT1, argT2>);
using realT = typename argT1::value_type;
Expand Down Expand Up @@ -174,10 +162,6 @@ template <typename T1, typename T2> struct GreaterOutputType
T2,
std::complex<double>,
bool>,
td_ns::
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -199,32 +183,10 @@ sycl::event greater_contig_impl(sycl::queue exec_q,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

size_t lws = 64;
constexpr unsigned int vec_sz = 4;
constexpr unsigned int n_vecs = 2;
const size_t n_groups =
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
const auto gws_range = sycl::range<1>(n_groups * lws);
const auto lws_range = sycl::range<1>(lws);

using resTy = typename GreaterOutputType<argTy1, argTy2>::value_type;

const argTy1 *arg1_tp =
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
const argTy2 *arg2_tp =
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;

cgh.parallel_for<
greater_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
sycl::nd_range<1>(gws_range, lws_range),
GreaterContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
arg1_tp, arg2_tp, res_tp, nelems));
});
return comp_ev;
return elementwise_common::binary_contig_impl<
argTy1, argTy2, GreaterOutputType, GreaterContigFunctor,
greater_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends);
}

template <typename fnT, typename T1, typename T2> struct GreaterContigFactory
Expand Down Expand Up @@ -272,28 +234,11 @@ greater_strided_impl(sycl::queue exec_q,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.depends_on(additional_depends);

using resTy = typename GreaterOutputType<argTy1, argTy2>::value_type;

using IndexerT =
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;

IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
shape_and_strides};

const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
greater_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems}, GreaterStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
return comp_ev;
return elementwise_common::binary_strided_impl<
argTy1, argTy2, GreaterOutputType, GreaterStridedFunctor,
greater_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct GreaterStridedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,8 @@ struct GreaterEqualFunctor

resT operator()(const argT1 &in1, const argT2 &in2)
{
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
std::is_same_v<argT2, float>)
{
float real1 = std::real(in1);
return (real1 == in2) ? (std::imag(in1) >= 0.0f) : real1 >= in2;
}
else if constexpr (std::is_same_v<argT1, float> &&
std::is_same_v<argT2, std::complex<float>>)
{
float real2 = std::real(in2);
return (in1 == real2) ? (0.0f >= std::imag(in2)) : in1 >= real2;
}
else if constexpr (tu_ns::is_complex<argT1>::value ||
tu_ns::is_complex<argT2>::value)
if constexpr (tu_ns::is_complex<argT1>::value ||
tu_ns::is_complex<argT2>::value)
{
static_assert(std::is_same_v<argT1, argT2>);
using realT = typename argT1::value_type;
Expand Down Expand Up @@ -175,10 +163,6 @@ template <typename T1, typename T2> struct GreaterEqualOutputType
T2,
std::complex<double>,
bool>,
td_ns::
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
td_ns::
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
td_ns::DefaultResultEntry<void>>::result_type;
};

Expand All @@ -201,33 +185,11 @@ greater_equal_contig_impl(sycl::queue exec_q,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

size_t lws = 64;
constexpr unsigned int vec_sz = 4;
constexpr unsigned int n_vecs = 2;
const size_t n_groups =
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
const auto gws_range = sycl::range<1>(n_groups * lws);
const auto lws_range = sycl::range<1>(lws);

using resTy =
typename GreaterEqualOutputType<argTy1, argTy2>::value_type;

const argTy1 *arg1_tp =
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
const argTy2 *arg2_tp =
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;

cgh.parallel_for<
greater_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
sycl::nd_range<1>(gws_range, lws_range),
GreaterEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
arg1_tp, arg2_tp, res_tp, nelems));
});
return comp_ev;
return elementwise_common::binary_contig_impl<
argTy1, argTy2, GreaterEqualOutputType, GreaterEqualContigFunctor,
greater_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset,
arg2_p, arg2_offset, res_p, res_offset,
depends);
}

template <typename fnT, typename T1, typename T2>
Expand Down Expand Up @@ -278,30 +240,11 @@ greater_equal_strided_impl(sycl::queue exec_q,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.depends_on(additional_depends);

using resTy =
typename GreaterEqualOutputType<argTy1, argTy2>::value_type;

using IndexerT =
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;

IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
shape_and_strides};

const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
greater_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems},
GreaterEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
return comp_ev;
return elementwise_common::binary_strided_impl<
argTy1, argTy2, GreaterEqualOutputType, GreaterEqualStridedFunctor,
greater_equal_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ template <typename fnT, typename T1, typename T2> struct HypotTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class hypot_strided_strided_kernel;
class hypot_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -187,9 +187,9 @@ hypot_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, HypotOutputType, HypotStridedFunctor,
hypot_strided_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
hypot_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct HypotStridedFactory
Expand Down
Loading