Skip to content

Commit afe7f9b

Browse files
committed
Additional fixes
Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 0d4fa27 commit afe7f9b

File tree

3 files changed

+157
-71
lines changed

3 files changed

+157
-71
lines changed

sycl/include/sycl/builtins_marray.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,13 @@ __SYCL_MATH_FUNCTION_OVERLOAD_FM(sqrt)
10171017
__SYCL_MATH_FUNCTION_OVERLOAD_FM(rsqrt)
10181018
#undef __SYCL_MATH_FUNCTION_OVERLOAD_FM
10191019

1020+
template <typename T, size_t N>
1021+
inline __SYCL_ALWAYS_INLINE
1022+
std::enable_if_t<std::is_same_v<T, float>, marray<T, N>>
1023+
powr(marray<T, N> x, marray<T, N> y) __NOEXC {
1024+
return native::powr(x, y);
1025+
}
1026+
10201027
#endif // __FAST_MATH__
10211028
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
10221029
} // namespace sycl

sycl/include/sycl/builtins_utils.hpp

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <sycl/detail/common.hpp>
1212
#include <sycl/detail/generic_type_traits.hpp>
13+
#include <sycl/types.hpp>
1314

1415
namespace sycl {
1516
__SYCL_INLINE_VER_NAMESPACE(_V1) {
@@ -78,62 +79,131 @@ template <int N, int... Ns> constexpr bool CheckSizeIn() {
7879
}
7980

8081
template <typename T, typename... Ts>
81-
struct is_valid_vec_type : std::false_type {};
82+
struct is_valid_elem_type : std::false_type {};
8283
template <typename T, int N, typename... Ts>
83-
struct is_valid_vec_type<vec<T, N>, Ts...>
84+
struct is_valid_elem_type<vec<T, N>, Ts...>
8485
: std::bool_constant<CheckTypeIn<T, Ts...>()> {};
85-
86-
template <typename T, int... Ns> struct is_valid_vec_size : std::false_type {};
86+
template <typename VecT, typename OperationLeftT, typename OperationRightT,
87+
template <typename> class OperationCurrentT, int... Indexes,
88+
typename... Ts>
89+
struct is_valid_elem_type<SwizzleOp<VecT, OperationLeftT, OperationRightT,
90+
OperationCurrentT, Indexes...>,
91+
Ts...>
92+
: std::bool_constant<CheckTypeIn<typename VecT::element_type, Ts...>()> {};
93+
94+
template <typename T, int... Ns> struct is_valid_size : std::false_type {};
8795
template <typename T, int N, int... Ns>
88-
struct is_valid_vec_size<vec<T, N>, Ns...>
96+
struct is_valid_size<vec<T, N>, Ns...>
8997
: std::bool_constant<CheckSizeIn<N, Ns...>()> {};
98+
template <typename VecT, typename OperationLeftT, typename OperationRightT,
99+
template <typename> class OperationCurrentT, int... Indexes,
100+
int... Ns>
101+
struct is_valid_size<SwizzleOp<VecT, OperationLeftT, OperationRightT,
102+
OperationCurrentT, Indexes...>,
103+
Ns...>
104+
: std::bool_constant<CheckSizeIn<sizeof...(Indexes), Ns...>()> {};
90105

91106
template <typename T, typename... Ts>
92-
constexpr bool is_valid_vec_type_v = is_valid_vec_type<T, Ts...>::value;
107+
constexpr bool is_valid_elem_type_v = is_valid_elem_type<T, Ts...>::value;
93108
template <typename T, int... Ns>
94-
constexpr bool is_valid_vec_size_v = is_valid_vec_size<T, Ns...>::value;
109+
constexpr bool is_valid_size_v = is_valid_size<T, Ns...>::value;
95110

96-
template <typename T> struct vec_return;
97-
template <typename T, int N> struct vec_return<vec<T, N>> {
111+
template <typename T> struct get_vec;
112+
template <typename T, int N> struct get_vec<vec<T, N>> {
98113
using type = vec<T, N>;
99114
};
100-
// TODO: Make specialization for swizzle.
115+
template <typename VecT, typename OperationLeftT, typename OperationRightT,
116+
template <typename> class OperationCurrentT, int... Indexes>
117+
struct get_vec<SwizzleOp<VecT, OperationLeftT, OperationRightT,
118+
OperationCurrentT, Indexes...>> {
119+
using type = vec<typename VecT::element_type, sizeof...(Indexes)>;
120+
};
101121

102-
template <typename T> using vec_return_t = typename vec_return<T>::type;
122+
template <typename T> using get_vec_t = typename get_vec<T>::type;
103123

104-
template <typename T> struct same_size_signed_int {
124+
template <size_t Size> struct get_signed_int_by_size {
125+
using type = std::conditional_t<
126+
Size == 1, int8_t,
127+
std::conditional_t<
128+
Size == 2, int16_t,
129+
std::conditional_t<Size == 4, int32_t,
130+
std::conditional_t<Size == 8, int64_t, void>>>>;
131+
};
132+
133+
template <size_t Size> struct get_unsigned_int_by_size {
105134
using type = std::conditional_t<
106-
sizeof(T) == 1, int8_t,
135+
Size == 1, uint8_t,
107136
std::conditional_t<
108-
sizeof(T) == 2, int16_t,
109-
std::conditional_t<
110-
sizeof(T) == 4, int32_t,
111-
std::conditional_t<sizeof(T) == 8, int64_t, void>>>>;
137+
Size == 2, uint16_t,
138+
std::conditional_t<Size == 4, uint32_t,
139+
std::conditional_t<Size == 8, uint64_t, void>>>>;
140+
};
141+
142+
template <size_t Size> struct get_float_by_size {
143+
using type = std::conditional_t<
144+
Size == 2, half,
145+
std::conditional_t<Size == 4, float,
146+
std::conditional_t<Size == 8, double, void>>>;
147+
};
148+
149+
template <typename T> struct same_size_signed_int {
150+
using type = typename get_signed_int_by_size<sizeof(T)>::type;
112151
};
113152

114153
template <typename T, int N> struct same_size_signed_int<vec<T, N>> {
115154
using type = vec<typename same_size_signed_int<T>::type, N>;
116155
};
156+
// TODO: Swizzle variant of this?
117157

118158
template <typename T>
119159
using same_size_signed_int_t = typename same_size_signed_int<T>::type;
120160

121161
template <typename T> struct same_size_unsigned_int {
122-
using type = std::conditional_t<
123-
sizeof(T) == 1, uint8_t,
124-
std::conditional_t<
125-
sizeof(T) == 2, uint16_t,
126-
std::conditional_t<
127-
sizeof(T) == 4, uint32_t,
128-
std::conditional_t<sizeof(T) == 8, uint64_t, void>>>>;
162+
using type = typename get_unsigned_int_by_size<sizeof(T)>::type;
129163
};
130164

131165
template <typename T, int N> struct same_size_unsigned_int<vec<T, N>> {
132166
using type = vec<typename same_size_unsigned_int<T>::type, N>;
133167
};
168+
// TODO: Swizzle variant of this?
169+
170+
template <typename T> struct same_size_float {
171+
using type = typename get_float_by_size<sizeof(T)>::type;
172+
};
173+
174+
template <typename T, int N> struct same_size_float<vec<T, N>> {
175+
using type = vec<typename same_size_float<T>::type, N>;
176+
};
177+
// TODO: Swizzle variant of this?
178+
179+
template <typename T>
180+
using same_size_float_t = typename same_size_float<T>::type;
181+
182+
// For upsampling we look for an integer of double the size of the specified
183+
// type.
184+
template <typename T> struct upsampled_int {
185+
using type =
186+
std::conditional_t<std::is_unsigned_v<T>,
187+
typename get_unsigned_int_by_size<sizeof(T) * 2>::type,
188+
typename get_signed_int_by_size<sizeof(T) * 2>::type>;
189+
};
190+
template <typename T, int N> struct upsampled_int<vec<T, N>> {
191+
using type = vec<typename upsampled_int<T>::type, N>;
192+
};
193+
// TODO: Swizzle variant of this?
194+
195+
template <typename T> using upsampled_int_t = typename upsampled_int<T>::type;
196+
197+
template <typename> struct is_swizzle : std::false_type {};
198+
template <typename VecT, typename OperationLeftT, typename OperationRightT,
199+
template <typename> class OperationCurrentT, int... Indexes>
200+
struct is_swizzle<SwizzleOp<VecT, OperationLeftT, OperationRightT,
201+
OperationCurrentT, Indexes...>> : std::true_type {};
202+
203+
template <typename T> constexpr bool is_swizzle_v = is_swizzle<T>::value;
134204

135205
template <typename T>
136-
using same_size_unsigned_int_t = typename same_size_unsigned_int<T>::type;
206+
constexpr bool is_vec_or_swizzle_v = is_vec_v<T> || is_swizzle_v<T>;
137207

138208
} // namespace detail
139209
} // __SYCL_INLINE_VER_NAMESPACE(_V1)

sycl/source/builtins_generator.py

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, related_arg_type, vec_type):
2222
self.vec_type = vec_type
2323

2424
def __str__(self):
25-
return f'detail::vec_return_t<{self.related_arg_type}>'
25+
return f'detail::get_vec_t<{self.related_arg_type}>'
2626

2727
class MultiPtr:
2828
def __init__(self, element_type):
@@ -66,20 +66,19 @@ def __init__(self, signed_type, parent_idx):
6666
def __str__(self):
6767
return f'detail::make_unsigned_t<{self.signed_type}>'
6868

69-
class SameSizeIntType:
70-
def __init__(self, signed, parent_idx):
71-
self.signed = signed
69+
class ConversionTraitType:
70+
def __init__(self, trait, parent_idx):
71+
self.trait = trait
7272
self.parent_idx = parent_idx
7373

74-
class InstantiatedSameSizeIntType:
75-
def __init__(self, parent_type, signed, parent_idx):
74+
class InstantiatedConversionTraitType:
75+
def __init__(self, parent_type, trait, parent_idx):
7676
self.parent_type = parent_type
77-
self.signed = signed
77+
self.trait = trait
7878
self.parent_idx = parent_idx
7979

8080
def __str__(self):
81-
signedness = 'signed' if self.signed else 'unsigned'
82-
return f'detail::same_size_{signedness}_int_t<{self.parent_type}>'
81+
return f'detail::{self.trait}<{self.parent_type}>'
8382

8483
### GENTYPE DEFINITIONS
8584
# NOTE: Marray is currently explicitly defined.
@@ -190,8 +189,10 @@ def __str__(self):
190189
# argument.
191190
elementtype0 = [ElementType(0)]
192191
unsignedtype0 = [UnsignedType(0)]
193-
samesizesignedint0 = [SameSizeIntType(True, 0)]
194-
samesizeunsignedint0 = [SameSizeIntType(False, 0)]
192+
samesizesignedint0 = [ConversionTraitType("same_size_signed_int_t", 0)]
193+
samesizeunsignedint0 = [ConversionTraitType("same_size_unsigned_int_t", 0)]
194+
samesizefloat0 = [ConversionTraitType("same_size_float_t", 0)]
195+
upsampledint0 = [ConversionTraitType("upsampled_int_t", 0)]
195196

196197
builtin_types = {
197198
"floatn" : floatn,
@@ -276,6 +277,8 @@ def __str__(self):
276277
"unsignedtype0" : unsignedtype0,
277278
"samesizesignedint0" : samesizesignedint0,
278279
"samesizeunsignedint0" : samesizeunsignedint0,
280+
"samesizefloat0" : samesizefloat0,
281+
"upsampledint0" : upsampledint0,
279282
"char" : ["char"],
280283
"signed char" : ["signed char"],
281284
"short" : ["short"],
@@ -309,6 +312,11 @@ def find_first_vec_arg(arg_types):
309312
return arg_type
310313
return None
311314

315+
def convert_vec_arg_name(arg_type, arg_name):
316+
if isinstance(arg_type, InstantiatedVecArg):
317+
return f'typename detail::get_vec_t<{arg_type}>({arg_name})'
318+
return arg_name
319+
312320
class Def:
313321
def __init__(self, return_type, arg_types, invoke_name=None,
314322
invoke_prefix="", custom_invoke=None, fast_math_invoke_name=None,
@@ -328,7 +336,7 @@ def __init__(self, return_type, arg_types, invoke_name=None,
328336
self.vec_size_alias = vec_size_alias
329337

330338
def get_invoke_args(self, arg_types, arg_names):
331-
result = arg_names
339+
result = list(map(convert_vec_arg_name, arg_types, arg_names))
332340
for (arg_idx, type_conv) in self.convert_args:
333341
# type_conv is either an index or a conversion function/type.
334342
conv = type_conv if isinstance(type_conv, str) else arg_types[type_conv]
@@ -424,8 +432,8 @@ def custom_bool_select_invoke(return_type, _, arg_names):
424432
"fract": [Def("vgenfloat", ["vgenfloat", "vgenfloatptr"]),
425433
Def("float", ["float", "floatptr"]),
426434
Def("double", ["double", "doubleptr"]),
427-
Def("half", ["half", "halfptr"]),
428-
Def("vgenfloat", ["vgenfloat", "vint32nptr"]),
435+
Def("half", ["half", "halfptr"])],
436+
"frexp": [Def("vgenfloat", ["vgenfloat", "vint32nptr"]),
429437
Def("float", ["float", "intptr"]),
430438
Def("double", ["double", "intptr"]),
431439
Def("half", ["half", "intptr"])],
@@ -456,13 +464,13 @@ def custom_bool_select_invoke(return_type, _, arg_names):
456464
Def("float", ["float", "floatptr"]),
457465
Def("double", ["double", "doubleptr"]),
458466
Def("half", ["half", "halfptr"])],
459-
"nan": [Def("vfloatn", ["vuint32n"]),
460-
Def("float", ["unsigned int"]),
461-
Def("vdoublen", ["vuint64n_ext"]),
462-
Def("double", ["unsigned long"]),
463-
Def("double", ["unsigned long long"]),
464-
Def("vhalfn", ["vuint16n"]),
465-
Def("half", ["unsigned short"])],
467+
"nan": [Def("samesizefloat0", ["vuint32n"]),
468+
Def("samesizefloat0", ["unsigned int"]),
469+
Def("samesizefloat0", ["vuint64n_ext"]),
470+
Def("samesizefloat0", ["unsigned long"]),
471+
Def("samesizefloat0", ["unsigned long long"]),
472+
Def("samesizefloat0", ["vuint16n"]),
473+
Def("samesizefloat0", ["unsigned short"])],
466474
"nextafter": [Def("genfloat", ["genfloat", "genfloat"])],
467475
"pow": [Def("genfloat", ["genfloat", "genfloat"])],
468476
"pown": [Def("vgenfloat", ["vgenfloat", "vint32n"]),
@@ -498,9 +506,6 @@ def custom_bool_select_invoke(return_type, _, arg_names):
498506
"tgamma": [Def("genfloat", ["genfloat"])],
499507
"trunc": [Def("genfloat", ["genfloat"])],
500508
# Integer functions
501-
"abs": [Def("sigeninteger", ["sigeninteger"], custom_invoke=custom_signed_abs_scalar_invoke),
502-
Def("vigeninteger", ["vigeninteger"], custom_invoke=custom_signed_abs_vector_invoke),
503-
Def("ugeninteger", ["ugeninteger"], invoke_prefix="u_")],
504509
"abs_diff": [Def("unsignedtype0", ["igeninteger", "igeninteger"], invoke_prefix="s_"),
505510
Def("ugeninteger", ["ugeninteger", "ugeninteger"], invoke_prefix="u_")],
506511
"add_sat": [Def("igeninteger", ["igeninteger", "igeninteger"], invoke_prefix="s_"),
@@ -520,18 +525,19 @@ def custom_bool_select_invoke(return_type, _, arg_names):
520525
"rotate": [Def("geninteger", ["geninteger", "geninteger"])],
521526
"sub_sat": [Def("igeninteger", ["igeninteger", "igeninteger"], invoke_prefix="s_"),
522527
Def("ugeninteger", ["ugeninteger", "ugeninteger"], invoke_prefix="u_")],
523-
"upsample": [Def("int16_t", ["int8_t", "uint8_t"], invoke_prefix="s_"),
524-
Def("vint16n", ["vint8n", "vuint8n"], invoke_prefix="s_"),
525-
Def("uint16_t", ["uint8_t", "uint8_t"], invoke_prefix="u_"),
526-
Def("vuint16n", ["vuint8n", "vuint8n"], invoke_prefix="u_"),
527-
Def("int32_t", ["int16_t", "uint16_t"], invoke_prefix="s_"),
528-
Def("vint32n", ["vint16n", "vuint16n"], invoke_prefix="s_"),
529-
Def("uint32_t", ["uint16_t", "uint16_t"], invoke_prefix="u_"),
530-
Def("vuint32n", ["vuint16n", "vuint16n"], invoke_prefix="u_"),
531-
Def("int64_t", ["int32_t", "uint32_t"], invoke_prefix="s_"),
532-
Def("vint64n", ["vint32n", "vuint32n"], invoke_prefix="s_"),
533-
Def("uint64_t", ["uint32_t", "uint32_t"], invoke_prefix="u_"),
534-
Def("vuint64n", ["vuint32n", "vuint32n"], invoke_prefix="u_")],
528+
"upsample": [Def("upsampledint0", ["int8_t", "uint8_t"], invoke_prefix="s_"),
529+
Def("upsampledint0", ["char", "uint8_t"], invoke_prefix="s_"), # TODO: Non-standard. Deprecate.
530+
Def("upsampledint0", ["vint8n", "vuint8n"], invoke_prefix="s_"),
531+
Def("upsampledint0", ["uint8_t", "uint8_t"], invoke_prefix="u_"),
532+
Def("upsampledint0", ["vuint8n", "vuint8n"], invoke_prefix="u_"),
533+
Def("upsampledint0", ["int16_t", "uint16_t"], invoke_prefix="s_"),
534+
Def("upsampledint0", ["vint16n", "vuint16n"], invoke_prefix="s_"),
535+
Def("upsampledint0", ["uint16_t", "uint16_t"], invoke_prefix="u_"),
536+
Def("upsampledint0", ["vuint16n", "vuint16n"], invoke_prefix="u_"),
537+
Def("upsampledint0", ["int32_t", "uint32_t"], invoke_prefix="s_"),
538+
Def("upsampledint0", ["vint32n", "vuint32n"], invoke_prefix="s_"),
539+
Def("upsampledint0", ["uint32_t", "uint32_t"], invoke_prefix="u_"),
540+
Def("upsampledint0", ["vuint32n", "vuint32n"], invoke_prefix="u_")],
535541
"popcount": [Def("geninteger", ["geninteger"])],
536542
"mad24": [Def("igenint32", ["igenint32", "igenint32", "igenint32"], invoke_prefix="s_"),
537543
Def("ugenint32", ["ugenint32", "ugenint32", "ugenint32"], invoke_prefix="u_")],
@@ -571,7 +577,10 @@ def custom_bool_select_invoke(return_type, _, arg_names):
571577
Def("vfloatn", ["float", "float", "vfloatn"], convert_args=[(0,2),(1,2)]),
572578
Def("vdoublen", ["double", "double", "vdoublen"], convert_args=[(0,2),(1,2)])],
573579
"sign": [Def("genfloat", ["genfloat"])],
574-
"abs": [Def("genfloat", ["genfloat"], invoke_prefix="f")], # TODO: Non-standard. Deprecate.
580+
"abs": [Def("genfloat", ["genfloat"], invoke_prefix="f"), # TODO: Non-standard. Deprecate.
581+
Def("sigeninteger", ["sigeninteger"], custom_invoke=custom_signed_abs_scalar_invoke),
582+
Def("vigeninteger", ["vigeninteger"], custom_invoke=custom_signed_abs_vector_invoke),
583+
Def("ugeninteger", ["ugeninteger"], invoke_prefix="u_")],
575584
# Geometric functions
576585
"cross": [Def("vfloat3or4", ["vfloat3or4", "vfloat3or4"]),
577586
Def("vdouble3or4", ["vdouble3or4", "vdouble3or4"]),
@@ -698,9 +707,9 @@ def select_from_mapping(mappings, arg_types, arg_type):
698707
if isinstance(mapping, UnsignedType):
699708
parent_mapping = mappings[arg_types[mapping.parent_idx]]
700709
return InstantiatedUnsignedType(parent_mapping, mapping.parent_idx)
701-
if isinstance(mapping, SameSizeIntType):
710+
if isinstance(mapping, ConversionTraitType):
702711
parent_mapping = mappings[arg_types[mapping.parent_idx]]
703-
return InstantiatedSameSizeIntType(parent_mapping, mapping.signed, mapping.parent_idx)
712+
return InstantiatedConversionTraitType(parent_mapping, mapping.trait, mapping.parent_idx)
704713
return mapping
705714

706715
def instantiate_arg(idx, arg):
@@ -714,8 +723,8 @@ def instantiate_arg(idx, arg):
714723
return InstantiatedElementType(instantiate_arg(arg.parent_idx, arg.referenced_type), arg.parent_idx)
715724
if isinstance(arg, InstantiatedUnsignedType):
716725
return InstantiatedUnsignedType(instantiate_arg(arg.parent_idx, arg.signed_type), arg.parent_idx)
717-
if isinstance(arg, InstantiatedSameSizeIntType):
718-
return InstantiatedSameSizeIntType(instantiate_arg(arg.parent_idx, arg.parent_type), arg.signed, arg.parent_idx)
726+
if isinstance(arg, InstantiatedConversionTraitType):
727+
return InstantiatedConversionTraitType(instantiate_arg(arg.parent_idx, arg.parent_type), arg.trait, arg.parent_idx)
719728
return arg
720729

721730
def instantiate_return_type(return_type, instantiated_args):
@@ -730,8 +739,8 @@ def instantiate_return_type(return_type, instantiated_args):
730739
return InstantiatedElementType(instantiate_return_type(return_type.referenced_type, instantiated_args), return_type.parent_idx)
731740
if isinstance(return_type, InstantiatedUnsignedType):
732741
return InstantiatedUnsignedType(instantiate_return_type(return_type.signed_type, instantiated_args), return_type.parent_idx)
733-
if isinstance(return_type, InstantiatedSameSizeIntType):
734-
return InstantiatedSameSizeIntType(instantiate_return_type(return_type.parent_type, instantiated_args), return_type.signed, return_type.parent_idx)
742+
if isinstance(return_type, InstantiatedConversionTraitType):
743+
return InstantiatedConversionTraitType(instantiate_return_type(return_type.parent_type, instantiated_args), return_type.trait, return_type.parent_idx)
735744
return return_type
736745

737746
def type_combinations(return_type, arg_types):
@@ -772,9 +781,9 @@ def get_all_vec_args(arg_types):
772781
def get_vec_arg_requirement(vec_arg):
773782
valid_type_str = ', '.join(vec_arg.vec_type.valid_types)
774783
valid_sizes_str = ', '.join(map(str, vec_arg.vec_type.valid_sizes))
775-
checks = [f'detail::is_vec_v<{vec_arg}>',
776-
f'detail::is_valid_vec_type_v<{vec_arg.template_name}, {valid_type_str}>',
777-
f'detail::is_valid_vec_size_v<{vec_arg.template_name}, {valid_sizes_str}>']
784+
checks = [f'detail::is_vec_or_swizzle_v<{vec_arg}>',
785+
f'detail::is_valid_elem_type_v<{vec_arg.template_name}, {valid_type_str}>',
786+
f'detail::is_valid_size_v<{vec_arg.template_name}, {valid_sizes_str}>']
778787
return '(' + (' && '.join(checks)) + ')'
779788

780789
def get_func_return(return_type, arg_types):

0 commit comments

Comments
 (0)