Skip to content

Commit d189cdf

Browse files
committed
Fixed bug in 1-elem simd_view element type definition.
Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent bd7ea4d commit d189cdf

File tree

5 files changed

+36
-40
lines changed

5 files changed

+36
-40
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,6 @@ static inline constexpr bool is_simd_flag_type_v = is_simd_flag_type<T>::value;
7474

7575
namespace detail {
7676

77-
// Determine element type of simd_obj_impl's Derived type w/o having to have
78-
// complete instantiation of the Derived type (is required by element_type_t,
79-
// hence can't be used here).
80-
template <class T> struct elem_type_of_derived;
81-
template <class T, int N> struct elem_type_of_derived<simd<T, N>> {
82-
using type = T;
83-
};
84-
template <class T, int N> struct elem_type_of_derived<simd_mask_impl<T, N>> {
85-
using type = simd_mask_elem_type; // equals T
86-
};
87-
8877
/// The simd_obj_impl vector class.
8978
///
9079
/// This is a base class for all ESIMD simd classes with real storage (simd,
@@ -113,7 +102,7 @@ class simd_obj_impl {
113102
template <typename, int> friend class simd;
114103
template <typename, int> friend class simd_mask_impl;
115104

116-
using element_type = typename elem_type_of_derived<Derived>::type;
105+
using element_type = simd_like_obj_element_type_t<Derived>;
117106
using Ty = element_type;
118107

119108
public:
@@ -296,7 +285,7 @@ class simd_obj_impl {
296285

297286
/// View this simd_obj_impl object in a different element type.
298287
template <typename EltTy> auto bit_cast_view() &[[clang::lifetimebound]] {
299-
using TopRegionTy = compute_format_type_t<simd_obj_impl, EltTy>;
288+
using TopRegionTy = compute_format_type_t<Derived, EltTy>;
300289
using RetTy = simd_view<Derived, TopRegionTy>;
301290
return RetTy{cast_this_to_derived(), TopRegionTy{0}};
302291
}
@@ -310,8 +299,7 @@ class simd_obj_impl {
310299
/// View as a 2-dimensional simd_view.
311300
template <typename EltTy, int Height, int Width>
312301
auto bit_cast_view() &[[clang::lifetimebound]] {
313-
using TopRegionTy =
314-
compute_format_type_2d_t<simd_obj_impl, EltTy, Height, Width>;
302+
using TopRegionTy = compute_format_type_2d_t<Derived, EltTy, Height, Width>;
315303
using RetTy = simd_view<Derived, TopRegionTy>;
316304
return RetTy{cast_this_to_derived(), TopRegionTy{0, 0}};
317305
}

sycl/include/sycl/ext/intel/experimental/esimd/detail/type_format.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ template <typename Ty, int N, typename EltTy,
2929
struct compute_format_type<SimdT<Ty, N>, EltTy>
3030
: compute_format_type_impl<Ty, N, EltTy> {};
3131

32-
template <typename Ty, int N, typename EltTy, class SimdT>
33-
struct compute_format_type<simd_obj_impl<Ty, N, SimdT>, EltTy>
34-
: compute_format_type_impl<Ty, N, EltTy> {};
35-
3632
template <typename BaseTy, typename RegionTy, typename EltTy>
3733
struct compute_format_type<simd_view<BaseTy, RegionTy>, EltTy> {
3834
using ShapeTy = typename shape_type<RegionTy>::type;
@@ -65,11 +61,6 @@ template <typename Ty, int N, typename EltTy, int Height, int Width,
6561
struct compute_format_type_2d<SimdT<Ty, N>, EltTy, Height, Width>
6662
: compute_format_type_2d_impl<Ty, N, EltTy, Height, Width> {};
6763

68-
template <typename Ty, int N, typename EltTy, int Height, int Width,
69-
class SimdT>
70-
struct compute_format_type_2d<simd_obj_impl<Ty, N, SimdT>, EltTy, Height, Width>
71-
: compute_format_type_2d_impl<Ty, N, EltTy, Height, Width> {};
72-
7364
template <typename BaseTy, typename RegionTy, typename EltTy, int Height,
7465
int Width>
7566
struct compute_format_type_2d<simd_view<BaseTy, RegionTy>, EltTy, Height,

sycl/include/sycl/ext/intel/experimental/esimd/detail/types.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,26 @@ struct element_type<T, std::enable_if_t<is_clang_vector_type_v<T>>> {
310310
};
311311

312312
template <typename T> using element_type_t = typename element_type<T>::type;
313+
314+
// Determine element type of simd_obj_impl's Derived type w/o having to have
315+
// complete instantiation of the Derived type (is required by element_type_t,
316+
// hence can't be used here).
317+
template <class T> struct simd_like_obj_info;
318+
template <class T, int N> struct simd_like_obj_info<simd<T, N>> {
319+
using type = T;
320+
static inline constexpr int length = N;
321+
};
322+
template <class T, int N> struct simd_like_obj_info<simd_mask_impl<T, N>> {
323+
using type = simd_mask_elem_type; // equals T
324+
static inline constexpr int length = N;
325+
};
326+
327+
template <typename T>
328+
using simd_like_obj_element_type_t = typename simd_like_obj_info<T>::type;
329+
template <typename T>
330+
static inline constexpr int simd_like_obj_length =
331+
simd_like_obj_info<T>::length;
332+
313333
// @}
314334

315335
template <typename To, typename From>

sycl/include/sycl/ext/intel/experimental/esimd/simd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class simd : public detail::simd_obj_impl<
7575
(T::length == 1) && detail::is_valid_simd_elem_type_v<To>>>
7676
operator To() const {
7777
__esimd_dbg_print(operator To());
78-
return convert_scalar<To, element_type>(base_type::data()[0]);
78+
return detail::convert_scalar<To, element_type>(base_type::data()[0]);
7979
}
8080

8181
/// @{

sycl/include/sycl/ext/intel/experimental/esimd/simd_view.hpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,22 @@ class simd_view : public detail::simd_view_impl<BaseTy, RegionTy> {
119119
/// bool b = v[0] > v[1] && v[2] < 42;
120120
///
121121
/// \ingroup sycl_esimd
122-
template <typename BaseTy>
123-
class simd_view<BaseTy, region1d_scalar_t<typename BaseTy::element_type>>
124-
: public detail::simd_view_impl<
125-
BaseTy, region1d_scalar_t<typename BaseTy::element_type>> {
122+
template <typename BaseTy, class ViewedElemT>
123+
class simd_view<BaseTy, region1d_scalar_t<ViewedElemT>>
124+
: public detail::simd_view_impl<BaseTy, region1d_scalar_t<ViewedElemT>> {
126125
template <typename, int, class, class> friend class detail::simd_obj_impl;
127126
template <typename, typename> friend class detail::simd_view_impl;
128127

129128
public:
130-
using RegionTy = region1d_scalar_t<typename BaseTy::element_type>;
129+
using RegionTy = region1d_scalar_t<ViewedElemT>;
131130
using BaseClass = detail::simd_view_impl<BaseTy, RegionTy>;
132131
using ShapeTy = typename shape_type<RegionTy>::type;
133132
static constexpr int length = ShapeTy::Size_x * ShapeTy::Size_y;
134133
static_assert(1 == length, "length of this view is not equal to 1");
134+
static_assert(std::is_same_v<typename ShapeTy::element_type, ViewedElemT>);
135135
/// The element type of this class, which could be different from the element
136136
/// type of the base object type.
137-
using element_type = typename ShapeTy::element_type;
137+
using element_type = ViewedElemT;
138138
using base_type = BaseTy;
139139
template <typename ElT, int N>
140140
using get_simd_t = typename BaseClass::template get_simd_t<ElT, N>;
@@ -174,26 +174,23 @@ class simd_view<BaseTy, region1d_scalar_t<typename BaseTy::element_type>>
174174
/// simd<int, 4> v = 1;
175175
/// auto v1 = v.select<2, 1>(0);
176176
/// auto v2 = v1[0]; // simd_view of a nested region for a single element
177-
template <typename BaseTy, typename NestedRegion>
178-
class simd_view<
179-
BaseTy,
180-
std::pair<region1d_scalar_t<typename BaseTy::element_type>, NestedRegion>>
177+
template <typename BaseTy, typename NestedRegion, class ViewedElemT>
178+
class simd_view<BaseTy, std::pair<region1d_scalar_t<ViewedElemT>, NestedRegion>>
181179
: public detail::simd_view_impl<
182-
BaseTy, std::pair<region1d_scalar_t<typename BaseTy::element_type>,
183-
NestedRegion>> {
180+
BaseTy, std::pair<region1d_scalar_t<ViewedElemT>, NestedRegion>> {
184181
template <typename, int> friend class simd;
185182
template <typename, typename> friend class detail::simd_view_impl;
186183

187184
public:
188-
using RegionTy =
189-
std::pair<region1d_scalar_t<typename BaseTy::element_type>, NestedRegion>;
185+
using RegionTy = std::pair<region1d_scalar_t<ViewedElemT>, NestedRegion>;
190186
using BaseClass = detail::simd_view_impl<BaseTy, RegionTy>;
191187
using ShapeTy = typename shape_type<RegionTy>::type;
192188
static constexpr int length = ShapeTy::Size_x * ShapeTy::Size_y;
193189
static_assert(1 == length, "length of this view is not equal to 1");
190+
static_assert(std::is_same_v<typename ShapeTy::element_type, ViewedElemT>);
194191
/// The element type of this class, which could be different from the element
195192
/// type of the base object type.
196-
using element_type = typename ShapeTy::element_type;
193+
using element_type = ViewedElemT;
197194

198195
private:
199196
simd_view(BaseTy &Base, RegionTy Region) : BaseClass(Base, Region) {}

0 commit comments

Comments
 (0)