Skip to content

Overloads for array-scalar and scalar-array broadcast element-wise functions #1815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,9 @@ def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):


def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
# if the kind of result is different from
# the kind of input, use the default data
# we use default dtype for the resulting kind.
# This guarantees alignment of reciprocal and
# divide output types.
# if the kind of result is different from the kind of input, we use the
# default floating-point dtype for the resulting kind. This guarantees
# alignment of reciprocal and divide output types.
if buf_dt.kind != arg_dtype.kind:
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
if res_dt == default_dt:
Expand Down
272 changes: 264 additions & 8 deletions dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,42 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
tmp);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
const argT2 &in2) const
{
auto tmp = in1 + in2;
if constexpr (std::is_same_v<resT,
typename decltype(tmp)::element_type>)
{
return tmp;
}
else {
using dpctl::tensor::type_utils::vec_cast;

return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
tmp);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz>
operator()(const argT1 &in1, const sycl::vec<argT2, vec_sz> &in2) const
{
auto tmp = in1 + in2;
if constexpr (std::is_same_v<resT,
typename decltype(tmp)::element_type>)
{
return tmp;
}
else {
using dpctl::tensor::type_utils::vec_cast;

return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
tmp);
}
}
};

template <typename argT1,
Expand Down Expand Up @@ -393,6 +429,126 @@ struct AddContigRowContigMatrixBroadcastFactory
}
};

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using AddScalarContigArrayFunctor =
elementwise_common::BinaryScalarContigArrayFunctor<
argT1,
argT2,
resT,
AddFunctor<argT1, argT2, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using AddContigArrayScalarFunctor =
elementwise_common::BinaryContigArrayScalarFunctor<
argT1,
argT2,
resT,
AddFunctor<argT1, argT2, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class add_scalar_contig_array_kernel;

template <typename argTy1, typename argTy2>
sycl::event
add_scalar_contig_array_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg1_p,
ssize_t arg1_offset,
const char *arg2_p,
ssize_t arg2_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_scalar_contig_array_impl<
argTy1, argTy2, AddOutputType, AddScalarContigArrayFunctor,
add_scalar_contig_array_kernel>(exec_q, nelems, arg1_p, arg1_offset,
arg2_p, arg2_offset, res_p, res_offset,
depends);
}

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class add_contig_array_scalar_kernel;

template <typename argTy1, typename argTy2>
sycl::event
add_contig_array_scalar_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg1_p,
ssize_t arg1_offset,
const char *arg2_p,
ssize_t arg2_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_contig_array_scalar_impl<
argTy1, argTy2, AddOutputType, AddContigArrayScalarFunctor,
add_contig_array_scalar_kernel>(exec_q, nelems, arg1_p, arg1_offset,
arg2_p, arg2_offset, res_p, res_offset,
depends);
}

template <typename fnT, typename T1, typename T2>
struct AddScalarContigArrayFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = add_scalar_contig_array_impl<T1, T2>;
return fn;
}
}
};

template <typename fnT, typename T1, typename T2>
struct AddContigArrayScalarFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = add_contig_array_scalar_impl<T1, T2>;
return fn;
}
}
};

template <typename argT, typename resT> struct AddInplaceFunctor
{

Expand All @@ -409,6 +565,12 @@ template <typename argT, typename resT> struct AddInplaceFunctor
{
res += in;
}

template <int vec_sz>
void operator()(sycl::vec<resT, vec_sz> &res, const argT &in)
{
res += in;
}
};

template <typename argT,
Expand Down Expand Up @@ -438,6 +600,53 @@ template <typename argT,
unsigned int n_vecs>
class add_inplace_contig_kernel;

/* @brief Types supported by in-place add */
template <typename argTy, typename resTy> struct AddInplaceTypePairSupport
{
/* value if true a kernel for <argTy, resTy> must be instantiated */
static constexpr bool is_defined = std::disjunction< // disjunction is
// C++17 feature,
// supported by
// DPC++ input bool
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
td_ns::TypePairDefinedEntry<argTy, sycl::half, resTy, sycl::half>,
td_ns::TypePairDefinedEntry<argTy, float, resTy, float>,
td_ns::TypePairDefinedEntry<argTy, double, resTy, double>,
td_ns::TypePairDefinedEntry<argTy,
std::complex<float>,
resTy,
std::complex<float>>,
td_ns::TypePairDefinedEntry<argTy,
std::complex<double>,
resTy,
std::complex<double>>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

template <typename fnT, typename argT, typename resT>
struct AddInplaceTypeMapFactory
{
/*! @brief get typeid for output type of x += y */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
if constexpr (AddInplaceTypePairSupport<argT, resT>::is_defined) {
return td_ns::GetTypeid<resT>{}.get();
}
else {
return td_ns::GetTypeid<void>{}.get();
}
}
};

template <typename argTy, typename resTy>
sycl::event
add_inplace_contig_impl(sycl::queue &exec_q,
Expand All @@ -457,9 +666,7 @@ template <typename fnT, typename T1, typename T2> struct AddInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -497,9 +704,7 @@ struct AddInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -544,8 +749,7 @@ struct AddInplaceRowMatrixBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (!std::is_same_v<resT, T2>) {
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand All @@ -564,6 +768,58 @@ struct AddInplaceRowMatrixBroadcastFactory
}
};

template <typename argT,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using AddInplaceScalarContigFunctor =
elementwise_common::BinaryInplaceScalarContigFunctor<
argT,
resT,
AddInplaceFunctor<argT, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argT,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class add_inplace_scalar_contig_kernel;

template <typename argTy, typename resTy>
sycl::event
add_inplace_scalar_contig_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg_p,
ssize_t arg_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_inplace_scalar_contig_impl<
argTy, resTy, AddInplaceScalarContigFunctor,
add_inplace_scalar_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
res_p, res_offset, depends);
}

template <typename fnT, typename T1, typename T2>
struct AddInplaceScalarContigFactory
{
fnT get()
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
else {
fnT fn = add_inplace_scalar_contig_impl<T1, T2>;
return fn;
}
}
};

} // namespace add
} // namespace kernels
} // namespace tensor
Expand Down
Loading
Loading