Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit daa5aec

Browse files
authored
add a new test that handles big combination size 32x64 (#1350)
1 parent cb20604 commit daa5aec

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
// XFAIL: *
15+
16+
#include <iostream>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::oneapi::experimental::matrix;
21+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
22+
23+
#define SG_SZ 16
24+
25+
#define TM 32
26+
#define TN 64
27+
#define TK 16
28+
29+
#define BF16_EPSILON 0.00781250
30+
31+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
32+
private:
33+
T *mat;
34+
35+
public:
36+
T *get_data() { return mat; }
37+
void set_data(T *data) { mat = data; }
38+
big_matrix(T *data) : mat(data) {}
39+
};
40+
41+
template <typename T1, typename T2, size_t M, size_t N, size_t K>
42+
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
43+
big_matrix<T2, K / 2, N * 2> &B) {
44+
size_t NDRangeM = M / TM;
45+
size_t NDRangeN = N / TN;
46+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
47+
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
48+
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
49+
50+
queue q;
51+
q.submit([&](handler &cgh) {
52+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
53+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
54+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
55+
56+
cgh.parallel_for<class imatrix>(
57+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
58+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
59+
60+
{
61+
// The submatrix API has to be accessed by all the workitems in a
62+
// subgroup these functions will be called once by the subgroup no
63+
// code divergence between the workitems
64+
const auto global_idx = spmd_item.get_global_id(0);
65+
const auto global_idy = spmd_item.get_global_id(1);
66+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
67+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
68+
69+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
70+
joint_matrix<bfloat16, TM, TK> sub_a(sg);
71+
// For B, since current implementation does not support non-packed
72+
// layout, users need to specify the updated VNNI sizes along with
73+
// the packed_b layout. By default, the layout is row_major and size
74+
// is (TK, TN).
75+
joint_matrix<bfloat16, TK, TN, matrix_layout::packed_b> sub_b(sg);
76+
joint_matrix<float, TM, TN> sub_c(sg);
77+
78+
joint_matrix_load(sg, sub_c,
79+
accC.get_pointer() + (sg_startx * TM) * N +
80+
sg_starty / SG_SZ * TN,
81+
N, matrix_layout::row_major);
82+
for (int k = 0; k < K / TK; k += 1) { //
83+
joint_matrix_load(
84+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
85+
K, matrix_layout::row_major);
86+
// Assuming B data is already in VNNI format.
87+
joint_matrix_load(sg, sub_b,
88+
accB.get_pointer() + (k * TK / 2) * (N * 2) +
89+
sg_starty / SG_SZ * TN * 2,
90+
N * 2, matrix_layout::packed_b);
91+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
92+
}
93+
joint_matrix_store(sg, sub_c,
94+
accC.get_pointer() + (sg_startx * TM) * N +
95+
sg_starty / SG_SZ * TN,
96+
N, matrix_layout::row_major);
97+
}); // parallel for
98+
}).wait();
99+
}
100+
101+
static constexpr size_t MATRIX_M = TM * 2;
102+
static constexpr size_t MATRIX_N = TN * 2;
103+
static constexpr size_t MATRIX_K = TK * 2;
104+
bfloat16 A[MATRIX_M][MATRIX_K];
105+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
106+
unsigned short Aref[MATRIX_M][MATRIX_K];
107+
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
108+
float C[MATRIX_M][MATRIX_N];
109+
float D[MATRIX_M][MATRIX_N];
110+
111+
float make_fp32(short x) {
112+
unsigned int y = x;
113+
y = y << 16;
114+
float *res = reinterpret_cast<float *>(&y);
115+
return *res;
116+
}
117+
118+
unsigned short make_bf16(float x) {
119+
int *res = reinterpret_cast<int *>(&x);
120+
*res = *res >> 16;
121+
return (unsigned short)*res;
122+
}
123+
124+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
125+
int K) {
126+
// tiling
127+
for (int m = 0; m < M; m++)
128+
for (int n = 0; n < N; n++) {
129+
for (int k = 0; k < K; k++) {
130+
short *va = (short *)(A_mem + m * K + k);
131+
short *vb = (short *)(B_mem + k * N + n);
132+
float acc = *((float *)(C_mem + m * N + n));
133+
// FIXME: Should we do reduce-add in another version?
134+
for (int i = 0; i < 2; i++) {
135+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
136+
}
137+
*((float *)(C_mem + m * N + n)) = acc;
138+
}
139+
}
140+
}
141+
142+
int main() {
143+
for (int i = 0; i < MATRIX_M; i++) {
144+
for (int j = 0; j < MATRIX_K; j++) {
145+
// bfloat16 is created using unsigned short since conversion from float to
146+
// bfloat16 is not supported on the host side yet
147+
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
148+
Aref[i][j] = make_bf16(1.0f * (i + j));
149+
}
150+
}
151+
for (int i = 0; i < MATRIX_K / 2; i++) {
152+
for (int j = 0; j < MATRIX_N * 2; j++) {
153+
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
154+
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
155+
}
156+
}
157+
for (int i = 0; i < MATRIX_M; i++) {
158+
for (int j = 0; j < MATRIX_N; j++) {
159+
C[i][j] = 1.0;
160+
D[i][j] = 1.0;
161+
}
162+
}
163+
164+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
165+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
166+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
167+
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
168+
matrix_multiply(MC, MA, MB);
169+
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
170+
MATRIX_N, MATRIX_K / 2);
171+
172+
bool res = true;
173+
for (int i = 0; i < MATRIX_M; i++) {
174+
for (int j = 0; j < MATRIX_N; j++) {
175+
if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
176+
res = false;
177+
}
178+
}
179+
if (res)
180+
std::cout << "passed\n";
181+
else
182+
std::cout << "failed\n";
183+
}

0 commit comments

Comments
 (0)