Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit f1a22ff

Browse files
authored
[SYCL][matrix] Add basic bf16 test case for the joint matrix feature (#384)
1 parent 204b096 commit f1a22ff

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed

SYCL/Matrix/joint_matrix_bf16.cpp

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
//==-------- joint_matrix_bf16.cpp - DPC++ joint_matrix--------------- ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
// XFAIL: *
15+
16+
#include <CL/sycl.hpp>
17+
#include <iostream>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::oneapi::experimental::matrix;
21+
22+
#define SG_SZ 8
23+
24+
#define TM 8
25+
#define TN SG_SIZE
26+
#define TK 16
27+
28+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
29+
public:
30+
T *mat;
31+
32+
public:
33+
T *get_data() { return mat; }
34+
void set_data(T *data) { mat = data; }
35+
big_matrix(T *data) : mat(data) {}
36+
};
37+
38+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
39+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
40+
size_t NUM_COLS_C>
41+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
42+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
43+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
44+
size_t M = NUM_ROWS_C;
45+
size_t N = NUM_COLS_C;
46+
size_t K = NUM_COLS_A;
47+
48+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
49+
size_t NDRangeM = M / TM;
50+
size_t NDRangeN = N / TN;
51+
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, K));
52+
buffer<unsigned short, 2> bufB(B.get_data(), range<2>(K / 2, N * 2));
53+
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
54+
55+
queue q;
56+
q.submit([&](handler &cgh) {
57+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
58+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
59+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
60+
61+
cgh.parallel_for<class imatrix>(
62+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
63+
[ accA, accB, accC, M, N, K ](nd_item<2> spmd_item)
64+
[[intel::reqd_sub_group_size(SG_SZ)]]
65+
66+
{
67+
// The submatrix API has to be accessed by all the workitems in a
68+
// subgroup these functions will be called once by the subgroup no
69+
// code divergence between the workitems
70+
const auto global_idx = spmd_item.get_global_id(0);
71+
const auto global_idy = spmd_item.get_global_id(1);
72+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
73+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
74+
75+
ONEAPI::sub_group sg = spmd_item.get_sub_group();
76+
joint_matrix<ONEAPI::sub_group, unsigned short, TM, TK> sub_a(sg);
77+
// For B, since current implementation does not support non-packed
78+
// layout, users need to specify the updated VNNI sizes along with
79+
// the packed_b layout. By default, the layout is row_major and size
80+
// is (TK, TN).
81+
joint_matrix<ONEAPI::sub_group, unsigned short, TK, TN,
82+
matrix_layout::packed_b>
83+
sub_b(sg);
84+
joint_matrix<ONEAPI::sub_group, float, TM, TN> sub_c(sg);
85+
86+
joint_matrix_load(sg, sub_c,
87+
accC.get_pointer() + (sg_startx * TM) * N +
88+
sg_starty / SG_SZ * TN,
89+
N, matrix_layout::row_major);
90+
for (int k = 0; k < K; k += TK) {
91+
joint_matrix_load(sg, sub_a,
92+
accA.get_pointer() + (sg_startx * TM) * K + k, K,
93+
matrix_layout::row_major);
94+
// Assume we alreay in vnni format.
95+
joint_matrix_load(sg, sub_b,
96+
accB.get_pointer() + (k) * (N) +
97+
sg_starty / SG_SZ * TN * 2,
98+
N * 2, matrix_layout::packed_b);
99+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
100+
}
101+
joint_matrix_store(sg, sub_c,
102+
accC.get_pointer() + (sg_startx * TM) * N +
103+
sg_starty / SG_SZ * TN,
104+
N, matrix_layout::row_major);
105+
}); // parallel for
106+
}).wait();
107+
}
108+
109+
static constexpr size_t MATRIX_M = TM * 2;
110+
static constexpr size_t MATRIX_N = TN * 2;
111+
static constexpr size_t MATRIX_K = TK * 2;
112+
unsigned short A[MATRIX_M][MATRIX_K];
113+
unsigned short B[MATRIX_K / 2][MATRIX_N * 2];
114+
float C[MATRIX_M][MATRIX_N];
115+
float D[MATRIX_M][MATRIX_N];
116+
117+
float make_fp32(short x) {
118+
unsigned int y = x;
119+
y = y << 16;
120+
float *res = reinterpret_cast<float *>(&y);
121+
return *res;
122+
}
123+
124+
unsigned short make_bf16(float x) {
125+
int *res = reinterpret_cast<int *>(&x);
126+
*res = *res >> 16;
127+
return (unsigned short)*res;
128+
}
129+
130+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
131+
int K) {
132+
// tiling
133+
for (int m = 0; m < M; m++)
134+
for (int n = 0; n < N; n++) {
135+
for (int k = 0; k < K; k++) {
136+
short *va = (short *)(A_mem + m * K + k);
137+
short *vb = (short *)(B_mem + k * N + n);
138+
float acc = *((float *)(C_mem + m * N + n));
139+
// FIXME: Should we do reduce-add in another version?
140+
for (int i = 0; i < 2; i++) {
141+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
142+
}
143+
*((float *)(C_mem + m * N + n)) = acc;
144+
}
145+
}
146+
}
147+
148+
int main() {
149+
for (int i = 0; i < MATRIX_M; i++) {
150+
for (int j = 0; j < MATRIX_K; j++) {
151+
A[i][j] = make_bf16(1.0f * (i + j));
152+
}
153+
}
154+
for (int i = 0; i < MATRIX_K / 2; i++) {
155+
for (int j = 0; j < MATRIX_N * 2; j++) {
156+
B[i][j] = make_bf16(2.0f * i + 3.0f * j);
157+
}
158+
}
159+
for (int i = 0; i < MATRIX_M; i++) {
160+
for (int j = 0; j < MATRIX_N; j++) {
161+
C[i][j] = 1.0;
162+
D[i][j] = 1.0;
163+
}
164+
}
165+
166+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
167+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
168+
big_matrix<unsigned short, MATRIX_M, MATRIX_K> MA((unsigned short *)&A);
169+
big_matrix<unsigned short, MATRIX_K / 2, MATRIX_N * 2> MB(
170+
(unsigned short *)&B);
171+
matrix_multiply(MC, MA, MB);
172+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
173+
MATRIX_N, MATRIX_K / 2);
174+
175+
bool res = true;
176+
for (int i = 0; i < MATRIX_M; i++) {
177+
for (int j = 0; j < MATRIX_N; j++) {
178+
if (C[i][j] != D[i][j])
179+
res = false;
180+
}
181+
}
182+
if (res)
183+
std::cout << "passed\n";
184+
else
185+
std::cout << "failed\n";
186+
for (int i = 0; i < MATRIX_M; i++) {
187+
for (int j = 0; j < MATRIX_N; j++)
188+
std::cout << C[i][j] << ", ";
189+
std::cout << "\n";
190+
}
191+
std::cout << std::endl;
192+
for (int i = 0; i < MATRIX_M; i++) {
193+
for (int j = 0; j < MATRIX_N; j++)
194+
std::cout << D[i][j] << ", ";
195+
std::cout << "\n";
196+
}
197+
}

0 commit comments

Comments
 (0)