Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 5ea081a

Browse files
committed
Adds separate test comparing wi_marray with get_wi_data usage.
Signed-off-by: JackAKirk <[email protected]>
1 parent b9ebfe9 commit 5ea081a

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//==----------- element_wise_wi_marray.cpp - DPC++ joint_matrix------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: cuda
9+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 %s -o %t.out
10+
// RUN: %t.out
11+
12+
#include <sycl/sycl.hpp>
13+
14+
using namespace sycl;
15+
using namespace sycl::ext::oneapi::experimental::matrix;
16+
using sycl::ext::oneapi::experimental::bfloat16;
17+
18+
#define SG_SZ 32
19+
20+
template <typename T, size_t M, size_t K> void verify_wi_marray(queue q) {
21+
int err = 0;
22+
{
23+
buffer<int> err_buf(&err, 1);
24+
q.submit([&](handler &cgh) {
25+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh);
26+
27+
cgh.parallel_for<class marray_kernel>(
28+
nd_range<2>({1, 1 * SG_SZ}, {1, 1 * SG_SZ}),
29+
[ERR](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] {
30+
auto sg = spmd_item.get_sub_group();
31+
32+
joint_matrix<T, matrix_use::a, M, K> sub_a;
33+
joint_matrix<T, matrix_use::a, M, K> sub_a_2;
34+
35+
joint_matrix_fill(sg, sub_a, -1);
36+
joint_matrix_fill(sg, sub_a_2, -1);
37+
38+
auto wi_slice_a = sub_a.get_wi_data();
39+
for (int i = 0; i < wi_slice_a.length(); i++) {
40+
wi_slice_a[i] = fabs(wi_slice_a[i]);
41+
}
42+
sub_a_2.wi_marray = fabs(sub_a_2.wi_marray);
43+
44+
for (int i = 0; i < sub_a_2.wi_marray.size(); i++) {
45+
if (sub_a_2.wi_marray[i] != wi_slice_a[i]) {
46+
ERR[0] = 1;
47+
}
48+
}
49+
}); // parallel for
50+
})
51+
.wait();
52+
}
53+
assert(err == 0);
54+
}
55+
56+
int main() {
57+
58+
queue q;
59+
auto computeCapability =
60+
std::stof(q.get_device().get_info<info::device::backend_version>());
61+
62+
if (computeCapability >= 8.0) {
63+
verify_wi_marray<bfloat16, 16, 16>(q);
64+
}
65+
66+
return 0;
67+
}

SYCL/Matrix/joint_matrix_tensorcore.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
6868
res += make_fp32(A[m * Big_K + k]) * make_fp32(B[k * Big_N + n]);
6969
} else if constexpr (std::is_same<T1, bfloat16>::value) {
7070
for (int k = 0; k < Big_K; k++)
71-
res += (make_fp32(A[m * Big_K + k].raw()) * 2 + 1) *
71+
res += make_fp32(A[m * Big_K + k].raw()) *
7272
make_fp32(B[k * Big_N + n].raw());
7373
} else {
7474
for (int k = 0; k < Big_K; k++)
@@ -192,14 +192,6 @@ void test(queue &q) {
192192
accA.get_pointer() + (k * K) + (m * M * Big_K),
193193
Big_K);
194194

195-
if constexpr (std::is_same<T1, bfloat16>::value) {
196-
marray<bfloat16, sub_a.wi_marray.size()> b, c;
197-
b = 2;
198-
c = 1;
199-
sub_a.wi_marray =
200-
sycl::ext::oneapi::experimental::fma(sub_a.wi_marray, b, c);
201-
}
202-
203195
joint_matrix_load(sg, sub_b,
204196
accB.get_pointer() + (k * K * Big_N) + (n * N),
205197
Big_N);

0 commit comments

Comments
 (0)