Skip to content

Commit 478b30c

Browse files
Added missing include, Identity to use has_known_identity
Implementation of Identity trait should call sycl::known_identity if trait sycl::has_known_identity is a true_type. Added IsMultiplies, and identity value for it, since sycl::known_identity for multiplies is only defined for real-valued types.
1 parent df1c22f commit 478b30c

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <CL/sycl.hpp>
2727
#include <algorithm>
2828
#include <cstddef>
29+
#include <type_traits>
2930
#include <vector>
3031

3132
#include "math_utils.hpp"
@@ -272,6 +273,18 @@ struct GetIdentity<Op,
272273
template <typename T, class Op>
273274
using IsPlus = std::bool_constant<std::is_same_v<Op, sycl::plus<T>> ||
274275
std::is_same_v<Op, std::plus<T>>>;
276+
// Multiplies
277+
278+
template <typename T, class Op>
279+
using IsMultiplies =
280+
std::bool_constant<std::is_same_v<Op, sycl::multiplies<T>> ||
281+
std::is_same_v<Op, std::multiplies<T>>>;
282+
283+
template <typename Op, typename T>
284+
struct GetIdentity<Op, T, std::enable_if_t<IsMultiplies<T, Op>::value>>
285+
{
286+
static constexpr T value = static_cast<T>(1);
287+
};
275288

276289
// Identity
277290

@@ -280,13 +293,17 @@ template <typename Op, typename T, typename = void> struct Identity
280293
};
281294

282295
template <typename Op, typename T>
283-
struct Identity<Op, T, std::enable_if_t<!IsSyclOp<T, Op>::value>>
296+
using UseBuiltInIdentity =
297+
std::conjunction<IsSyclOp<T, Op>, sycl::has_known_identity<Op, T>>;
298+
299+
template <typename Op, typename T>
300+
struct Identity<Op, T, std::enable_if_t<!UseBuiltInIdentity<Op, T>::value>>
284301
{
285302
static constexpr T value = GetIdentity<Op, T>::value;
286303
};
287304

288305
template <typename Op, typename T>
289-
struct Identity<Op, T, std::enable_if_t<IsSyclOp<T, Op>::value>>
306+
struct Identity<Op, T, std::enable_if_t<UseBuiltInIdentity<Op, T>::value>>
290307
{
291308
static constexpr T value = sycl::known_identity<Op, T>::value;
292309
};

0 commit comments

Comments
 (0)