Skip to content

Commit 2ffedad

Browse files
committed
Add is_defined static boolean to output type tables and uses them in type dispatching
1 parent 0aa5321 commit 2ffedad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+310
-502
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ template <typename T> struct AbsOutputType
118118
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
119119
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
120120
td_ns::DefaultResultEntry<void>>::result_type;
121+
122+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
121123
};
122124

123125
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -139,9 +141,7 @@ template <typename fnT, typename T> struct AbsContigFactory
139141
{
140142
fnT get()
141143
{
142-
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
143-
void>)
144-
{
144+
if constexpr (!AbsOutputType<T>::is_defined) {
145145
fnT fn = nullptr;
146146
return fn;
147147
}
@@ -190,9 +190,7 @@ template <typename fnT, typename T> struct AbsStridedFactory
190190
{
191191
fnT get()
192192
{
193-
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
194-
void>)
195-
{
193+
if constexpr (!AbsOutputType<T>::is_defined) {
196194
fnT fn = nullptr;
197195
return fn;
198196
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ template <typename T> struct AcosOutputType
152152
td_ns::TypeMapResultEntry<T, std::complex<float>>,
153153
td_ns::TypeMapResultEntry<T, std::complex<double>>,
154154
td_ns::DefaultResultEntry<void>>::result_type;
155+
156+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
155157
};
156158

157159
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -173,9 +175,7 @@ template <typename fnT, typename T> struct AcosContigFactory
173175
{
174176
fnT get()
175177
{
176-
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
177-
void>)
178-
{
178+
if constexpr (!AcosOutputType<T>::is_defined) {
179179
fnT fn = nullptr;
180180
return fn;
181181
}
@@ -221,9 +221,7 @@ template <typename fnT, typename T> struct AcosStridedFactory
221221
{
222222
fnT get()
223223
{
224-
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
225-
void>)
226-
{
224+
if constexpr (!AcosOutputType<T>::is_defined) {
227225
fnT fn = nullptr;
228226
return fn;
229227
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ template <typename T> struct AcoshOutputType
179179
td_ns::TypeMapResultEntry<T, std::complex<float>>,
180180
td_ns::TypeMapResultEntry<T, std::complex<double>>,
181181
td_ns::DefaultResultEntry<void>>::result_type;
182+
183+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
182184
};
183185

184186
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -200,9 +202,7 @@ template <typename fnT, typename T> struct AcoshContigFactory
200202
{
201203
fnT get()
202204
{
203-
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
204-
void>)
205-
{
205+
if constexpr (!AcoshOutputType<T>::is_defined) {
206206
fnT fn = nullptr;
207207
return fn;
208208
}
@@ -248,9 +248,7 @@ template <typename fnT, typename T> struct AcoshStridedFactory
248248
{
249249
fnT get()
250250
{
251-
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
252-
void>)
253-
{
251+
if constexpr (!AcoshOutputType<T>::is_defined) {
254252
fnT fn = nullptr;
255253
return fn;
256254
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ template <typename T1, typename T2> struct AddOutputType
192192
std::complex<double>,
193193
std::complex<double>>,
194194
td_ns::DefaultResultEntry<void>>::result_type;
195+
196+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
195197
};
196198

197199
template <typename argT1,
@@ -222,9 +224,7 @@ template <typename fnT, typename T1, typename T2> struct AddContigFactory
222224
{
223225
fnT get()
224226
{
225-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
226-
void>)
227-
{
227+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
228228
fnT fn = nullptr;
229229
return fn;
230230
}
@@ -272,9 +272,7 @@ template <typename fnT, typename T1, typename T2> struct AddStridedFactory
272272
{
273273
fnT get()
274274
{
275-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
276-
void>)
277-
{
275+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
278276
fnT fn = nullptr;
279277
return fn;
280278
}
@@ -323,12 +321,12 @@ struct AddContigMatrixContigRowBroadcastFactory
323321
{
324322
fnT get()
325323
{
326-
using resT = typename AddOutputType<T1, T2>::value_type;
327-
if constexpr (std::is_same_v<resT, void>) {
324+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
328325
fnT fn = nullptr;
329326
return fn;
330327
}
331328
else {
329+
using resT = typename AddOutputType<T1, T2>::value_type;
332330
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
333331
dpctl::tensor::type_utils::is_complex<T2>::value ||
334332
dpctl::tensor::type_utils::is_complex<resT>::value)
@@ -370,12 +368,12 @@ struct AddContigRowContigMatrixBroadcastFactory
370368
{
371369
fnT get()
372370
{
373-
using resT = typename AddOutputType<T1, T2>::value_type;
374-
if constexpr (std::is_same_v<resT, void>) {
371+
if constexpr (!AddOutputType<T1, T2>::is_defined) {
375372
fnT fn = nullptr;
376373
return fn;
377374
}
378375
else {
376+
using resT = typename AddOutputType<T1, T2>::value_type;
379377
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
380378
dpctl::tensor::type_utils::is_complex<T2>::value ||
381379
dpctl::tensor::type_utils::is_complex<resT>::value)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ template <typename T> struct AngleOutputType
9595
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
9696
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
9797
td_ns::DefaultResultEntry<void>>::result_type;
98+
99+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
98100
};
99101

100102
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -116,9 +118,7 @@ template <typename fnT, typename T> struct AngleContigFactory
116118
{
117119
fnT get()
118120
{
119-
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
120-
void>)
121-
{
121+
if constexpr (!AngleOutputType<T>::is_defined) {
122122
fnT fn = nullptr;
123123
return fn;
124124
}
@@ -164,9 +164,7 @@ template <typename fnT, typename T> struct AngleStridedFactory
164164
{
165165
fnT get()
166166
{
167-
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
168-
void>)
169-
{
167+
if constexpr (!AngleOutputType<T>::is_defined) {
170168
fnT fn = nullptr;
171169
return fn;
172170
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ template <typename T> struct AsinOutputType
172172
td_ns::TypeMapResultEntry<T, std::complex<float>>,
173173
td_ns::TypeMapResultEntry<T, std::complex<double>>,
174174
td_ns::DefaultResultEntry<void>>::result_type;
175+
176+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
175177
};
176178

177179
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -193,9 +195,7 @@ template <typename fnT, typename T> struct AsinContigFactory
193195
{
194196
fnT get()
195197
{
196-
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
197-
void>)
198-
{
198+
if constexpr (!AsinOutputType<T>::is_defined) {
199199
fnT fn = nullptr;
200200
return fn;
201201
}
@@ -241,9 +241,7 @@ template <typename fnT, typename T> struct AsinStridedFactory
241241
{
242242
fnT get()
243243
{
244-
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
245-
void>)
246-
{
244+
if constexpr (!AsinOutputType<T>::is_defined) {
247245
fnT fn = nullptr;
248246
return fn;
249247
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ template <typename T> struct AsinhOutputType
155155
td_ns::TypeMapResultEntry<T, std::complex<float>>,
156156
td_ns::TypeMapResultEntry<T, std::complex<double>>,
157157
td_ns::DefaultResultEntry<void>>::result_type;
158+
159+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
158160
};
159161

160162
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -176,9 +178,7 @@ template <typename fnT, typename T> struct AsinhContigFactory
176178
{
177179
fnT get()
178180
{
179-
if constexpr (std::is_same_v<typename AsinhOutputType<T>::value_type,
180-
void>)
181-
{
181+
if constexpr (!AsinhOutputType<T>::is_defined) {
182182
fnT fn = nullptr;
183183
return fn;
184184
}
@@ -224,9 +224,7 @@ template <typename fnT, typename T> struct AsinhStridedFactory
224224
{
225225
fnT get()
226226
{
227-
if constexpr (std::is_same_v<typename AsinhOutputType<T>::value_type,
228-
void>)
229-
{
227+
if constexpr (!AsinhOutputType<T>::is_defined) {
230228
fnT fn = nullptr;
231229
return fn;
232230
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ template <typename T> struct AtanOutputType
162162
td_ns::TypeMapResultEntry<T, std::complex<float>>,
163163
td_ns::TypeMapResultEntry<T, std::complex<double>>,
164164
td_ns::DefaultResultEntry<void>>::result_type;
165+
166+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
165167
};
166168

167169
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -183,9 +185,7 @@ template <typename fnT, typename T> struct AtanContigFactory
183185
{
184186
fnT get()
185187
{
186-
if constexpr (std::is_same_v<typename AtanOutputType<T>::value_type,
187-
void>)
188-
{
188+
if constexpr (!AtanOutputType<T>::is_defined) {
189189
fnT fn = nullptr;
190190
return fn;
191191
}
@@ -231,9 +231,7 @@ template <typename fnT, typename T> struct AtanStridedFactory
231231
{
232232
fnT get()
233233
{
234-
if constexpr (std::is_same_v<typename AtanOutputType<T>::value_type,
235-
void>)
236-
{
234+
if constexpr (!AtanOutputType<T>::is_defined) {
237235
fnT fn = nullptr;
238236
return fn;
239237
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ template <typename T1, typename T2> struct Atan2OutputType
9999
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
100100
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
101101
td_ns::DefaultResultEntry<void>>::result_type;
102+
103+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
102104
};
103105

104106
template <typename argT1,
@@ -129,9 +131,7 @@ template <typename fnT, typename T1, typename T2> struct Atan2ContigFactory
129131
{
130132
fnT get()
131133
{
132-
if constexpr (std::is_same_v<
133-
typename Atan2OutputType<T1, T2>::value_type, void>)
134-
{
134+
if constexpr (!Atan2OutputType<T1, T2>::is_defined) {
135135
fnT fn = nullptr;
136136
return fn;
137137
}
@@ -148,7 +148,6 @@ template <typename fnT, typename T1, typename T2> struct Atan2TypeMapFactory
148148
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
149149
{
150150
using rT = typename Atan2OutputType<T1, T2>::value_type;
151-
;
152151
return td_ns::GetTypeid<rT>{}.get();
153152
}
154153
};
@@ -182,9 +181,7 @@ template <typename fnT, typename T1, typename T2> struct Atan2StridedFactory
182181
{
183182
fnT get()
184183
{
185-
if constexpr (std::is_same_v<
186-
typename Atan2OutputType<T1, T2>::value_type, void>)
187-
{
184+
if constexpr (!Atan2OutputType<T1, T2>::is_defined) {
188185
fnT fn = nullptr;
189186
return fn;
190187
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ template <typename T> struct AtanhOutputType
156156
td_ns::TypeMapResultEntry<T, std::complex<float>>,
157157
td_ns::TypeMapResultEntry<T, std::complex<double>>,
158158
td_ns::DefaultResultEntry<void>>::result_type;
159+
160+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
159161
};
160162

161163
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
@@ -177,9 +179,7 @@ template <typename fnT, typename T> struct AtanhContigFactory
177179
{
178180
fnT get()
179181
{
180-
if constexpr (std::is_same_v<typename AtanhOutputType<T>::value_type,
181-
void>)
182-
{
182+
if constexpr (!AtanhOutputType<T>::is_defined) {
183183
fnT fn = nullptr;
184184
return fn;
185185
}
@@ -225,9 +225,7 @@ template <typename fnT, typename T> struct AtanhStridedFactory
225225
{
226226
fnT get()
227227
{
228-
if constexpr (std::is_same_v<typename AtanhOutputType<T>::value_type,
229-
void>)
230-
{
228+
if constexpr (!AtanhOutputType<T>::is_defined) {
231229
fnT fn = nullptr;
232230
return fn;
233231
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ template <typename T1, typename T2> struct BitwiseAndOutputType
156156
std::int64_t,
157157
std::int64_t>,
158158
td_ns::DefaultResultEntry<void>>::result_type;
159+
160+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
159161
};
160162

161163
template <typename argT1,
@@ -187,10 +189,7 @@ template <typename fnT, typename T1, typename T2> struct BitwiseAndContigFactory
187189
{
188190
fnT get()
189191
{
190-
if constexpr (std::is_same_v<
191-
typename BitwiseAndOutputType<T1, T2>::value_type,
192-
void>)
193-
{
192+
if constexpr (!BitwiseAndOutputType<T1, T2>::is_defined) {
194193
fnT fn = nullptr;
195194
return fn;
196195
}
@@ -243,10 +242,7 @@ struct BitwiseAndStridedFactory
243242
{
244243
fnT get()
245244
{
246-
if constexpr (std::is_same_v<
247-
typename BitwiseAndOutputType<T1, T2>::value_type,
248-
void>)
249-
{
245+
if constexpr (!BitwiseAndOutputType<T1, T2>::is_defined) {
250246
fnT fn = nullptr;
251247
return fn;
252248
}

0 commit comments

Comments
 (0)