Skip to content

Commit 7936fb0

Browse files
author
arnamoy.bhattacharyya
committed
[SYCL][Matrix] Add initial get_coord API.
This patch adds initial API for retrieval of coordinates from a work item element.
1 parent 08b2022 commit 7936fb0

File tree

3 files changed

+224
-0
lines changed

3 files changed

+224
-0
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
113113
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
114114
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);
115115

116+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
117+
__spv::MatrixLayout L,
118+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
119+
extern SYCL_EXTERNAL std::tuple<T, T> __spirv_JointMatrixWorkItemElemCoord(
120+
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
121+
116122
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
117123
__spv::MatrixLayout L,
118124
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>

sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <CL/__spirv/spirv_ops.hpp>
1313
#include <sycl/detail/defines_elementary.hpp>
1414
#include <sycl/feature_test.hpp>
15+
#include <tuple>
1516

1617
namespace sycl {
1718
__SYCL_INLINE_VER_NAMESPACE(_V1) {
@@ -256,6 +257,18 @@ class wi_element {
256257
wi_element(joint_matrix<T, NumRows, NumCols, Use, Layout, Group> &Mat,
257258
std::size_t i)
258259
: M(Mat), idx(i) {}
260+
261+
// Functions
262+
std::tuple<size_t, size_t> get_coord() {
263+
#ifdef __SYCL_DEVICE_ONLY__
264+
return __spirv_JointMatrixWorkItemElemCoord(M.spvm, idx);
265+
#else
266+
throw runtime_error("joint matrix is not supported on host device.",
267+
PI_ERROR_INVALID_DEVICE);
268+
#endif // __SYCL_DEVICE_ONLY__
269+
}
270+
271+
// Various Operations
259272
operator T() {
260273
#ifdef __SYCL_DEVICE_ONLY__
261274
return __spirv_VectorExtractDynamic(M.spvm, idx);
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=spir64_gen -DSYCL_EXT_ONEAPI_MATRIX=2 -S -emit-llvm %s -o %t.out
2+
#include <iostream>
3+
#include <sycl/sycl.hpp>
4+
5+
using namespace sycl::ext::oneapi::experimental::matrix;
6+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
7+
8+
static constexpr auto TILE_SZ = 16;
9+
static constexpr auto TM = TILE_SZ - 1;
10+
static constexpr auto TN = TILE_SZ - 1;
11+
static constexpr auto TK = 2 * TILE_SZ - 2;
12+
13+
static constexpr auto SG_SZ = 16;
14+
15+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
16+
public:
17+
T *mat;
18+
19+
public:
20+
T *get_data() { return mat; }
21+
void set_data(T *data) { mat = data; }
22+
big_matrix(T *data) : mat(data) {}
23+
};
24+
25+
static constexpr size_t MATRIX_M = TM * 2;
26+
static constexpr size_t MATRIX_N = TN * 2;
27+
static constexpr size_t MATRIX_K = TK * 2;
28+
bfloat16 A[MATRIX_M][MATRIX_K];
29+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
30+
unsigned short Aref[MATRIX_M][MATRIX_K];
31+
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
32+
float C[MATRIX_M][MATRIX_N];
33+
float D[MATRIX_M][MATRIX_N];
34+
int32_t *res_local_row;
35+
36+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
37+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
38+
size_t NUM_COLS_C>
39+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
40+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
41+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
42+
size_t M = NUM_ROWS_C;
43+
size_t N = NUM_COLS_C;
44+
size_t K = NUM_COLS_A;
45+
// B => K/4 x N*4, A => M x K, C => M, N
46+
// stride should be X's cols, e.g., B's stirde = N*4
47+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
48+
size_t NDRangeM = M / TM;
49+
size_t NDRangeN = N / TN;
50+
sycl::buffer<bfloat16, 2> bufA(A.get_data(), sycl::range<2>(M, K));
51+
sycl::buffer<bfloat16, 2> bufB(B.get_data(), sycl::range<2>(K, N));
52+
sycl::buffer<float, 2> bufC((float *)C.get_data(), sycl::range<2>(M, N));
53+
54+
sycl::buffer<int32_t, 1> res_local_row_buf(res_local_row,
55+
sycl::range<1>(MATRIX_M));
56+
57+
sycl::queue q;
58+
q.submit([&](sycl::handler &cgh) {
59+
auto accC = bufC.get_access<sycl::access::mode::read_write>(cgh);
60+
auto accA = bufA.get_access<sycl::access::mode::read_write>(cgh);
61+
auto accB = bufB.get_access<sycl::access::mode::read_write>(cgh);
62+
auto res_local_row_acc =
63+
res_local_row_buf.get_access<sycl::access::mode::read_write>(cgh);
64+
65+
cgh.parallel_for<class imatrix>(
66+
sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
67+
[accA, accB, accC, M, N, K,
68+
res_local_row_acc](sycl::nd_item<2> spmd_item)
69+
70+
{
71+
// The submatrix API has to be accessed by all the workitems in a
72+
// subgroup these functions will be called once by the subgroup no
73+
// code divergence between the workitems
74+
const auto global_idx = spmd_item.get_global_id(0);
75+
const auto global_idy = spmd_item.get_global_id(1);
76+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
77+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
78+
79+
sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group();
80+
joint_matrix<bfloat16, TM, TK, use::a> sub_a(sg);
81+
// For B, since current implementation does not support non-packed
82+
// layout, users need to specify the updated VNNI sizes along with
83+
// the packed_b layout. By default, the layout is row_major and size
84+
// is (TK, TN).
85+
joint_matrix<bfloat16, TK, TN, use::b> sub_b(sg);
86+
joint_matrix<float, TM, TN, use::accumulator> sub_c(sg);
87+
88+
joint_matrix_load(sg, sub_c,
89+
accC.get_pointer() + (sg_startx * TM) * N +
90+
sg_starty / SG_SZ * TN,
91+
N, layout::row_major);
92+
for (int k = 0; k < K / TK; k += 1) { //
93+
joint_matrix_load(
94+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
95+
K, layout::row_major);
96+
// Assuming B data is already in VNNI format.
97+
joint_matrix_load(sg, sub_b,
98+
accB.get_pointer() + (k * TK / 2) * (N * 2) +
99+
sg_starty / SG_SZ * TN * 2,
100+
N * 2, layout::packed_b);
101+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
102+
}
103+
joint_matrix_store(sg, sub_c,
104+
accC.get_pointer() + (sg_startx * TM) * N +
105+
sg_starty / SG_SZ * TN,
106+
N, layout::row_major);
107+
// Element wise operation
108+
auto tCData = sub_c.get_wi_data();
109+
110+
for (int i = 0; i < tCData.length(); ++i) {
111+
size_t row, col;
112+
std::tie(row, col) = tCData[i].get_coord();
113+
res_local_row_acc[row] += tCData[i];
114+
}
115+
}); // parallel for
116+
}).wait();
117+
}
118+
119+
float make_fp32(short x) {
120+
unsigned int y = x;
121+
y = y << 16;
122+
float *res = reinterpret_cast<float *>(&y);
123+
return *res;
124+
}
125+
126+
unsigned short make_bf16(float x) {
127+
int *res = reinterpret_cast<int *>(&x);
128+
*res = *res >> 16;
129+
return (unsigned short)*res;
130+
}
131+
132+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
133+
int K) {
134+
// tiling
135+
for (int m = 0; m < M; m++)
136+
for (int n = 0; n < N; n++) {
137+
for (int k = 0; k < K; k++) {
138+
short *va = (short *)(A_mem + m * K + k);
139+
short *vb = (short *)(B_mem + k * N + n);
140+
float acc = *((float *)(C_mem + m * N + n));
141+
// FIXME: Should we do reduce-add in another version?
142+
for (int i = 0; i < 2; i++) {
143+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
144+
}
145+
*((float *)(C_mem + m * N + n)) = acc;
146+
}
147+
}
148+
}
149+
150+
int main() {
151+
for (int i = 0; i < MATRIX_M; i++) {
152+
for (int j = 0; j < MATRIX_K; j++) {
153+
// Ee create bfloat16 from unsigned short since float-to-bfloat's
154+
// conversion is not allowed.
155+
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
156+
Aref[i][j] = make_bf16(1.0f * (i + j));
157+
}
158+
}
159+
for (int i = 0; i < MATRIX_K / 2; i++) {
160+
for (int j = 0; j < MATRIX_N * 2; j++) {
161+
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
162+
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
163+
}
164+
}
165+
for (int i = 0; i < MATRIX_M; i++) {
166+
for (int j = 0; j < MATRIX_N; j++) {
167+
C[i][j] = 1.0;
168+
D[i][j] = 1.0;
169+
}
170+
}
171+
172+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
173+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
174+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
175+
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
176+
177+
res_local_row = (int32_t *)calloc(MATRIX_M, sizeof(int32_t));
178+
179+
matrix_multiply(MC, MA, MB);
180+
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
181+
MATRIX_N, MATRIX_K / 2);
182+
183+
bool res = true;
184+
for (int i = 0; i < MATRIX_M; i++) {
185+
for (int j = 0; j < MATRIX_N; j++) {
186+
if (C[i][j] != D[i][j])
187+
res = false;
188+
}
189+
}
190+
if (res)
191+
std::cout << "passed\n";
192+
else
193+
std::cout << "failed\n";
194+
for (int i = 0; i < MATRIX_M; i++) {
195+
for (int j = 0; j < MATRIX_N; j++)
196+
std::cout << C[i][j] << ", ";
197+
std::cout << "\n";
198+
}
199+
std::cout << std::endl;
200+
for (int i = 0; i < MATRIX_M; i++) {
201+
for (int j = 0; j < MATRIX_N; j++)
202+
std::cout << D[i][j] << ", ";
203+
std::cout << "\n";
204+
}
205+
}

0 commit comments

Comments
 (0)