Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 191ddf7

Browse files
authored
add new matrix tests that use the new interface (#1231)
* add new matrix tests that use the new interface that add an argument called use * address Andrei s review comments * add xfail for GPU as use is not handled yet * correct the verification for bfloat16 test
1 parent 3bac74f commit 191ddf7

File tree

2 files changed

+338
-0
lines changed

2 files changed

+338
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//==-------- joint_matrix_bfloat16_use.cpp - DPC++ joint_matrix------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
// XFAIL: gpu
15+
16+
#include <iostream>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl::ext::oneapi::experimental::matrix;
20+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
21+
22+
#define SG_SZ 8
23+
24+
#define TM 8
25+
#define TN 8
26+
#define TK 16
27+
28+
#define BF16_EPSILON 0.00781250
29+
30+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
31+
private:
32+
T *mat;
33+
34+
public:
35+
T *get_data() { return mat; }
36+
void set_data(T *data) { mat = data; }
37+
big_matrix(T *data) : mat(data) {}
38+
};
39+
40+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
41+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
42+
size_t NUM_COLS_C>
43+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
44+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
45+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
46+
size_t M = NUM_ROWS_C;
47+
size_t N = NUM_COLS_C;
48+
size_t K = NUM_COLS_A;
49+
static_assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
50+
size_t NDRangeM = M / TM;
51+
size_t NDRangeN = N / TN;
52+
sycl::buffer<bfloat16, 2> bufA(A.get_data(), sycl::range<2>(M, K));
53+
sycl::buffer<bfloat16, 2> bufB(B.get_data(), sycl::range<2>(K, N));
54+
sycl::buffer<float, 2> bufC((float *)C.get_data(), sycl::range<2>(M, N));
55+
56+
sycl::queue q;
57+
q.submit([&](sycl::handler &cgh) {
58+
sycl::accessor accC{bufC, cgh};
59+
sycl::accessor accA{bufA, cgh};
60+
sycl::accessor accB{bufB, cgh};
61+
62+
cgh.parallel_for<class imatrix>(
63+
sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
64+
[=](sycl::nd_item<2> spmd_item)
65+
66+
{
67+
// The submatrix API has to be accessed by all the workitems in a
68+
// subgroup
69+
const auto global_idx = spmd_item.get_global_id(0);
70+
const auto global_idy = spmd_item.get_global_id(1);
71+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
72+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
73+
74+
sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group();
75+
joint_matrix<bfloat16, TM, TK, use::a> sub_a(sg);
76+
joint_matrix<bfloat16, TK, TN, use::b> sub_b(sg);
77+
joint_matrix<float, TM, TN, use::accumulator> sub_c(sg);
78+
79+
joint_matrix_fill(sg, sub_c, 1.0);
80+
for (int k = 0; k < K / TK; k += 1) {
81+
joint_matrix_load(
82+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
83+
K, layout::row_major);
84+
// Assuming B data is already in VNNI format.
85+
joint_matrix_load(sg, sub_b,
86+
accB.get_pointer() + (k * TK / 2) * (N * 2) +
87+
sg_starty / SG_SZ * TN * 2,
88+
N * 2, layout::packed_b);
89+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
90+
}
91+
joint_matrix_store(sg, sub_c,
92+
accC.get_pointer() + (sg_startx * TM) * N +
93+
sg_starty / SG_SZ * TN,
94+
N, layout::row_major);
95+
}); // parallel for
96+
}).wait();
97+
}
98+
99+
static constexpr size_t MATRIX_M = TM * 2;
100+
static constexpr size_t MATRIX_N = TN * 2;
101+
static constexpr size_t MATRIX_K = TK * 2;
102+
bfloat16 A[MATRIX_M][MATRIX_K];
103+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
104+
unsigned short Aref[MATRIX_M][MATRIX_K];
105+
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
106+
float C[MATRIX_M][MATRIX_N];
107+
float D[MATRIX_M][MATRIX_N];
108+
109+
float make_fp32(short x) {
110+
unsigned int y = x;
111+
y = y << 16;
112+
float *res = reinterpret_cast<float *>(&y);
113+
return *res;
114+
}
115+
116+
unsigned short make_bf16(float x) {
117+
int *res = reinterpret_cast<int *>(&x);
118+
*res = *res >> 16;
119+
return (unsigned short)*res;
120+
}
121+
122+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
123+
int K) {
124+
for (int m = 0; m < M; m++)
125+
for (int n = 0; n < N; n++) {
126+
for (int k = 0; k < K; k++) {
127+
short *va = (short *)(A_mem + m * K + k);
128+
short *vb = (short *)(B_mem + k * N + n);
129+
float acc = *((float *)(C_mem + m * N + n));
130+
for (int i = 0; i < 2; i++) {
131+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
132+
}
133+
*((float *)(C_mem + m * N + n)) = acc;
134+
}
135+
}
136+
}
137+
138+
int main() {
139+
for (int i = 0; i < MATRIX_M; i++) {
140+
for (int j = 0; j < MATRIX_K; j++) {
141+
// bfloat16 is created from unsigned short since float-to-bfloat's
142+
// conversion is not allowed.
143+
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
144+
Aref[i][j] = make_bf16(1.0f * (i + j));
145+
}
146+
}
147+
for (int i = 0; i < MATRIX_K / 2; i++) {
148+
for (int j = 0; j < MATRIX_N * 2; j++) {
149+
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
150+
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
151+
}
152+
}
153+
for (int i = 0; i < MATRIX_M; i++) {
154+
for (int j = 0; j < MATRIX_N; j++) {
155+
C[i][j] = 1.0;
156+
D[i][j] = 1.0;
157+
}
158+
}
159+
160+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
161+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
162+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
163+
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
164+
matrix_multiply(MC, MA, MB);
165+
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
166+
MATRIX_N, MATRIX_K / 2);
167+
168+
bool res = true;
169+
for (int i = 0; i < MATRIX_M; i++) {
170+
for (int j = 0; j < MATRIX_N; j++) {
171+
if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
172+
res = false;
173+
}
174+
}
175+
if (res)
176+
std::cout << "passed\n";
177+
else
178+
std::cout << "failed\n";
179+
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
//==-------- joint_matrix_ss_int8_use.cpp - DPC++ joint_matrix-------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
// XFAIL: gpu
15+
16+
#include <iostream>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::oneapi::experimental::matrix;
21+
22+
#define SG_SZ 8
23+
24+
#define TM 8
25+
#define TN SG_SZ
26+
#define TK 32
27+
28+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
29+
private:
30+
T *mat;
31+
32+
public:
33+
T *get_data() { return mat; }
34+
void set_data(T *data) { mat = data; }
35+
big_matrix(T *data) : mat(data) {}
36+
};
37+
38+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
39+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
40+
size_t NUM_COLS_C>
41+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
42+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
43+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
44+
size_t M = NUM_ROWS_C;
45+
size_t N = NUM_COLS_C;
46+
size_t K = NUM_COLS_A;
47+
static_assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
48+
size_t NDRangeM = M / TM;
49+
size_t NDRangeN = N / TN;
50+
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
51+
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N));
52+
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
53+
54+
queue q;
55+
q.submit([&](handler &cgh) {
56+
sycl::accessor accC{bufC, cgh};
57+
sycl::accessor accA{bufA, cgh};
58+
sycl::accessor accB{bufB, cgh};
59+
60+
cgh.parallel_for<class imatrix>(
61+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
62+
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
63+
[[intel::reqd_sub_group_size(SG_SZ)]]
64+
65+
{
66+
// The submatrix API has to be accessed by all the workitems in a
67+
// subgroup
68+
const auto global_idx = spmd_item.get_global_id(0);
69+
const auto global_idy = spmd_item.get_global_id(1);
70+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
71+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
72+
73+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
74+
joint_matrix<int8_t, TM, TK, use::a> sub_a(sg);
75+
joint_matrix<int8_t, TK, TN, use::b> sub_b(sg);
76+
joint_matrix<int32_t, TM, TN, use::accumulator> sub_c(sg);
77+
78+
joint_matrix_fill(sg, sub_c, 0);
79+
for (int k = 0; k < K / TK; k += 1) {
80+
joint_matrix_load(
81+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
82+
K, layout::row_major);
83+
// Assuming B data is already in VNNI format.
84+
joint_matrix_load(sg, sub_b,
85+
accB.get_pointer() + (k * TK / 4) * (N * 4) +
86+
sg_starty / SG_SZ * TN * 4,
87+
N * 4, layout::packed_b);
88+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
89+
}
90+
joint_matrix_store(sg, sub_c,
91+
accC.get_pointer() + (sg_startx * TM) * N +
92+
sg_starty / SG_SZ * TN,
93+
N, layout::row_major);
94+
}); // parallel for
95+
}).wait();
96+
}
97+
98+
static constexpr size_t MATRIX_M = TM * 2;
99+
static constexpr size_t MATRIX_N = TN * 2;
100+
static constexpr size_t MATRIX_K = TK * 2;
101+
int8_t A[MATRIX_M][MATRIX_K];
102+
int8_t B[MATRIX_K / 4][MATRIX_N * 4];
103+
int32_t C[MATRIX_M][MATRIX_N];
104+
int32_t D[MATRIX_M][MATRIX_N];
105+
106+
void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M,
107+
int N, int K) {
108+
for (int m = 0; m < M; m++)
109+
for (int n = 0; n < N; n++) {
110+
for (int k = 0; k < K; k++) {
111+
char *va = (char *)(A_mem + m * K + k);
112+
char *vb = (char *)(B_mem + k * N + n);
113+
int acc = *(C_mem + m * N + n);
114+
for (int i = 0; i < 4; i++) {
115+
acc += (va[i] * vb[i]);
116+
}
117+
*(C_mem + m * N + n) = acc;
118+
}
119+
}
120+
}
121+
122+
int main() {
123+
for (int i = 0; i < MATRIX_M; i++) {
124+
for (int j = 0; j < MATRIX_K; j++) {
125+
A[i][j] = i + 2 * j;
126+
}
127+
}
128+
for (int i = 0; i < MATRIX_K / 4; i++) {
129+
for (int j = 0; j < MATRIX_N * 4; j++) {
130+
B[i][j] = i + j;
131+
}
132+
}
133+
for (int i = 0; i < MATRIX_M; i++) {
134+
for (int j = 0; j < MATRIX_N; j++) {
135+
C[i][j] = 0;
136+
D[i][j] = 0;
137+
}
138+
}
139+
140+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MC((int32_t *)&C);
141+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD((int32_t *)&D);
142+
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A);
143+
big_matrix<int8_t, MATRIX_K / 4, MATRIX_N * 4> MB((int8_t *)&B);
144+
matrix_multiply(MC, MA, MB);
145+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
146+
MATRIX_N, MATRIX_K / 4);
147+
148+
bool res = true;
149+
for (int i = 0; i < MATRIX_M; i++) {
150+
for (int j = 0; j < MATRIX_N; j++) {
151+
if (C[i][j] != D[i][j])
152+
res = false;
153+
}
154+
}
155+
if (res)
156+
std::cout << "passed\n";
157+
else
158+
std::cout << "failed\n";
159+
}

0 commit comments

Comments
 (0)