Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 613ede6

Browse files
authored
[SYCL][Matrix] add a new joint matrix test that uses SYCL bfloat16 (#1005)
The existing tests are using uint16 to represent bf16 type. We are moving towards adding new ones that use SYCL bfloat16 type in an effort to replace the workaround we had adopted. Signed-off-by: Dounia Khaldi <[email protected]>
1 parent c4e3c97 commit 613ede6

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

SYCL/Matrix/joint_matrix_bfloat16.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//==-------- joint_matrix_bfloat16.cpp - DPC++ joint_matrix----------- ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
20+
21+
#define SG_SZ 8
22+
23+
#define TM 8
24+
#define TN 8
25+
#define TK 16
26+
27+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
28+
private:
29+
T *mat;
30+
31+
public:
32+
T *get_data() { return mat; }
33+
void set_data(T *data) { mat = data; }
34+
big_matrix(T *data) : mat(data) {}
35+
};
36+
37+
template <typename T1, typename T2, size_t M, size_t N, size_t K>
38+
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
39+
big_matrix<T2, K / 2, N * 2> &B) {
40+
size_t NDRangeM = M / TM;
41+
size_t NDRangeN = N / TN;
42+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
43+
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
44+
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
45+
46+
queue q;
47+
q.submit([&](handler &cgh) {
48+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
49+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
50+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
51+
52+
cgh.parallel_for<class imatrix>(
53+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
54+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
55+
56+
{
57+
// The submatrix API has to be accessed by all the workitems in a
58+
// subgroup these functions will be called once by the subgroup no
59+
// code divergence between the workitems
60+
const auto global_idx = spmd_item.get_global_id(0);
61+
const auto global_idy = spmd_item.get_global_id(1);
62+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
63+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
64+
65+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
66+
joint_matrix<bfloat16, TM, TK> sub_a(sg);
67+
// For B, since current implementation does not support non-packed
68+
// layout, users need to specify the updated VNNI sizes along with
69+
// the packed_b layout. By default, the layout is row_major and size
70+
// is (TK, TN).
71+
joint_matrix<bfloat16, TK, TN, matrix_layout::packed_b> sub_b(sg);
72+
joint_matrix<float, TM, TN> sub_c(sg);
73+
74+
joint_matrix_load(sg, sub_c,
75+
accC.get_pointer() + (sg_startx * TM) * N +
76+
sg_starty / SG_SZ * TN,
77+
N, matrix_layout::row_major);
78+
for (int k = 0; k < K / TK; k += 1) { //
79+
joint_matrix_load(
80+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
81+
K, matrix_layout::row_major);
82+
// Assuming B data is already in VNNI format.
83+
joint_matrix_load(sg, sub_b,
84+
accB.get_pointer() + (k * TK / 2) * (N * 2) +
85+
sg_starty / SG_SZ * TN * 2,
86+
N * 2, matrix_layout::packed_b);
87+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
88+
}
89+
joint_matrix_store(sg, sub_c,
90+
accC.get_pointer() + (sg_startx * TM) * N +
91+
sg_starty / SG_SZ * TN,
92+
N, matrix_layout::row_major);
93+
}); // parallel for
94+
}).wait();
95+
}
96+
97+
static constexpr size_t MATRIX_M = TM * 2;
98+
static constexpr size_t MATRIX_N = TN * 2;
99+
static constexpr size_t MATRIX_K = TK * 2;
100+
bfloat16 A[MATRIX_M][MATRIX_K];
101+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
102+
unsigned short Aref[MATRIX_M][MATRIX_K];
103+
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
104+
float C[MATRIX_M][MATRIX_N];
105+
float D[MATRIX_M][MATRIX_N];
106+
107+
float make_fp32(short x) {
108+
unsigned int y = x;
109+
y = y << 16;
110+
float *res = reinterpret_cast<float *>(&y);
111+
return *res;
112+
}
113+
114+
unsigned short make_bf16(float x) {
115+
int *res = reinterpret_cast<int *>(&x);
116+
*res = *res >> 16;
117+
return (unsigned short)*res;
118+
}
119+
120+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
121+
int K) {
122+
// tiling
123+
for (int m = 0; m < M; m++)
124+
for (int n = 0; n < N; n++) {
125+
for (int k = 0; k < K; k++) {
126+
short *va = (short *)(A_mem + m * K + k);
127+
short *vb = (short *)(B_mem + k * N + n);
128+
float acc = *((float *)(C_mem + m * N + n));
129+
// FIXME: Should we do reduce-add in another version?
130+
for (int i = 0; i < 2; i++) {
131+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
132+
}
133+
*((float *)(C_mem + m * N + n)) = acc;
134+
}
135+
}
136+
}
137+
138+
int main() {
139+
for (int i = 0; i < MATRIX_M; i++) {
140+
for (int j = 0; j < MATRIX_K; j++) {
141+
// bfloat16 is created using unsigned short since conversion from float to
142+
// bfloat16 is not supported on the host side yet
143+
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
144+
Aref[i][j] = make_bf16(1.0f * (i + j));
145+
}
146+
}
147+
for (int i = 0; i < MATRIX_K / 2; i++) {
148+
for (int j = 0; j < MATRIX_N * 2; j++) {
149+
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
150+
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
151+
}
152+
}
153+
for (int i = 0; i < MATRIX_M; i++) {
154+
for (int j = 0; j < MATRIX_N; j++) {
155+
C[i][j] = 1.0;
156+
D[i][j] = 1.0;
157+
}
158+
}
159+
160+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
161+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
162+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
163+
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
164+
matrix_multiply(MC, MA, MB);
165+
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
166+
MATRIX_N, MATRIX_K / 2);
167+
168+
bool res = true;
169+
for (int i = 0; i < MATRIX_M; i++) {
170+
for (int j = 0; j < MATRIX_N; j++) {
171+
if (C[i][j] != D[i][j])
172+
res = false;
173+
}
174+
}
175+
if (res)
176+
std::cout << "passed\n";
177+
else
178+
std::cout << "failed\n";
179+
}

0 commit comments

Comments
 (0)