Skip to content

Commit 46dc288

Browse files
authored
Merge pull request #1816 from IntelPython/dtype-matrices-for-in-place-element-wise-ops
Introduce dedicated type support matrices for in-place element-wise operations
2 parents e2c7425 + 2ffedad commit 46dc288

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+1082
-935
lines changed

dpctl/tensor/_type_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,9 @@ def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
140140

141141

142142
def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
143-
# if the kind of result is different from
144-
# the kind of input, use the default data
145-
# we use default dtype for the resulting kind.
146-
# This guarantees alignment of reciprocal and
147-
# divide output types.
143+
# if the kind of result is different from the kind of input, we use the
144+
# default floating-point dtype for the resulting kind. This guarantees
145+
# alignment of reciprocal and divide output types.
148146
if buf_dt.kind != arg_dtype.kind:
149147
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
150148
if res_dt == default_dt:

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ using AbsContigFunctor =
102102

103103
template <typename T> struct AbsOutputType
104104
{
105-
using value_type = typename std::disjunction< // disjunction is C++17
106-
// feature, supported by DPC++
105+
using value_type = typename std::disjunction<
107106
td_ns::TypeMapResultEntry<T, bool>,
108107
td_ns::TypeMapResultEntry<T, std::uint8_t>,
109108
td_ns::TypeMapResultEntry<T, std::uint16_t>,
@@ -119,6 +118,8 @@ template <typename T> struct AbsOutputType
119118
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
120119
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
121120
td_ns::DefaultResultEntry<void>>::result_type;
121+
122+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
122123
};
123124

124125
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -140,9 +141,7 @@ template <typename fnT, typename T> struct AbsContigFactory
140141
{
141142
fnT get()
142143
{
143-
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
144-
void>)
145-
{
144+
if constexpr (!AbsOutputType<T>::is_defined) {
146145
fnT fn = nullptr;
147146
return fn;
148147
}
@@ -191,9 +190,7 @@ template <typename fnT, typename T> struct AbsStridedFactory
191190
{
192191
fnT get()
193192
{
194-
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
195-
void>)
196-
{
193+
if constexpr (!AbsOutputType<T>::is_defined) {
197194
fnT fn = nullptr;
198195
return fn;
199196
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,15 @@ using AcosStridedFunctor = elementwise_common::
145145

146146
template <typename T> struct AcosOutputType
147147
{
148-
using value_type = typename std::disjunction< // disjunction is C++17
149-
// feature, supported by DPC++
148+
using value_type = typename std::disjunction<
150149
td_ns::TypeMapResultEntry<T, sycl::half>,
151150
td_ns::TypeMapResultEntry<T, float>,
152151
td_ns::TypeMapResultEntry<T, double>,
153152
td_ns::TypeMapResultEntry<T, std::complex<float>>,
154153
td_ns::TypeMapResultEntry<T, std::complex<double>>,
155154
td_ns::DefaultResultEntry<void>>::result_type;
155+
156+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
156157
};
157158

158159
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -174,9 +175,7 @@ template <typename fnT, typename T> struct AcosContigFactory
174175
{
175176
fnT get()
176177
{
177-
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
178-
void>)
179-
{
178+
if constexpr (!AcosOutputType<T>::is_defined) {
180179
fnT fn = nullptr;
181180
return fn;
182181
}
@@ -222,9 +221,7 @@ template <typename fnT, typename T> struct AcosStridedFactory
222221
{
223222
fnT get()
224223
{
225-
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
226-
void>)
227-
{
224+
if constexpr (!AcosOutputType<T>::is_defined) {
228225
fnT fn = nullptr;
229226
return fn;
230227
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,15 @@ using AcoshStridedFunctor = elementwise_common::
172172

173173
template <typename T> struct AcoshOutputType
174174
{
175-
using value_type = typename std::disjunction< // disjunction is C++17
176-
// feature, supported by DPC++
175+
using value_type = typename std::disjunction<
177176
td_ns::TypeMapResultEntry<T, sycl::half>,
178177
td_ns::TypeMapResultEntry<T, float>,
179178
td_ns::TypeMapResultEntry<T, double>,
180179
td_ns::TypeMapResultEntry<T, std::complex<float>>,
181180
td_ns::TypeMapResultEntry<T, std::complex<double>>,
182181
td_ns::DefaultResultEntry<void>>::result_type;
182+
183+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
183184
};
184185

185186
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -201,9 +202,7 @@ template <typename fnT, typename T> struct AcoshContigFactory
201202
{
202203
fnT get()
203204
{
204-
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
205-
void>)
206-
{
205+
if constexpr (!AcoshOutputType<T>::is_defined) {
207206
fnT fn = nullptr;
208207
return fn;
209208
}
@@ -249,9 +248,7 @@ template <typename fnT, typename T> struct AcoshStridedFactory
249248
{
250249
fnT get()
251250
{
252-
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
253-
void>)
254-
{
251+
if constexpr (!AcoshOutputType<T>::is_defined) {
255252
fnT fn = nullptr;
256253
return fn;
257254
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ using AddStridedFunctor =
132132

133133
template <typename T1, typename T2> struct AddOutputType
134134
{
135-
using value_type = typename std::disjunction< // disjunction is C++17
136-
// feature, supported by DPC++
135+
using value_type = typename std::disjunction<
137136
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
138137
td_ns::BinaryTypeMapResultEntry<T1,
139138
std::uint8_t,
@@ -193,6 +192,8 @@ template <typename T1, typename T2> struct AddOutputType
193192
std::complex<double>,
194193
std::complex<double>>,
195194
td_ns::DefaultResultEntry<void>>::result_type;
195+
196+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
196197
};
197198

198199
template <typename argT1,
@@ -223,9 +224,7 @@ template <typename fnT, typename T1, typename T2> struct AddContigFactory
223224
{
224225
fnT get()
225226
{
226-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
227-
void>)
228-
{
227+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
229228
fnT fn = nullptr;
230229
return fn;
231230
}
@@ -273,9 +272,7 @@ template <typename fnT, typename T1, typename T2> struct AddStridedFactory
273272
{
274273
fnT get()
275274
{
276-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
277-
void>)
278-
{
275+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
279276
fnT fn = nullptr;
280277
return fn;
281278
}
@@ -324,12 +321,12 @@ struct AddContigMatrixContigRowBroadcastFactory
324321
{
325322
fnT get()
326323
{
327-
using resT = typename AddOutputType<T1, T2>::value_type;
328-
if constexpr (std::is_same_v<resT, void>) {
324+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
329325
fnT fn = nullptr;
330326
return fn;
331327
}
332328
else {
329+
using resT = typename AddOutputType<T1, T2>::value_type;
333330
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
334331
dpctl::tensor::type_utils::is_complex<T2>::value ||
335332
dpctl::tensor::type_utils::is_complex<resT>::value)
@@ -371,12 +368,12 @@ struct AddContigRowContigMatrixBroadcastFactory
371368
{
372369
fnT get()
373370
{
374-
using resT = typename AddOutputType<T1, T2>::value_type;
375-
if constexpr (std::is_same_v<resT, void>) {
371+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
376372
fnT fn = nullptr;
377373
return fn;
378374
}
379375
else {
376+
using resT = typename AddOutputType<T1, T2>::value_type;
380377
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
381378
dpctl::tensor::type_utils::is_complex<T2>::value ||
382379
dpctl::tensor::type_utils::is_complex<resT>::value)
@@ -438,6 +435,50 @@ template <typename argT,
438435
unsigned int n_vecs>
439436
class add_inplace_contig_kernel;
440437

438+
/* @brief Types supported by in-place add */
439+
template <typename argTy, typename resTy> struct AddInplaceTypePairSupport
440+
{
441+
/* value if true a kernel for <argTy, resTy> must be instantiated */
442+
static constexpr bool is_defined = std::disjunction<
443+
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
444+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
445+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
446+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
447+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
448+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
449+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
450+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
451+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
452+
td_ns::TypePairDefinedEntry<argTy, sycl::half, resTy, sycl::half>,
453+
td_ns::TypePairDefinedEntry<argTy, float, resTy, float>,
454+
td_ns::TypePairDefinedEntry<argTy, double, resTy, double>,
455+
td_ns::TypePairDefinedEntry<argTy,
456+
std::complex<float>,
457+
resTy,
458+
std::complex<float>>,
459+
td_ns::TypePairDefinedEntry<argTy,
460+
std::complex<double>,
461+
resTy,
462+
std::complex<double>>,
463+
// fall-through
464+
td_ns::NotDefinedEntry>::is_defined;
465+
};
466+
467+
template <typename fnT, typename argT, typename resT>
468+
struct AddInplaceTypeMapFactory
469+
{
470+
/*! @brief get typeid for output type of x += y */
471+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
472+
{
473+
if constexpr (AddInplaceTypePairSupport<argT, resT>::is_defined) {
474+
return td_ns::GetTypeid<resT>{}.get();
475+
}
476+
else {
477+
return td_ns::GetTypeid<void>{}.get();
478+
}
479+
}
480+
};
481+
441482
template <typename argTy, typename resTy>
442483
sycl::event
443484
add_inplace_contig_impl(sycl::queue &exec_q,
@@ -457,9 +498,7 @@ template <typename fnT, typename T1, typename T2> struct AddInplaceContigFactory
457498
{
458499
fnT get()
459500
{
460-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
461-
void>)
462-
{
501+
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
463502
fnT fn = nullptr;
464503
return fn;
465504
}
@@ -497,9 +536,7 @@ struct AddInplaceStridedFactory
497536
{
498537
fnT get()
499538
{
500-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
501-
void>)
502-
{
539+
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
503540
fnT fn = nullptr;
504541
return fn;
505542
}
@@ -544,8 +581,7 @@ struct AddInplaceRowMatrixBroadcastFactory
544581
{
545582
fnT get()
546583
{
547-
using resT = typename AddOutputType<T1, T2>::value_type;
548-
if constexpr (!std::is_same_v<resT, T2>) {
584+
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
549585
fnT fn = nullptr;
550586
return fn;
551587
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ using AngleStridedFunctor = elementwise_common::
9191

9292
template <typename T> struct AngleOutputType
9393
{
94-
using value_type = typename std::disjunction< // disjunction is C++17
95-
// feature, supported by DPC++
94+
using value_type = typename std::disjunction<
9695
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
9796
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
9897
td_ns::DefaultResultEntry<void>>::result_type;
98+
99+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
99100
};
100101

101102
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -117,9 +118,7 @@ template <typename fnT, typename T> struct AngleContigFactory
117118
{
118119
fnT get()
119120
{
120-
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
121-
void>)
122-
{
121+
if constexpr (!AngleOutputType<T>::is_defined) {
123122
fnT fn = nullptr;
124123
return fn;
125124
}
@@ -165,9 +164,7 @@ template <typename fnT, typename T> struct AngleStridedFactory
165164
{
166165
fnT get()
167166
{
168-
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
169-
void>)
170-
{
167+
if constexpr (!AngleOutputType<T>::is_defined) {
171168
fnT fn = nullptr;
172169
return fn;
173170
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,15 @@ using AsinStridedFunctor = elementwise_common::
165165

166166
template <typename T> struct AsinOutputType
167167
{
168-
using value_type = typename std::disjunction< // disjunction is C++17
169-
// feature, supported by DPC++
168+
using value_type = typename std::disjunction<
170169
td_ns::TypeMapResultEntry<T, sycl::half>,
171170
td_ns::TypeMapResultEntry<T, float>,
172171
td_ns::TypeMapResultEntry<T, double>,
173172
td_ns::TypeMapResultEntry<T, std::complex<float>>,
174173
td_ns::TypeMapResultEntry<T, std::complex<double>>,
175174
td_ns::DefaultResultEntry<void>>::result_type;
175+
176+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
176177
};
177178

178179
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -194,9 +195,7 @@ template <typename fnT, typename T> struct AsinContigFactory
194195
{
195196
fnT get()
196197
{
197-
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
198-
void>)
199-
{
198+
if constexpr (!AsinOutputType<T>::is_defined) {
200199
fnT fn = nullptr;
201200
return fn;
202201
}
@@ -242,9 +241,7 @@ template <typename fnT, typename T> struct AsinStridedFactory
242241
{
243242
fnT get()
244243
{
245-
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
246-
void>)
247-
{
244+
if constexpr (!AsinOutputType<T>::is_defined) {
248245
fnT fn = nullptr;
249246
return fn;
250247
}

0 commit comments

Comments
 (0)