Skip to content

Commit 1755f66

Browse files
committed
Decouples in-place and out-of-place type support tables
Improves readability of in-place code
1 parent 1b26f0e commit 1755f66

23 files changed

+658
-162
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,53 @@ template <typename argT,
438438
unsigned int n_vecs>
439439
class add_inplace_contig_kernel;
440440

441+
/* @brief Types supported by in-place add */
442+
template <typename argTy, typename resTy> struct AddInplaceTypePairSupport
443+
{
444+
/* value if true a kernel for <argTy, resTy> must be instantiated */
445+
static constexpr bool is_defined = std::disjunction< // disjunction is
446+
// C++17 feature,
447+
// supported by
448+
// DPC++ input bool
449+
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
450+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
451+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
452+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
453+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
454+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
455+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
456+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
457+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
458+
td_ns::TypePairDefinedEntry<argTy, sycl::half, resTy, sycl::half>,
459+
td_ns::TypePairDefinedEntry<argTy, float, resTy, float>,
460+
td_ns::TypePairDefinedEntry<argTy, double, resTy, double>,
461+
td_ns::TypePairDefinedEntry<argTy,
462+
std::complex<float>,
463+
resTy,
464+
std::complex<float>>,
465+
td_ns::TypePairDefinedEntry<argTy,
466+
std::complex<double>,
467+
resTy,
468+
std::complex<double>>,
469+
// fall-through
470+
td_ns::NotDefinedEntry>::is_defined;
471+
};
472+
473+
template <typename fnT, typename argT, typename resT>
474+
struct AddInplaceTypeMapFactory
475+
{
476+
/*! @brief get typeid for output type of x += y */
477+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
478+
{
479+
if constexpr (AddInplaceTypePairSupport<argT, resT>::is_defined) {
480+
return td_ns::GetTypeid<resT>{}.get();
481+
}
482+
else {
483+
return td_ns::GetTypeid<void>{}.get();
484+
}
485+
}
486+
};
487+
441488
template <typename argTy, typename resTy>
442489
sycl::event
443490
add_inplace_contig_impl(sycl::queue &exec_q,
@@ -457,9 +504,7 @@ template <typename fnT, typename T1, typename T2> struct AddInplaceContigFactory
457504
{
458505
fnT get()
459506
{
460-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
461-
void>)
462-
{
507+
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
463508
fnT fn = nullptr;
464509
return fn;
465510
}
@@ -497,9 +542,7 @@ struct AddInplaceStridedFactory
497542
{
498543
fnT get()
499544
{
500-
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
501-
void>)
502-
{
545+
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
503546
fnT fn = nullptr;
504547
return fn;
505548
}
@@ -544,8 +587,7 @@ struct AddInplaceRowMatrixBroadcastFactory
544587
{
545588
fnT get()
546589
{
547-
using resT = typename AddOutputType<T1, T2>::value_type;
548-
if constexpr (!std::is_same_v<resT, T2>) {
590+
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
549591
fnT fn = nullptr;
550592
return fn;
551593
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,44 @@ template <typename argT,
322322
unsigned int n_vecs>
323323
class bitwise_and_inplace_contig_kernel;
324324

325+
/* @brief Types supported by in-place bitwise AND */
326+
template <typename argTy, typename resTy>
327+
struct BitwiseAndInplaceTypePairSupport
328+
{
329+
/* value if true a kernel for <argTy, resTy> must be instantiated */
330+
static constexpr bool is_defined = std::disjunction< // disjunction is
331+
// C++17 feature,
332+
// supported by
333+
// DPC++ input bool
334+
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
335+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
336+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
337+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
338+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
339+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
340+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
341+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
342+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
343+
// fall-through
344+
td_ns::NotDefinedEntry>::is_defined;
345+
};
346+
347+
template <typename fnT, typename argT, typename resT>
348+
struct BitwiseAndInplaceTypeMapFactory
349+
{
350+
/*! @brief get typeid for output type of x &= y */
351+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
352+
{
353+
if constexpr (BitwiseAndInplaceTypePairSupport<argT, resT>::is_defined)
354+
{
355+
return td_ns::GetTypeid<resT>{}.get();
356+
}
357+
else {
358+
return td_ns::GetTypeid<void>{}.get();
359+
}
360+
}
361+
};
362+
325363
template <typename argTy, typename resTy>
326364
sycl::event
327365
bitwise_and_inplace_contig_impl(sycl::queue &exec_q,
@@ -343,10 +381,7 @@ struct BitwiseAndInplaceContigFactory
343381
{
344382
fnT get()
345383
{
346-
if constexpr (std::is_same_v<
347-
typename BitwiseAndOutputType<T1, T2>::value_type,
348-
void>)
349-
{
384+
if constexpr (!BitwiseAndInplaceTypePairSupport<T1, T2>::is_defined) {
350385
fnT fn = nullptr;
351386
return fn;
352387
}
@@ -385,10 +420,7 @@ struct BitwiseAndInplaceStridedFactory
385420
{
386421
fnT get()
387422
{
388-
if constexpr (std::is_same_v<
389-
typename BitwiseAndOutputType<T1, T2>::value_type,
390-
void>)
391-
{
423+
if constexpr (!BitwiseAndInplaceTypePairSupport<T1, T2>::is_defined) {
392424
fnT fn = nullptr;
393425
return fn;
394426
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,44 @@ template <typename argT,
336336
unsigned int n_vecs>
337337
class bitwise_left_shift_inplace_contig_kernel;
338338

339+
/* @brief Types supported by in-place bitwise left shift */
340+
template <typename argTy, typename resTy>
341+
struct BitwiseLeftShiftInplaceTypePairSupport
342+
{
343+
/* value if true a kernel for <argTy, resTy> must be instantiated */
344+
static constexpr bool is_defined = std::disjunction< // disjunction is
345+
// C++17 feature,
346+
// supported by
347+
// DPC++ input bool
348+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
349+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
350+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
351+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
352+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
353+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
354+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
355+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
356+
// fall-through
357+
td_ns::NotDefinedEntry>::is_defined;
358+
};
359+
360+
template <typename fnT, typename argT, typename resT>
361+
struct BitwiseLeftShiftInplaceTypeMapFactory
362+
{
363+
/*! @brief get typeid for output type of x <<= y */
364+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
365+
{
366+
if constexpr (BitwiseLeftShiftInplaceTypePairSupport<argT,
367+
resT>::is_defined)
368+
{
369+
return td_ns::GetTypeid<resT>{}.get();
370+
}
371+
else {
372+
return td_ns::GetTypeid<void>{}.get();
373+
}
374+
}
375+
};
376+
339377
template <typename argTy, typename resTy>
340378
sycl::event bitwise_left_shift_inplace_contig_impl(
341379
sycl::queue &exec_q,
@@ -357,9 +395,8 @@ struct BitwiseLeftShiftInplaceContigFactory
357395
{
358396
fnT get()
359397
{
360-
if constexpr (std::is_same_v<typename BitwiseLeftShiftOutputType<
361-
T1, T2>::value_type,
362-
void>)
398+
if constexpr (!BitwiseLeftShiftInplaceTypePairSupport<T1,
399+
T2>::is_defined)
363400
{
364401
fnT fn = nullptr;
365402
return fn;
@@ -399,9 +436,8 @@ struct BitwiseLeftShiftInplaceStridedFactory
399436
{
400437
fnT get()
401438
{
402-
if constexpr (std::is_same_v<typename BitwiseLeftShiftOutputType<
403-
T1, T2>::value_type,
404-
void>)
439+
if constexpr (!BitwiseLeftShiftInplaceTypePairSupport<T1,
440+
T2>::is_defined)
405441
{
406442
fnT fn = nullptr;
407443
return fn;

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,42 @@ template <typename argT,
318318
unsigned int n_vecs>
319319
class bitwise_or_inplace_contig_kernel;
320320

321+
/* @brief Types supported by in-place bitwise OR */
322+
template <typename argTy, typename resTy> struct BitwiseOrInplaceTypePairSupport
323+
{
324+
/* value if true a kernel for <argTy, resTy> must be instantiated */
325+
static constexpr bool is_defined = std::disjunction< // disjunction is
326+
// C++17 feature,
327+
// supported by
328+
// DPC++ input bool
329+
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
330+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
331+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
332+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
333+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
334+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
335+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
336+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
337+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
338+
// fall-through
339+
td_ns::NotDefinedEntry>::is_defined;
340+
};
341+
342+
template <typename fnT, typename argT, typename resT>
343+
struct BitwiseOrInplaceTypeMapFactory
344+
{
345+
/*! @brief get typeid for output type of x |= y */
346+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
347+
{
348+
if constexpr (BitwiseOrInplaceTypePairSupport<argT, resT>::is_defined) {
349+
return td_ns::GetTypeid<resT>{}.get();
350+
}
351+
else {
352+
return td_ns::GetTypeid<void>{}.get();
353+
}
354+
}
355+
};
356+
321357
template <typename argTy, typename resTy>
322358
sycl::event
323359
bitwise_or_inplace_contig_impl(sycl::queue &exec_q,
@@ -339,10 +375,7 @@ struct BitwiseOrInplaceContigFactory
339375
{
340376
fnT get()
341377
{
342-
if constexpr (std::is_same_v<
343-
typename BitwiseOrOutputType<T1, T2>::value_type,
344-
void>)
345-
{
378+
if constexpr (!BitwiseOrInplaceTypePairSupport<T1, T2>::is_defined) {
346379
fnT fn = nullptr;
347380
return fn;
348381
}
@@ -381,10 +414,7 @@ struct BitwiseOrInplaceStridedFactory
381414
{
382415
fnT get()
383416
{
384-
if constexpr (std::is_same_v<
385-
typename BitwiseOrOutputType<T1, T2>::value_type,
386-
void>)
387-
{
417+
if constexpr (!BitwiseOrInplaceTypePairSupport<T1, T2>::is_defined) {
388418
fnT fn = nullptr;
389419
return fn;
390420
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,44 @@ template <typename argT,
340340
unsigned int n_vecs>
341341
class bitwise_right_shift_inplace_contig_kernel;
342342

343+
/* @brief Types supported by in-place bitwise right shift */
344+
template <typename argTy, typename resTy>
345+
struct BitwiseRightShiftInplaceTypePairSupport
346+
{
347+
/* value if true a kernel for <argTy, resTy> must be instantiated */
348+
static constexpr bool is_defined = std::disjunction< // disjunction is
349+
// C++17 feature,
350+
// supported by
351+
// DPC++ input bool
352+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
353+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
354+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
355+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
356+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
357+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
358+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
359+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
360+
// fall-through
361+
td_ns::NotDefinedEntry>::is_defined;
362+
};
363+
364+
template <typename fnT, typename argT, typename resT>
365+
struct BitwiseRightShiftInplaceTypeMapFactory
366+
{
367+
/*! @brief get typeid for output type of x >>= y */
368+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
369+
{
370+
if constexpr (BitwiseRightShiftInplaceTypePairSupport<argT,
371+
resT>::is_defined)
372+
{
373+
return td_ns::GetTypeid<resT>{}.get();
374+
}
375+
else {
376+
return td_ns::GetTypeid<void>{}.get();
377+
}
378+
}
379+
};
380+
343381
template <typename argTy, typename resTy>
344382
sycl::event bitwise_right_shift_inplace_contig_impl(
345383
sycl::queue &exec_q,
@@ -361,9 +399,8 @@ struct BitwiseRightShiftInplaceContigFactory
361399
{
362400
fnT get()
363401
{
364-
if constexpr (std::is_same_v<typename BitwiseRightShiftOutputType<
365-
T1, T2>::value_type,
366-
void>)
402+
if constexpr (!BitwiseRightShiftInplaceTypePairSupport<T1,
403+
T2>::is_defined)
367404
{
368405
fnT fn = nullptr;
369406
return fn;
@@ -403,9 +440,8 @@ struct BitwiseRightShiftInplaceStridedFactory
403440
{
404441
fnT get()
405442
{
406-
if constexpr (std::is_same_v<typename BitwiseRightShiftOutputType<
407-
T1, T2>::value_type,
408-
void>)
443+
if constexpr (!BitwiseRightShiftInplaceTypePairSupport<T1,
444+
T2>::is_defined)
409445
{
410446
fnT fn = nullptr;
411447
return fn;

0 commit comments

Comments
 (0)