Skip to content

Commit 8bbfeba

Browse files
dkhaldibb-sycl
authored andcommitted
[SYCL][Matrix] Add a new test for bf16 slicing, remove XFAIL from the half test case (intel#774)
Signed-off-by: Dounia Khaldi <[email protected]>
1 parent 017807c commit 8bbfeba

File tree

2 files changed

+277
-6
lines changed

2 files changed

+277
-6
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
//==----------- element_wise_all_ops_bf16.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
#include <random>
17+
18+
using namespace sycl;
19+
using namespace sycl::ext::intel;
20+
using namespace sycl::ext::oneapi::experimental::matrix;
21+
22+
#define SG_SZ 8
23+
24+
#define TM 8
25+
#define TN SG_SZ
26+
#define TK 16
27+
28+
static float make_fp32(uint16_t x) {
29+
unsigned int y = x;
30+
y = y << 16;
31+
float *res = reinterpret_cast<float *>(&y);
32+
return *res;
33+
}
34+
35+
static uint16_t make_bf16(float x) {
36+
int *res = reinterpret_cast<int *>(&x);
37+
*res = *res >> 16;
38+
return (uint16_t)*res;
39+
}
40+
41+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
42+
public:
43+
T *mat;
44+
45+
public:
46+
T *get_data() { return mat; }
47+
void set_data(T *data) { mat = data; }
48+
big_matrix(T *data) : mat(data) {}
49+
};
50+
51+
template <typename T, size_t M, size_t N>
52+
void assert_ops_ref(
53+
accessor<T, 2, access::mode::read, access::target::host_buffer> C,
54+
const float ref) {
55+
for (size_t i = 0; i < M; i++)
56+
for (size_t j = 0; j < N; j++) {
57+
auto diff = make_fp32(C[i][j]) - ref;
58+
assert(std::fabs(static_cast<float>(diff)) <
59+
std::numeric_limits<float>::epsilon());
60+
}
61+
}
62+
template <typename T, size_t M, size_t N>
63+
void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
64+
const float ref) {
65+
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
66+
67+
q.submit([&](handler &cgh) {
68+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
69+
70+
cgh.parallel_for<class add_matrix>(
71+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
72+
const auto global_idx = spmd_item.get_global_id(0);
73+
const auto global_idy = spmd_item.get_global_id(1);
74+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
75+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
76+
77+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
78+
joint_matrix<T, TM, TK> sub_a(sg);
79+
80+
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
81+
82+
auto wi_slice_a = sub_a.get_wi_data();
83+
for (int i = 0; i < wi_slice_a.length(); i++) {
84+
wi_slice_a[i] = wi_slice_a[i] + make_bf16(2);
85+
}
86+
joint_matrix_store(sg, sub_a,
87+
accA.get_pointer() + (sg_startx * TM) * N +
88+
sg_starty / SG_SZ * TN,
89+
N, matrix_layout::row_major);
90+
}); // parallel for
91+
}).wait();
92+
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
93+
}
94+
95+
template <typename T, size_t M, size_t N>
96+
void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
97+
const float ref) {
98+
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
99+
100+
q.submit([&](handler &cgh) {
101+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
102+
103+
cgh.parallel_for<class sub_matrix>(
104+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
105+
const auto global_idx = spmd_item.get_global_id(0);
106+
const auto global_idy = spmd_item.get_global_id(1);
107+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
108+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
109+
110+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
111+
joint_matrix<T, TM, TK> sub_a(sg);
112+
113+
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
114+
115+
auto wi_slice_a = sub_a.get_wi_data();
116+
for (int i = 0; i < wi_slice_a.length(); i++) {
117+
wi_slice_a[i] = wi_slice_a[i] - make_bf16(2);
118+
}
119+
joint_matrix_store(sg, sub_a,
120+
accA.get_pointer() + (sg_startx * TM) * N +
121+
sg_starty / SG_SZ * TN,
122+
N, matrix_layout::row_major);
123+
}); // parallel for
124+
}).wait();
125+
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
126+
}
127+
128+
template <typename T, size_t M, size_t N>
129+
void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
130+
const float ref) {
131+
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
132+
133+
q.submit([&](handler &cgh) {
134+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
135+
136+
cgh.parallel_for<class mul_matrix>(
137+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
138+
const auto global_idx = spmd_item.get_global_id(0);
139+
const auto global_idy = spmd_item.get_global_id(1);
140+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
141+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
142+
143+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
144+
joint_matrix<T, TM, TK> sub_a(sg);
145+
146+
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
147+
148+
auto wi_slice_a = sub_a.get_wi_data();
149+
for (int i = 0; i < wi_slice_a.length(); i++) {
150+
wi_slice_a[i] = wi_slice_a[i] * make_bf16(3.0);
151+
}
152+
joint_matrix_store(sg, sub_a,
153+
accA.get_pointer() + (sg_startx * TM) * N +
154+
sg_starty / SG_SZ * TN,
155+
N, matrix_layout::row_major);
156+
}); // parallel for
157+
}).wait();
158+
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
159+
}
160+
161+
template <typename T, size_t M, size_t N>
162+
void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
163+
const float ref) {
164+
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
165+
166+
q.submit([&](handler &cgh) {
167+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
168+
169+
cgh.parallel_for<class div_matrix>(
170+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
171+
const auto global_idx = spmd_item.get_global_id(0);
172+
const auto global_idy = spmd_item.get_global_id(1);
173+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
174+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
175+
176+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
177+
joint_matrix<T, TM, TK> sub_a(sg);
178+
179+
joint_matrix_fill(sg, sub_a, make_bf16(4.0));
180+
181+
auto wi_slice_a = sub_a.get_wi_data();
182+
for (int i = 0; i < wi_slice_a.length(); i++) {
183+
wi_slice_a[i] = wi_slice_a[i] / make_bf16(2.0);
184+
}
185+
joint_matrix_store(sg, sub_a,
186+
accA.get_pointer() + (sg_startx * TM) * N +
187+
sg_starty / SG_SZ * TN,
188+
N, matrix_layout::row_major);
189+
}); // parallel for
190+
}).wait();
191+
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
192+
}
193+
194+
template <typename T, size_t M, size_t N>
195+
void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
196+
const float ref) {
197+
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
198+
199+
q.submit([&](handler &cgh) {
200+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
201+
cgh.parallel_for<class logic_matrix>(
202+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
203+
const auto global_idx = spmd_item.get_global_id(0);
204+
const auto global_idy = spmd_item.get_global_id(1);
205+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
206+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
207+
208+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
209+
joint_matrix<T, TM, TK> sub_a(sg);
210+
211+
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
212+
213+
auto wi_slice_a = sub_a.get_wi_data();
214+
for (int i = 0; i < wi_slice_a.length(); i++) {
215+
if (wi_slice_a[i]) {
216+
if (wi_slice_a[i] > make_bf16(2.0) ||
217+
wi_slice_a[i] >= make_bf16(2.0) ||
218+
wi_slice_a[i] < make_bf16(2.0) ||
219+
wi_slice_a[i] <= make_bf16(2.0)) {
220+
T val = (wi_slice_a[i] != make_bf16(2.0)) ? wi_slice_a[i]
221+
: make_bf16(2.0);
222+
val = make_bf16(make_fp32(val) - static_cast<float>(1));
223+
val = make_bf16(make_fp32(val) + static_cast<float>(1));
224+
if (wi_slice_a[i] == make_bf16(2.0)) {
225+
val = make_bf16(make_fp32(val) - static_cast<float>(2));
226+
val = make_bf16(make_fp32(val) * static_cast<float>(3));
227+
val = make_bf16(make_fp32(val) / static_cast<float>(2));
228+
229+
} else {
230+
val = make_bf16(make_fp32(val) + static_cast<float>(2));
231+
}
232+
wi_slice_a[i] = val;
233+
}
234+
}
235+
}
236+
joint_matrix_store(sg, sub_a,
237+
accA.get_pointer() + (sg_startx * TM) * N +
238+
sg_starty / SG_SZ * TN,
239+
N, matrix_layout::row_major);
240+
}); // parallel for
241+
}).wait();
242+
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
243+
}
244+
245+
static constexpr size_t MATRIX_M = TM * 2;
246+
static constexpr size_t MATRIX_N = TN * 2;
247+
unsigned short A[MATRIX_M][MATRIX_N];
248+
float D[MATRIX_M][MATRIX_N];
249+
250+
void matrix_ops_ref(float *D, int M, int N) {
251+
for (int m = 0; m < M; m++)
252+
for (int n = 0; n < N; n++) {
253+
*(D + m * N + n) = 0;
254+
*(D + m * N + n) *= 2;
255+
}
256+
}
257+
258+
int main() {
259+
260+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
261+
big_matrix<unsigned short, MATRIX_M, MATRIX_N> MA((unsigned short *)&A);
262+
263+
size_t NDRangeM = MATRIX_M / TM;
264+
size_t NDRangeN = MATRIX_N / TN;
265+
queue q;
266+
nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});
267+
268+
matrix_verify_add<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
269+
matrix_verify_sub<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 3.0);
270+
matrix_verify_mul<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 15.0);
271+
matrix_verify_div<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 2.0);
272+
matrix_verify_logic<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
273+
274+
return 0;
275+
}

SYCL/Matrix/element_wise_all_ops_half.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,11 @@
77
//===----------------------------------------------------------------------===//
88
// REQUIRES: matrix
99

10+
// Only runs on DPAS because AMX implementation does not support half data type
11+
// yet
1012
// RUN: %clangxx -fsycl %s -o %t.out
11-
// RUN: %CPU_RUN_PLACEHOLDER %t.out
1213
// RUN: %GPU_RUN_PLACEHOLDER %t.out
1314

14-
// There is a known bug in joint_matrix_fill when type is half
15-
// A PR is being developed to fix the bug
16-
// Will remove the XFAIL once this is fixed
17-
// XFAIL: *
18-
1915
#include <CL/sycl.hpp>
2016
#include <iostream>
2117
#include <random>

0 commit comments

Comments
 (0)