Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit e4bfaa8

Browse files
authored
[SYCL][Matrix] Add the 8 bit type variants (#443)
1 parent 401db0a commit e4bfaa8

File tree

5 files changed

+719
-3
lines changed

5 files changed

+719
-3
lines changed

SYCL/Matrix/joint_matrix_bf16.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
// RUN: %CPU_RUN_PLACEHOLDER %t.out
1212
// RUN: %GPU_RUN_PLACEHOLDER %t.out
1313

14-
// XFAIL: *
15-
1614
#include <CL/sycl.hpp>
1715
#include <iostream>
1816

@@ -22,7 +20,7 @@ using namespace sycl::ext::oneapi::experimental::matrix;
2220
#define SG_SZ 8
2321

2422
#define TM 8
25-
#define TN SG_SIZE
23+
#define TN SG_SZ
2624
#define TK 16
2725

2826
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {

SYCL/Matrix/joint_matrix_ss_int8.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//==-------- joint_matrix_ss_int8.cpp - DPC++ joint_matrix------------ ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 8
21+
22+
#define TM 8
23+
#define TN SG_SZ
24+
#define TK 32
25+
26+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
27+
public:
28+
T *mat;
29+
30+
public:
31+
T *get_data() { return mat; }
32+
void set_data(T *data) { mat = data; }
33+
big_matrix(T *data) : mat(data) {}
34+
};
35+
36+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
37+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
38+
size_t NUM_COLS_C>
39+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
40+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
41+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
42+
size_t M = NUM_ROWS_C;
43+
size_t N = NUM_COLS_C;
44+
size_t K = NUM_COLS_A;
45+
// B => K/4 x N*4, A => M x K, C => M, N
46+
// stride should be X's cols, e.g., B's stirde = N*4
47+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
48+
size_t NDRangeM = M / TM;
49+
size_t NDRangeN = N / TN;
50+
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
51+
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N));
52+
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
53+
54+
queue q;
55+
q.submit([&](handler &cgh) {
56+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
57+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
58+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
59+
60+
cgh.parallel_for<class imatrix>(
61+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
62+
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
63+
64+
{
65+
// The submatrix API has to be accessed by all the workitems in a
66+
// subgroup these functions will be called once by the subgroup no
67+
// code divergence between the workitems
68+
const auto global_idx = spmd_item.get_global_id(0);
69+
const auto global_idy = spmd_item.get_global_id(1);
70+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
71+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
72+
73+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
74+
joint_matrix<int8_t, TM, TK> sub_a(sg);
75+
// For B, since current implementation does not support non-packed
76+
// layout, users need to specify the updated VNNI sizes along with
77+
// the packed_b layout. By default, the layout is row_major and size
78+
// is (TK, TN).
79+
joint_matrix<int8_t, TK, TN, matrix_layout::packed_b> sub_b(sg);
80+
joint_matrix<int32_t, TM, TN> sub_c(sg);
81+
82+
joint_matrix_load(sg, sub_c,
83+
accC.get_pointer() + (sg_startx * TM) * N +
84+
sg_starty / SG_SZ * TN,
85+
N, matrix_layout::row_major);
86+
for (int k = 0; k < K / TK; k += 1) {
87+
joint_matrix_load(
88+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
89+
K, matrix_layout::packed_a);
90+
// Assuming B data is already in VNNI format.
91+
joint_matrix_load(sg, sub_b,
92+
accB.get_pointer() + (k * TK / 4) * (N * 4) +
93+
sg_starty / SG_SZ * TN * 4,
94+
N * 4, matrix_layout::packed_b);
95+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
96+
}
97+
joint_matrix_store(sg, sub_c,
98+
accC.get_pointer() + (sg_startx * TM) * N +
99+
sg_starty / SG_SZ * TN,
100+
N, matrix_layout::row_major);
101+
}); // parallel for
102+
}).wait();
103+
}
104+
105+
static constexpr size_t MATRIX_M = TM * 2;
106+
static constexpr size_t MATRIX_N = TN * 2;
107+
static constexpr size_t MATRIX_K = TK * 2;
108+
int8_t A[MATRIX_M][MATRIX_K];
109+
int8_t B[MATRIX_K / 4][MATRIX_N * 4];
110+
int32_t C[MATRIX_M][MATRIX_N];
111+
int32_t D[MATRIX_M][MATRIX_N];
112+
113+
void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M,
114+
int N, int K) {
115+
// tiling
116+
for (int m = 0; m < M; m++)
117+
for (int n = 0; n < N; n++) {
118+
for (int k = 0; k < K; k++) {
119+
char *va = (char *)(A_mem + m * K + k);
120+
char *vb = (char *)(B_mem + k * N + n);
121+
int acc = *(C_mem + m * N + n);
122+
for (int i = 0; i < 4; i++) {
123+
acc += (va[i] * vb[i]);
124+
}
125+
*(C_mem + m * N + n) = acc;
126+
}
127+
}
128+
}
129+
130+
int main() {
131+
for (int i = 0; i < MATRIX_M; i++) {
132+
for (int j = 0; j < MATRIX_K; j++) {
133+
A[i][j] = i + 2 * j;
134+
}
135+
}
136+
for (int i = 0; i < MATRIX_K / 4; i++) {
137+
for (int j = 0; j < MATRIX_N * 4; j++) {
138+
B[i][j] = i + j;
139+
}
140+
}
141+
for (int i = 0; i < MATRIX_M; i++) {
142+
for (int j = 0; j < MATRIX_N; j++) {
143+
C[i][j] = 1;
144+
D[i][j] = 1;
145+
}
146+
}
147+
148+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MC((int32_t *)&C);
149+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD((int32_t *)&D);
150+
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A);
151+
big_matrix<int8_t, MATRIX_K / 4, MATRIX_N * 4> MB((int8_t *)&B);
152+
matrix_multiply(MC, MA, MB);
153+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
154+
MATRIX_N, MATRIX_K / 4);
155+
156+
bool res = true;
157+
for (int i = 0; i < MATRIX_M; i++) {
158+
for (int j = 0; j < MATRIX_N; j++) {
159+
if (C[i][j] != D[i][j])
160+
res = false;
161+
}
162+
}
163+
if (res)
164+
std::cout << "passed\n";
165+
else
166+
std::cout << "failed\n";
167+
for (int i = 0; i < MATRIX_M; i++) {
168+
for (int j = 0; j < MATRIX_N; j++)
169+
std::cout << C[i][j] << ", ";
170+
std::cout << "\n";
171+
}
172+
std::cout << std::endl;
173+
for (int i = 0; i < MATRIX_M; i++) {
174+
for (int j = 0; j < MATRIX_N; j++)
175+
std::cout << D[i][j] << ", ";
176+
std::cout << "\n";
177+
}
178+
}

SYCL/Matrix/joint_matrix_su_int8.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
//==-------- joint_matrix_su_int8.cpp - DPC++ joint_matrix------------ ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
20+
#define SG_SZ 8
21+
22+
#define TM 8
23+
#define TN SG_SZ
24+
#define TK 32
25+
26+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
27+
public:
28+
T *mat;
29+
30+
public:
31+
T *get_data() { return mat; }
32+
void set_data(T *data) { mat = data; }
33+
big_matrix(T *data) : mat(data) {}
34+
};
35+
36+
template <typename T1, typename T2, typename T3, size_t NUM_ROWS_A,
37+
size_t NUM_COLS_A, size_t NUM_ROWS_B, size_t NUM_COLS_B,
38+
size_t NUM_ROWS_C, size_t NUM_COLS_C>
39+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
40+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
41+
big_matrix<T3, NUM_ROWS_B, NUM_COLS_B> &B) {
42+
size_t M = NUM_ROWS_C;
43+
size_t N = NUM_COLS_C;
44+
size_t K = NUM_COLS_A;
45+
// B => K/4 x N*4, A => M x K, C => M, N
46+
// stride should be X's cols, e.g., B's stirde = N*4
47+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
48+
size_t NDRangeM = M / TM;
49+
size_t NDRangeN = N / TN;
50+
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
51+
buffer<uint8_t, 2> bufB(B.get_data(), range<2>(K, N));
52+
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
53+
54+
queue q;
55+
q.submit([&](handler &cgh) {
56+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
57+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
58+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
59+
60+
cgh.parallel_for<class imatrix>(
61+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
62+
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
63+
64+
{
65+
// The submatrix API has to be accessed by all the workitems in a
66+
// subgroup these functions will be called once by the subgroup no
67+
// code divergence between the workitems
68+
const auto global_idx = spmd_item.get_global_id(0);
69+
const auto global_idy = spmd_item.get_global_id(1);
70+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
71+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
72+
73+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
74+
joint_matrix<int8_t, TM, TK> sub_a(sg);
75+
// For B, since current implementation does not support non-packed
76+
// layout, users need to specify the updated VNNI sizes along with
77+
// the packed_b layout. By default, the layout is row_major and size
78+
// is (TK, TN).
79+
joint_matrix<uint8_t, TK, TN, matrix_layout::packed_b> sub_b(sg);
80+
joint_matrix<int32_t, TM, TN> sub_c(sg);
81+
82+
// AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64
83+
// strideX = X's cols, so strideC = N, strideA = K, strideB = N*4
84+
joint_matrix_load(sg, sub_c,
85+
accC.get_pointer() + (sg_startx * TM) * N +
86+
sg_starty / SG_SZ * TN,
87+
N, matrix_layout::row_major);
88+
for (int k = 0; k < K / TK; k += 1) {
89+
joint_matrix_load(
90+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
91+
K, matrix_layout::packed_a);
92+
// Assuming B data is already in VNNI format.
93+
joint_matrix_load(sg, sub_b,
94+
accB.get_pointer() + (k * TK / 4) * (N * 4) +
95+
sg_starty / SG_SZ * TN * 4,
96+
N * 4, matrix_layout::packed_b);
97+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
98+
}
99+
joint_matrix_store(sg, sub_c,
100+
accC.get_pointer() + (sg_startx * TM) * N +
101+
sg_starty / SG_SZ * TN,
102+
N, matrix_layout::row_major);
103+
}); // parallel for
104+
}).wait();
105+
}
106+
107+
static constexpr size_t MATRIX_M = TM * 2;
108+
static constexpr size_t MATRIX_N = TN * 2;
109+
static constexpr size_t MATRIX_K = TK * 2;
110+
int8_t A[MATRIX_M][MATRIX_K];
111+
uint8_t B[MATRIX_K / 4][MATRIX_N * 4];
112+
int32_t C[MATRIX_M][MATRIX_N];
113+
int32_t D[MATRIX_M][MATRIX_N];
114+
115+
void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M,
116+
int N, int K) {
117+
// tiling
118+
for (int m = 0; m < M; m++)
119+
for (int n = 0; n < N; n++) {
120+
for (int k = 0; k < K; k++) {
121+
char *va = (char *)(A_mem + m * K + k);
122+
char *vb = (char *)(B_mem + k * N + n);
123+
int acc = *(C_mem + m * N + n);
124+
for (int i = 0; i < 4; i++) {
125+
acc += (va[i] * vb[i]);
126+
}
127+
*(C_mem + m * N + n) = acc;
128+
}
129+
}
130+
}
131+
132+
int main() {
133+
for (int i = 0; i < MATRIX_M; i++) {
134+
for (int j = 0; j < MATRIX_K; j++) {
135+
A[i][j] = i + 2 * j;
136+
}
137+
}
138+
for (int i = 0; i < MATRIX_K / 4; i++) {
139+
for (int j = 0; j < MATRIX_N * 4; j++) {
140+
B[i][j] = i + j;
141+
}
142+
}
143+
for (int i = 0; i < MATRIX_M; i++) {
144+
for (int j = 0; j < MATRIX_N; j++) {
145+
C[i][j] = 1;
146+
D[i][j] = 1;
147+
}
148+
}
149+
150+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MC((int32_t *)&C);
151+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD((int32_t *)&D);
152+
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A);
153+
big_matrix<uint8_t, MATRIX_K / 4, MATRIX_N * 4> MB((uint8_t *)&B);
154+
matrix_multiply(MC, MA, MB);
155+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
156+
MATRIX_N, MATRIX_K / 4);
157+
158+
bool res = true;
159+
for (int i = 0; i < MATRIX_M; i++) {
160+
for (int j = 0; j < MATRIX_N; j++) {
161+
if (C[i][j] != D[i][j])
162+
res = false;
163+
}
164+
}
165+
if (res)
166+
std::cout << "passed\n";
167+
else
168+
std::cout << "failed\n";
169+
for (int i = 0; i < MATRIX_M; i++) {
170+
for (int j = 0; j < MATRIX_N; j++)
171+
std::cout << C[i][j] << ", ";
172+
std::cout << "\n";
173+
}
174+
std::cout << std::endl;
175+
for (int i = 0; i < MATRIX_M; i++) {
176+
for (int j = 0; j < MATRIX_N; j++)
177+
std::cout << D[i][j] << ", ";
178+
std::cout << "\n";
179+
}
180+
}

0 commit comments

Comments
 (0)