Skip to content

Commit ed69be8

Browse files
committed
Implements overloads for binary element-wise functions between a broadcast scalar and an array
Implements these new kernels for addition
1 parent d065abb commit ed69be8

32 files changed

+1349
-54
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,42 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
105105
tmp);
106106
}
107107
}
108+
109+
template <int vec_sz>
110+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
111+
const argT2 &in2) const
112+
{
113+
auto tmp = in1 + in2;
114+
if constexpr (std::is_same_v<resT,
115+
typename decltype(tmp)::element_type>)
116+
{
117+
return tmp;
118+
}
119+
else {
120+
using dpctl::tensor::type_utils::vec_cast;
121+
122+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
123+
tmp);
124+
}
125+
}
126+
127+
template <int vec_sz>
128+
sycl::vec<resT, vec_sz>
129+
operator()(const argT1 &in1, const sycl::vec<argT2, vec_sz> &in2) const
130+
{
131+
auto tmp = in1 + in2;
132+
if constexpr (std::is_same_v<resT,
133+
typename decltype(tmp)::element_type>)
134+
{
135+
return tmp;
136+
}
137+
else {
138+
using dpctl::tensor::type_utils::vec_cast;
139+
140+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
141+
tmp);
142+
}
143+
}
108144
};
109145

110146
template <typename argT1,
@@ -393,6 +429,126 @@ struct AddContigRowContigMatrixBroadcastFactory
393429
}
394430
};
395431

432+
template <typename argT1,
433+
typename argT2,
434+
typename resT,
435+
unsigned int vec_sz = 4,
436+
unsigned int n_vecs = 2,
437+
bool enable_sg_loadstore = true>
438+
using AddScalarContigArrayFunctor =
439+
elementwise_common::BinaryScalarContigArrayFunctor<
440+
argT1,
441+
argT2,
442+
resT,
443+
AddFunctor<argT1, argT2, resT>,
444+
vec_sz,
445+
n_vecs,
446+
enable_sg_loadstore>;
447+
448+
template <typename argT1,
449+
typename argT2,
450+
typename resT,
451+
unsigned int vec_sz = 4,
452+
unsigned int n_vecs = 2,
453+
bool enable_sg_loadstore = true>
454+
using AddContigArrayScalarFunctor =
455+
elementwise_common::BinaryContigArrayScalarFunctor<
456+
argT1,
457+
argT2,
458+
resT,
459+
AddFunctor<argT1, argT2, resT>,
460+
vec_sz,
461+
n_vecs,
462+
enable_sg_loadstore>;
463+
464+
template <typename argT1,
465+
typename argT2,
466+
typename resT,
467+
unsigned int vec_sz,
468+
unsigned int n_vecs>
469+
class add_scalar_contig_array_kernel;
470+
471+
template <typename argTy1, typename argTy2>
472+
sycl::event
473+
add_scalar_contig_array_impl(sycl::queue &exec_q,
474+
size_t nelems,
475+
const char *arg1_p,
476+
ssize_t arg1_offset,
477+
const char *arg2_p,
478+
ssize_t arg2_offset,
479+
char *res_p,
480+
ssize_t res_offset,
481+
const std::vector<sycl::event> &depends = {})
482+
{
483+
return elementwise_common::binary_scalar_contig_array_impl<
484+
argTy1, argTy2, AddOutputType, AddScalarContigArrayFunctor,
485+
add_scalar_contig_array_kernel>(exec_q, nelems, arg1_p, arg1_offset,
486+
arg2_p, arg2_offset, res_p, res_offset,
487+
depends);
488+
}
489+
490+
template <typename argT1,
491+
typename argT2,
492+
typename resT,
493+
unsigned int vec_sz,
494+
unsigned int n_vecs>
495+
class add_contig_array_scalar_kernel;
496+
497+
template <typename argTy1, typename argTy2>
498+
sycl::event
499+
add_contig_array_scalar_impl(sycl::queue &exec_q,
500+
size_t nelems,
501+
const char *arg1_p,
502+
ssize_t arg1_offset,
503+
const char *arg2_p,
504+
ssize_t arg2_offset,
505+
char *res_p,
506+
ssize_t res_offset,
507+
const std::vector<sycl::event> &depends = {})
508+
{
509+
return elementwise_common::binary_contig_array_scalar_impl<
510+
argTy1, argTy2, AddOutputType, AddContigArrayScalarFunctor,
511+
add_contig_array_scalar_kernel>(exec_q, nelems, arg1_p, arg1_offset,
512+
arg2_p, arg2_offset, res_p, res_offset,
513+
depends);
514+
}
515+
516+
template <typename fnT, typename T1, typename T2>
517+
struct AddScalarContigArrayFactory
518+
{
519+
fnT get()
520+
{
521+
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
522+
void>)
523+
{
524+
fnT fn = nullptr;
525+
return fn;
526+
}
527+
else {
528+
fnT fn = add_scalar_contig_array_impl<T1, T2>;
529+
return fn;
530+
}
531+
}
532+
};
533+
534+
template <typename fnT, typename T1, typename T2>
535+
struct AddContigArrayScalarFactory
536+
{
537+
fnT get()
538+
{
539+
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
540+
void>)
541+
{
542+
fnT fn = nullptr;
543+
return fn;
544+
}
545+
else {
546+
fnT fn = add_contig_array_scalar_impl<T1, T2>;
547+
return fn;
548+
}
549+
}
550+
};
551+
396552
template <typename argT, typename resT> struct AddInplaceFunctor
397553
{
398554

@@ -409,6 +565,12 @@ template <typename argT, typename resT> struct AddInplaceFunctor
409565
{
410566
res += in;
411567
}
568+
569+
template <int vec_sz>
570+
void operator()(sycl::vec<resT, vec_sz> &res, const argT &in)
571+
{
572+
res += in;
573+
}
412574
};
413575

414576
template <typename argT,
@@ -606,6 +768,58 @@ struct AddInplaceRowMatrixBroadcastFactory
606768
}
607769
};
608770

771+
template <typename argT,
772+
typename resT,
773+
unsigned int vec_sz = 4,
774+
unsigned int n_vecs = 2,
775+
bool enable_sg_loadstore = true>
776+
using AddInplaceScalarContigFunctor =
777+
elementwise_common::BinaryInplaceScalarContigFunctor<
778+
argT,
779+
resT,
780+
AddInplaceFunctor<argT, resT>,
781+
vec_sz,
782+
n_vecs,
783+
enable_sg_loadstore>;
784+
785+
template <typename argT,
786+
typename resT,
787+
unsigned int vec_sz,
788+
unsigned int n_vecs>
789+
class add_inplace_scalar_contig_kernel;
790+
791+
template <typename argTy, typename resTy>
792+
sycl::event
793+
add_inplace_scalar_contig_impl(sycl::queue &exec_q,
794+
size_t nelems,
795+
const char *arg_p,
796+
ssize_t arg_offset,
797+
char *res_p,
798+
ssize_t res_offset,
799+
const std::vector<sycl::event> &depends = {})
800+
{
801+
return elementwise_common::binary_inplace_scalar_contig_impl<
802+
argTy, resTy, AddInplaceScalarContigFunctor,
803+
add_inplace_scalar_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
804+
res_p, res_offset, depends);
805+
}
806+
807+
template <typename fnT, typename T1, typename T2>
808+
struct AddInplaceScalarContigFactory
809+
{
810+
fnT get()
811+
{
812+
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
813+
fnT fn = nullptr;
814+
return fn;
815+
}
816+
else {
817+
fnT fn = add_inplace_scalar_contig_impl<T1, T2>;
818+
return fn;
819+
}
820+
}
821+
};
822+
609823
} // namespace add
610824
} // namespace kernels
611825
} // namespace tensor

0 commit comments

Comments
 (0)