Skip to content

Commit 841ea9c

Browse files
committed
Fix test failures, address review comments.
Signed-off-by: kbobrovs <[email protected]>
1 parent 5b40088 commit 841ea9c

File tree

10 files changed

+74
-51
lines changed

10 files changed

+74
-51
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/detail/operators.hpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,21 @@ namespace __SEIEE {
235235
template <class SimdT1, class RegionT1, class SimdT2, class RegionT2, \
236236
class T1 = typename __SEIEE::shape_type<RegionT1>::element_type, \
237237
class T2 = typename __SEIEE::shape_type<RegionT2>::element_type, \
238+
auto N1 = __SEIEE::shape_type<RegionT1>::length, \
239+
auto N2 = __SEIEE::shape_type<RegionT2>::length, \
238240
class = \
239241
std::enable_if_t<__SEIEED::is_simd_type_v<SimdT1> == \
240242
__SEIEED::is_simd_type_v<SimdT2> && \
241-
(__SEIEE::shape_type<RegionT1>::length == \
242-
__SEIEE::shape_type<RegionT2>::length) && \
243-
COND>> \
243+
(N1 == N2 || N1 == 1 || N2 == 1) && COND>> \
244244
inline auto operator BINOP( \
245245
const __SEIEE::simd_view<SimdT1, RegionT1> &LHS, \
246246
const __SEIEE::simd_view<SimdT2, RegionT2> &RHS) { \
247-
return LHS.read() BINOP RHS.read(); \
247+
if constexpr (N1 == 1) \
248+
return (T1)LHS.read()[0] BINOP RHS.read(); \
249+
else if constexpr (N2 == 1) \
250+
return LHS.read() BINOP(T2) RHS.read()[0]; \
251+
else \
252+
return LHS.read() BINOP RHS.read(); \
248253
} \
249254
\
250255
/* simd* BINOP simd_view<simd*...> */ \
@@ -337,20 +342,27 @@ __ESIMD_DEF_SIMD_VIEW_BIN_OP(||, __SEIEED::is_simd_mask_type_v<SimdT1>)
337342
\
338343
/* simd_view CMPOP simd_view */ \
339344
template <class SimdT1, class RegionT1, class SimdT2, class RegionT2, \
340-
class = \
341-
std::enable_if_t</* both views must have the same base type \
342-
kind - simds or masks: */ \
343-
(__SEIEED::is_simd_type_v<SimdT1> == \
344-
__SEIEED::is_simd_type_v< \
345-
SimdT2>)&&/* the length of the views \
346-
must match as well: */ \
347-
(__SEIEE::shape_type<RegionT1>::length == \
348-
__SEIEE::shape_type<RegionT2>::length) && \
349-
COND>> \
345+
auto N1 = __SEIEE::shape_type<RegionT1>::length, \
346+
auto N2 = __SEIEE::shape_type<RegionT2>::length, \
347+
class = std::enable_if_t</* both views must have the same base \
348+
type kind - simds or masks: */ \
349+
(__SEIEED::is_simd_type_v<SimdT1> == \
350+
__SEIEED::is_simd_type_v< \
351+
SimdT2>)&&/* the length of the views \
352+
must match as well: */ \
353+
(N1 == N2 || N1 == 1 || N2 == 1) && \
354+
COND>> \
350355
inline auto operator CMPOP( \
351356
const __SEIEE::simd_view<SimdT1, RegionT1> &LHS, \
352357
const __SEIEE::simd_view<SimdT2, RegionT2> &RHS) { \
353-
return LHS.read() CMPOP RHS.read(); \
358+
using T1 = typename __SEIEE::shape_type<RegionT1>::element_type; \
359+
using T2 = typename __SEIEE::shape_type<RegionT2>::element_type; \
360+
if constexpr (N1 == 1) \
361+
return (T1)LHS.read()[0] CMPOP RHS.read(); \
362+
else if constexpr (N2 == 1) \
363+
return LHS.read() CMPOP(T2) RHS.read()[0]; \
364+
else \
365+
return LHS.read() CMPOP RHS.read(); \
354366
} \
355367
\
356368
/* simd_view CMPOP simd_obj_impl */ \

sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//==------------ - simd_obj_impl.hpp - DPC++ Explicit SIMD API
2-
//--------------------==//
1+
//==------------ - simd_obj_impl.hpp - DPC++ Explicit SIMD API -------------==//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.

sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_view_impl.hpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,16 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
286286

287287
/// @{
288288
/// Assignment operators.
289-
Derived &operator=(const Derived &Other) { return write(Other.read()); }
289+
simd_view_impl &operator=(const simd_view_impl &Other) {
290+
return write(Other.read());
291+
}
290292

291293
Derived &operator=(const value_type &Val) { return write(Val); }
292294

293295
/// Move assignment operator.
294-
Derived &operator=(Derived &&Other) { return write(Other.read()); }
296+
simd_view_impl &operator=(simd_view_impl &&Other) {
297+
return write(Other.read());
298+
}
295299

296300
template <class T, int N, class SimdT,
297301
class = std::enable_if_t<(is_simd_type_v<SimdT> ==
@@ -342,7 +346,7 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
342346
template <typename T = Derived,
343347
typename = sycl::detail::enable_if_t<T::is2D()>>
344348
auto row(int i) {
345-
return select<1, 0, getSizeX(), 1>(i, 0)
349+
return select<1, 1, getSizeX(), 1>(i, 0)
346350
.template bit_cast_view<element_type>();
347351
}
348352

@@ -351,7 +355,7 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
351355
template <typename T = Derived,
352356
typename = sycl::detail::enable_if_t<T::is2D()>>
353357
auto column(int i) {
354-
return select<getSizeY(), 1, 1, 0>(0, i);
358+
return select<getSizeY(), 1, 1, 1>(0, i);
355359
}
356360

357361
/// Read a single element from a 1D region, by value only.
@@ -375,15 +379,15 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
375379
template <typename T = Derived,
376380
typename = sycl::detail::enable_if_t<T::is1D()>>
377381
auto operator[](int i) {
378-
return select<1, 0>(i);
382+
return select<1, 1>(i);
379383
}
380384

381385
/// Return a writeable view of a single element.
382386
template <typename T = Derived,
383387
typename = sycl::detail::enable_if_t<T::is1D()>>
384388
__SYCL_DEPRECATED("use operator[] form.")
385389
auto operator()(int i) {
386-
return select<1, 0>(i);
390+
return select<1, 1>(i);
387391
}
388392

389393
/// \name Replicate
@@ -402,7 +406,7 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
402406
/// \return replicated simd instance.
403407
template <int Rep, int W>
404408
get_simd_t<element_type, Rep * W> replicate(uint16_t OffsetX) {
405-
return replicate<Rep, 0, W>(0, OffsetX);
409+
return replicate<Rep, 1, W>(0, OffsetX);
406410
}
407411

408412
/// \tparam Rep is number of times region has to be replicated.
@@ -413,7 +417,7 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
413417
template <int Rep, int W>
414418
get_simd_t<element_type, Rep * W> replicate(uint16_t OffsetY,
415419
uint16_t OffsetX) {
416-
return replicate<Rep, 0, W>(OffsetY, OffsetX);
420+
return replicate<Rep, 1, W>(OffsetY, OffsetX);
417421
}
418422

419423
/// \tparam Rep is number of times region has to be replicated.

sycl/include/sycl/ext/intel/experimental/esimd/detail/types.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,16 @@ static inline constexpr bool is_simd_mask_type_v = is_simd_mask_type<Ty>::value;
329329

330330
// @{
331331
// Checks if given type is a view of the simd type.
332-
template <typename Ty> struct is_simd_view_type : std::false_type {};
332+
template <typename Ty> struct is_simd_view_type_impl : std::false_type {};
333333

334334
template <class BaseT, class RegionT>
335-
struct is_simd_view_type<simd_view<BaseT, RegionT>>
335+
struct is_simd_view_type_impl<simd_view<BaseT, RegionT>>
336336
: std::conditional_t<is_simd_type_v<BaseT>, std::true_type,
337337
std::false_type> {};
338338

339+
template <class Ty>
340+
struct is_simd_view_type : is_simd_view_type_impl<remove_cvref_t<Ty>> {};
341+
339342
template <typename Ty>
340343
static inline constexpr bool is_simd_view_type_v = is_simd_view_type<Ty>::value;
341344
// @}

sycl/include/sycl/ext/intel/experimental/esimd/simd_view.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ class simd_view : public detail::simd_view_impl<BaseTy, RegionTy> {
7070
simd_view(const simd_view &Other) = default;
7171
simd_view(simd_view &&Other) = default;
7272

73+
simd_view &operator=(const simd_view &Other) {
74+
BaseClass::operator=(Other);
75+
return *this;
76+
}
77+
7378
using BaseClass::operator--;
7479
using BaseClass::operator++;
7580
using BaseClass::operator=;
@@ -162,7 +167,7 @@ class simd_view<BaseTy, std::pair<region1d_scalar_t<T>, NestedRegion>>
162167
: public detail::simd_view_impl<
163168
BaseTy, std::pair<region1d_scalar_t<T>, NestedRegion>> {
164169
template <typename, int> friend class simd;
165-
template <typename, typename, typename> friend class detail::simd_view_impl;
170+
template <typename, typename> friend class detail::simd_view_impl;
166171

167172
public:
168173
using RegionTy = std::pair<region1d_scalar_t<T>, NestedRegion>;

sycl/test/esimd/esimd-util-compiler-eval.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@ static_assert(log2<1024 * 1024>() == 20, "");
1919

2020
using BaseTy = simd<float, 4>;
2121
using RegionTy = region1d_t<float, 2, 1>;
22-
using RegionTy1 = region1d_scalar_t<float, 0, 0>;
23-
static_assert(
24-
!is_simd_view_v<
25-
simd_view_impl<BaseTy, RegionTy, simd_view<BaseTy, RegionTy>>>::value,
26-
"");
27-
static_assert(is_simd_view_v<simd_view<BaseTy, RegionTy>>::value, "");
28-
static_assert(is_simd_view_v<simd_view<BaseTy, RegionTy1>>::value, "");
29-
static_assert(!is_simd_view_v<BaseTy>::value, "");
22+
using RegionTy1 = region1d_scalar_t<float>;
23+
static_assert(!is_simd_view_type_v<simd_view_impl<BaseTy, RegionTy>>, "");
24+
static_assert(is_simd_view_type_v<simd_view<BaseTy, RegionTy>>, "");
25+
static_assert(is_simd_view_type_v<simd_view<BaseTy, RegionTy1>>, "");
26+
static_assert(!is_simd_view_type_v<BaseTy>, "");

sycl/test/esimd/lane_id.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<int, 16> foo(int x) {
2626
SIMT_BEGIN(16, lane)
2727
//CHECK: define internal spir_func void @_ZZ3fooiENKUlvE_clEv({{.*}}) {{.*}} #[[ATTR:[0-9]+]]
2828
//CHECK: %{{[0-9a-zA-Z_.]+}} = tail call spir_func i32 @_Z15__esimd_lane_idv()
29-
v.select<1, 0>(lane) = x++;
29+
v.select<1, 1>(lane) = x++;
3030
SIMT_END
3131
return v;
3232
}

sycl/test/esimd/simd_subscript.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void test_simd_writable_subscript() SYCL_ESIMD_FUNCTION {
6565
v[1] = 0; // returns simd_view
6666

6767
// CHECK: simd_subscript.cpp:69{{.*}}warning: {{.*}} deprecated
68-
// CHECK: sycl/ext/intel/experimental/esimd/simd.hpp:{{.*}} note: {{.*}} has been explicitly marked deprecated here
68+
// CHECK: sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp:{{.*}} note: {{.*}} has been explicitly marked deprecated here
6969
v(1) = 0;
7070
}
7171

@@ -76,7 +76,7 @@ void test_simd_const_subscript() SYCL_ESIMD_FUNCTION {
7676
cv[1] = 0;
7777

7878
// CHECK: simd_subscript.cpp:80{{.*}}warning: {{.*}} deprecated
79-
// CHECK: sycl/ext/intel/experimental/esimd/simd.hpp:{{.*}} note: {{.*}} has been explicitly marked deprecated here
79+
// CHECK: sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp:{{.*}} note: {{.*}} has been explicitly marked deprecated here
8080
int val3 = cv(0);
8181
}
8282

sycl/test/esimd/simd_view.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
using namespace sycl::ext::intel::experimental::esimd;
88

9-
SYCL_ESIMD_FUNCTION bool test_simd_view_bin_ops() {
9+
SYCL_ESIMD_FUNCTION auto test_simd_view_bin_ops() {
1010
simd<int, 16> v0 = 1;
1111
simd<int, 16> v1 = 2;
1212
auto ref0 = v0.select<8, 2>(0);
@@ -51,7 +51,7 @@ SYCL_ESIMD_FUNCTION bool test_simd_view_unary_ops() {
5151
ref0 <<= ref1;
5252
ref1 = -ref0;
5353
ref0 = ~ref1;
54-
auto mask = !ref0;
54+
auto mask = !(ref0 < ref1);
5555
return v1[0] == 1;
5656
}
5757

@@ -143,21 +143,21 @@ void test_simd_view_impl_api_ret_types() SYCL_ESIMD_FUNCTION {
143143
simd<float, 4> x = 0;
144144
auto v1 =
145145
x.select<2, 1>(0); // simd_view<simd<float, 4>, region1d_t<float, 2, 1>>
146-
static_assert(detail::is_simd_view_v<decltype(v1)>::value, "");
146+
static_assert(detail::is_simd_view_type_v<decltype(v1)>, "");
147147
auto v2 = v1.select<1, 1>(
148148
0); // simd_view<simd<float, 4>, std::pair<region_base<false, float, 1, 0,
149149
// 1, 1>, region_base<false, float, 1, 0, 2, 1>>>
150-
static_assert(detail::is_simd_view_v<decltype(v1)>::value, "");
150+
static_assert(detail::is_simd_view_type_v<decltype(v1)>, "");
151151

152152
auto v2_int = v2.bit_cast_view<int>();
153-
static_assert(detail::is_simd_view_v<decltype(v2_int)>::value, "");
153+
static_assert(detail::is_simd_view_type_v<decltype(v2_int)>, "");
154154
auto v2_int_2D = v2.bit_cast_view<int, 1, 1>();
155-
static_assert(detail::is_simd_view_v<decltype(v2_int_2D)>::value, "");
155+
static_assert(detail::is_simd_view_type_v<decltype(v2_int_2D)>, "");
156156

157157
auto v3 = x.select<2, 1>(2);
158158
auto &v4 = (v1 += v3);
159-
static_assert(detail::is_simd_view_v<decltype(v4)>::value, "");
160-
static_assert(detail::is_simd_view_v<decltype(++v4)>::value, "");
159+
static_assert(detail::is_simd_view_type_v<decltype(v4)>, "");
160+
static_assert(detail::is_simd_view_type_v<decltype(++v4)>, "");
161161
}
162162

163163
void test_simd_view_subscript() SYCL_ESIMD_FUNCTION {
@@ -189,9 +189,12 @@ void test_simd_view_writeable_subscript() SYCL_ESIMD_FUNCTION {
189189
void test_simd_view_binop_with_conv_to_scalar() SYCL_ESIMD_FUNCTION {
190190
simd<ushort, 64> s = 0;
191191
auto g = s.bit_cast_view<ushort, 4, 16>();
192-
auto x = g.row(1) - (g.row(1))[0]; // binary op
193-
auto y = g.row(1) & (g.row(1))[0]; // bitwise op
194-
auto z = g.row(1) < (g.row(1))[0]; // relational op
192+
auto x1 = g.row(1) - (g.row(1))[0]; // binary op
193+
auto x2 = (g.row(1))[0] - g.row(1); // binary op
194+
auto y1 = g.row(1) & (g.row(1))[0]; // bitwise op
195+
auto y2 = (g.row(1))[0] & g.row(1); // bitwise op
196+
auto z1 = g.row(1) < (g.row(1))[0]; // relational op
197+
auto z2 = (g.row(1))[0] < g.row(1); // relational op
195198
}
196199

197200
// This code is OK. The result of bit_cast_view should be mapped

sycl/test/esimd/simd_view_ret_warn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using namespace sycl::ext::intel::experimental::esimd;
77
// and it should be programmers fault, similar to string_view.
88
// However, sometimes we could return simd_view from a function
99
// implicitly. This test checks that users will see a warning in such situation.
10-
simd_view<simd<float, 4>, region1d_t<float, 1, 0>> f1(simd<float, 4> x) {
10+
simd_view<simd<float, 4>, region1d_t<float, 1, 1>> f1(simd<float, 4> x) {
1111
// expected-warning@+1 {{address of stack memory associated with parameter 'x' returned}}
1212
return x[0];
1313
}

0 commit comments

Comments
 (0)