1
1
2
2
// REQUIRES: cuda
3
- // Temp xfail: test was merged early.
4
- // XFAIL: cuda
5
3
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out
6
4
// RUN: %t.out
7
5
//
14
12
#include < sycl/sycl.hpp>
15
13
16
14
using namespace sycl ;
17
- using namespace sycl ::ext::oneapi::experimental ;
15
+ using namespace sycl ::ext::oneapi;
18
16
using namespace sycl ::ext::oneapi::experimental::matrix;
19
17
constexpr float bf16_eps = 0.00390625 ;
20
18
@@ -146,9 +144,11 @@ void test(queue &q) {
146
144
// column id of current submatrix of BIG C matrix
147
145
const auto n = item.get_group ().get_group_id ()[1 ];
148
146
149
- joint_matrix<T3, use::a, M, K, layout::row_major> sub_a;
150
- joint_matrix<T3, use::b, K, N, layout::row_major> sub_b;
151
- joint_matrix<std::remove_const_t <T2>, use::accumulator, M, N> sub_c;
147
+ joint_matrix<sub_group, T3, use::a, M, K, layout::row_major> sub_a;
148
+ joint_matrix<sub_group, T3, use::b, K, N, layout::row_major> sub_b;
149
+ joint_matrix<sub_group, std::remove_const_t <T2>, use::accumulator,
150
+ M, N>
151
+ sub_c;
152
152
153
153
joint_matrix_load (sg, sub_c,
154
154
accC.get_pointer () + (m * M) * Big_N + n * N,
@@ -165,11 +165,13 @@ void test(queue &q) {
165
165
166
166
// round values to correct precision if using tf32
167
167
if constexpr (std::is_same<T3, precision::tf32>::value) {
168
- auto wi_size = sub_a. wi_marray . size ();
169
- assert (wi_size == sub_b. wi_marray . size ());
168
+ auto wi_size = get_wi_data (sg, sub_a). length ();
169
+ assert (wi_size == get_wi_data (sg, sub_b). length ());
170
170
for (auto i = 0 ; i < wi_size; ++i) {
171
- sub_a.wi_marray [i] = round_to_tf32 (sub_a.wi_marray [i]);
172
- sub_b.wi_marray [i] = round_to_tf32 (sub_b.wi_marray [i]);
171
+ get_wi_data (sg, sub_a)[i] =
172
+ round_to_tf32 (get_wi_data (sg, sub_a)[i]);
173
+ get_wi_data (sg, sub_b)[i] =
174
+ round_to_tf32 (get_wi_data (sg, sub_b)[i]);
173
175
}
174
176
}
175
177
0 commit comments