Skip to content

Commit 678a3fc

Browse files
authored
[SYCL][Matrix]Add new matrix tests for PVC (intel#646)
1 parent 0f47e44 commit 678a3fc

15 files changed

+1039
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//==-------- joint_matrix_bf16_pvc.cpp - DPC++ joint_matrix---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-pvc
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 16
21+
22+
#define TM 8
23+
#define TN SG_SZ
24+
#define TK 16
25+
26+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
27+
public:
28+
T *mat;
29+
30+
public:
31+
T *get_data() { return mat; }
32+
void set_data(T *data) { mat = data; }
33+
big_matrix(T *data) : mat(data) {}
34+
};
35+
36+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
37+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
38+
size_t NUM_COLS_C>
39+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
40+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
41+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
42+
size_t M = NUM_ROWS_C;
43+
size_t N = NUM_COLS_C;
44+
size_t K = NUM_COLS_A;
45+
46+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
47+
size_t NDRangeM = M / TM;
48+
size_t NDRangeN = N / TN;
49+
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, K));
50+
buffer<unsigned short, 2> bufB(B.get_data(), range<2>(K / 2, N * 2));
51+
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
52+
53+
queue q;
54+
q.submit([&](handler &cgh) {
55+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
56+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
57+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
58+
59+
cgh.parallel_for<class imatrix>(
60+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
61+
[ accA, accB, accC, M, N, K ](nd_item<2> spmd_item)
62+
[[intel::reqd_sub_group_size(SG_SZ)]]
63+
64+
{
65+
// The submatrix API has to be accessed by all the workitems in a
66+
// subgroup these functions will be called once by the subgroup no
67+
// code divergence between the workitems
68+
const auto global_idx = spmd_item.get_global_id(0);
69+
const auto global_idy = spmd_item.get_global_id(1);
70+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
71+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
72+
73+
sub_group sg = spmd_item.get_sub_group();
74+
joint_matrix<unsigned short, TM, TK> sub_a(sg);
75+
// For B, since current implementation does not support non-packed
76+
// layout, users need to specify the packed_b layout.
77+
// By default, the layout is row_major
78+
joint_matrix<unsigned short, TK, TN, matrix_layout::packed_b> sub_b(
79+
sg);
80+
joint_matrix<float, TM, TN> sub_c(sg);
81+
joint_matrix_load(sg, sub_c,
82+
accC.get_pointer() + (sg_startx * TM) * N +
83+
sg_starty / SG_SZ * TN,
84+
N, matrix_layout::row_major);
85+
for (int k = 0; k < K; k += TK) {
86+
joint_matrix_load(sg, sub_a,
87+
accA.get_pointer() + (sg_startx * TM) * K + k, K,
88+
matrix_layout::row_major);
89+
// Assume we alreay in vnni format.
90+
joint_matrix_load(sg, sub_b,
91+
accB.get_pointer() + (k) * (N) +
92+
sg_starty / SG_SZ * TN * 2,
93+
N * 2, matrix_layout::packed_b);
94+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
95+
}
96+
joint_matrix_store(sg, sub_c,
97+
accC.get_pointer() + (sg_startx * TM) * N +
98+
sg_starty / SG_SZ * TN,
99+
N, matrix_layout::row_major);
100+
}); // parallel for
101+
}).wait();
102+
}
103+
104+
static constexpr size_t MATRIX_M = TM * 2;
105+
static constexpr size_t MATRIX_N = TN * 2;
106+
static constexpr size_t MATRIX_K = TK * 2;
107+
unsigned short A[MATRIX_M][MATRIX_K];
108+
unsigned short B[MATRIX_K / 2][MATRIX_N * 2];
109+
float C[MATRIX_M][MATRIX_N];
110+
float D[MATRIX_M][MATRIX_N];
111+
112+
float make_fp32(short x) {
113+
unsigned int y = x;
114+
y = y << 16;
115+
float *res = reinterpret_cast<float *>(&y);
116+
return *res;
117+
}
118+
119+
unsigned short make_bf16(float x) {
120+
int *res = reinterpret_cast<int *>(&x);
121+
*res = *res >> 16;
122+
return (unsigned short)*res;
123+
}
124+
125+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
126+
int K) {
127+
// tiling
128+
for (int m = 0; m < M; m++)
129+
for (int n = 0; n < N; n++) {
130+
for (int k = 0; k < K; k++) {
131+
short *va = (short *)(A_mem + m * K + k);
132+
short *vb = (short *)(B_mem + k * N + n);
133+
float acc = *((float *)(C_mem + m * N + n));
134+
// FIXME: Should we do reduce-add in another version?
135+
for (int i = 0; i < 2; i++) {
136+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
137+
}
138+
*((float *)(C_mem + m * N + n)) = acc;
139+
}
140+
}
141+
}
142+
143+
int main() {
144+
for (int i = 0; i < MATRIX_M; i++) {
145+
for (int j = 0; j < MATRIX_K; j++) {
146+
A[i][j] = make_bf16(1.0f * (i + j));
147+
}
148+
}
149+
for (int i = 0; i < MATRIX_K / 2; i++) {
150+
for (int j = 0; j < MATRIX_N * 2; j++) {
151+
B[i][j] = make_bf16(2.0f * i + 3.0f * j);
152+
}
153+
}
154+
for (int i = 0; i < MATRIX_M; i++) {
155+
for (int j = 0; j < MATRIX_N; j++) {
156+
C[i][j] = 1.0;
157+
D[i][j] = 1.0;
158+
}
159+
}
160+
161+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
162+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
163+
big_matrix<unsigned short, MATRIX_M, MATRIX_K> MA((unsigned short *)&A);
164+
big_matrix<unsigned short, MATRIX_K / 2, MATRIX_N * 2> MB(
165+
(unsigned short *)&B);
166+
matrix_multiply(MC, MA, MB);
167+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
168+
MATRIX_N, MATRIX_K / 2);
169+
170+
bool res = true;
171+
for (int i = 0; i < MATRIX_M; i++) {
172+
for (int j = 0; j < MATRIX_N; j++) {
173+
if (C[i][j] != D[i][j])
174+
res = false;
175+
}
176+
}
177+
if (res)
178+
std::cout << "passed\n";
179+
else
180+
std::cout << "failed\n";
181+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
//==-------- joint_matrix_half_pvc.cpp - DPC++ joint_matrix---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-pvc
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// Only run on the GPU because half is not supported on AMX hardware
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 16
21+
22+
#define TM 8
23+
#define TN SG_SZ
24+
#define TK 16
25+
26+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
27+
public:
28+
T *mat;
29+
30+
public:
31+
T *get_data() { return mat; }
32+
void set_data(T *data) { mat = data; }
33+
big_matrix(T *data) : mat(data) {}
34+
};
35+
36+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
37+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
38+
size_t NUM_COLS_C>
39+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
40+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
41+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
42+
size_t M = NUM_ROWS_C;
43+
size_t N = NUM_COLS_C;
44+
size_t K = NUM_COLS_A;
45+
46+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
47+
size_t NDRangeM = M / TM;
48+
size_t NDRangeN = N / TN;
49+
buffer<half, 2> bufA(A.get_data(), range<2>(M, K));
50+
buffer<half, 2> bufB(B.get_data(), range<2>(K, N));
51+
buffer<float, 2> bufC(C.get_data(), range<2>(M, N));
52+
53+
queue q;
54+
q.submit([&](handler &cgh) {
55+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
56+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
57+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
58+
59+
cgh.parallel_for<class imatrix>(
60+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, SG_SZ}),
61+
[ accA, accB, accC, M, N,
62+
K ](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
63+
// The submatrix API has to be accessed by all the workitems in a
64+
// subgroup these functions will be called once by the subgroup no
65+
// code divergence between the workitems
66+
const auto global_idx = spmd_item.get_global_id(0);
67+
const auto global_idy = spmd_item.get_global_id(1);
68+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
69+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
70+
71+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
72+
joint_matrix<half, TM, TK> sub_a(sg);
73+
// For B, since current implementation does not support non-packed
74+
// layout, users need to specify the updated VNNI sizes along with
75+
// the packed_b layout. By default, the layout is row_major and size
76+
// is (TK, TN).
77+
joint_matrix<half, TK, TN, matrix_layout::packed_b> sub_b(sg);
78+
joint_matrix<float, TM, TN> sub_c(sg);
79+
80+
joint_matrix_load(sg, sub_c,
81+
accC.get_pointer() + (sg_startx * TM) * N +
82+
sg_starty / SG_SZ * TN,
83+
N, matrix_layout::row_major);
84+
for (int k = 0; k < K / TK; k += 1) {
85+
joint_matrix_load(
86+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
87+
K, matrix_layout::row_major);
88+
// Assuming B data is already in VNNI format.
89+
joint_matrix_load(sg, sub_b,
90+
accB.get_pointer() + (k * TK / 2) * (N * 2) +
91+
sg_starty / SG_SZ * TN * 2,
92+
N * 2, matrix_layout::packed_b);
93+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
94+
}
95+
joint_matrix_store(sg, sub_c,
96+
accC.get_pointer() + (sg_startx * TM) * N +
97+
sg_starty / SG_SZ * TN,
98+
N, matrix_layout::row_major);
99+
}); // parallel for
100+
}).wait();
101+
}
102+
103+
static constexpr size_t MATRIX_M = TM * 2;
104+
static constexpr size_t MATRIX_N = TN * 2;
105+
static constexpr size_t MATRIX_K = TK * 2;
106+
half A[MATRIX_M][MATRIX_K];
107+
half B[MATRIX_K / 2][MATRIX_N * 2];
108+
float C[MATRIX_M][MATRIX_N];
109+
float D[MATRIX_M][MATRIX_N];
110+
111+
void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N,
112+
int K) {
113+
// tiling
114+
for (int m = 0; m < M; m++)
115+
for (int n = 0; n < N; n++) {
116+
for (int k = 0; k < K; k++) {
117+
half *va = (half *)(A_mem + m * K + k);
118+
half *vb = (half *)(B_mem + k * N + n);
119+
float acc = *(C_mem + m * N + n);
120+
for (int i = 0; i < 2; i++) {
121+
acc += ((float)va[i] * (float)vb[i]);
122+
}
123+
*((float *)(C_mem + m * N + n)) = acc;
124+
}
125+
}
126+
}
127+
128+
int main() {
129+
for (int i = 0; i < MATRIX_M; i++) {
130+
for (int j = 0; j < MATRIX_K; j++) {
131+
A[i][j] = i + 2 * j;
132+
}
133+
}
134+
for (int i = 0; i < MATRIX_K / 2; i++) {
135+
for (int j = 0; j < MATRIX_N * 2; j++) {
136+
B[i][j] = i + j;
137+
}
138+
}
139+
for (int i = 0; i < MATRIX_M; i++) {
140+
for (int j = 0; j < MATRIX_N; j++) {
141+
C[i][j] = 1.0;
142+
D[i][j] = 1.0;
143+
}
144+
}
145+
146+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
147+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
148+
big_matrix<half, MATRIX_M, MATRIX_K> MA((half *)&A);
149+
big_matrix<half, MATRIX_K / 2, MATRIX_N * 2> MB((half *)&B);
150+
matrix_multiply(MC, MA, MB);
151+
matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N,
152+
MATRIX_K / 2);
153+
154+
bool res = true;
155+
for (int i = 0; i < MATRIX_M; i++) {
156+
for (int j = 0; j < MATRIX_N; j++) {
157+
if (C[i][j] != D[i][j])
158+
res = false;
159+
}
160+
}
161+
if (res)
162+
std::cout << "passed\n";
163+
else
164+
std::cout << "failed\n";
165+
}

0 commit comments

Comments
 (0)