Skip to content

Commit a04135c

Browse files
dkhaldibb-sycl
authored andcommitted
[SYCL] Add a new query test that uses the new API (intel#1358)
1 parent cacceb0 commit a04135c

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
//==------ joint_matrix_query_use_default.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
13+
// CHECK: passed
14+
15+
// CHECK: passed
16+
17+
#include <iostream>
18+
#include <sycl/sycl.hpp>
19+
20+
using namespace sycl;
21+
using namespace sycl::ext::oneapi::experimental::matrix;
22+
23+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
24+
public:
25+
T *mat;
26+
27+
public:
28+
T *get_data() { return mat; }
29+
void set_data(T *data) { mat = data; }
30+
big_matrix(T *data) : mat(data) {}
31+
};
32+
33+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
34+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
35+
size_t NUM_COLS_C>
36+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
37+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
38+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
39+
size_t M = NUM_ROWS_C;
40+
size_t N = NUM_COLS_C;
41+
size_t K = NUM_COLS_A;
42+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
43+
44+
using myparams2 = tpu_params<tpu::amx, int8_t, int8_t, int>;
45+
constexpr int TM = myparams2::M;
46+
constexpr int TN = myparams2::N;
47+
constexpr int TK = myparams2::K;
48+
49+
std::cout << "AMX query sizes are: M " << TM << " N " << TN << " K " << TK
50+
<< std::endl;
51+
if (TM == 16 && TN == 16 && TK == 64)
52+
std::cout << "passed\n";
53+
else
54+
std::cout << "failed\n";
55+
constexpr int SG_SZ = TN;
56+
size_t NDRangeM = M / TM;
57+
size_t NDRangeN = N / TN;
58+
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
59+
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N));
60+
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
61+
62+
queue q;
63+
q.submit([&](handler &cgh) {
64+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
65+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
66+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
67+
68+
cgh.parallel_for<class imatrix>(
69+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
70+
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
71+
[[intel::reqd_sub_group_size(SG_SZ)]]
72+
73+
{
74+
// The submatrix API has to be accessed by all the workitems in a
75+
// subgroup these functions will be called once by the subgroup no
76+
// code divergence between the workitems
77+
const auto global_idx = spmd_item.get_global_id(0);
78+
const auto global_idy = spmd_item.get_global_id(1);
79+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
80+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
81+
82+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
83+
84+
myparams2::joint_matrix_a<sub_group> sub_a(sg);
85+
myparams2::joint_matrix_b<sub_group> sub_b(sg);
86+
myparams2::joint_matrix_accumulator<sub_group> sub_c(sg);
87+
88+
joint_matrix_load(sg, sub_c,
89+
accC.get_pointer() + (sg_startx * TM) * N +
90+
sg_starty / SG_SZ * TN,
91+
N, layout::row_major);
92+
for (int k = 0; k < K / TK; k += 1) {
93+
joint_matrix_load(
94+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
95+
K, layout::row_major);
96+
// Assuming B data is already in VNNI format.
97+
joint_matrix_load(sg, sub_b,
98+
accB.get_pointer() + (k * TK / 4) * (N * 4) +
99+
sg_starty / SG_SZ * TN * 4,
100+
N * 4, layout::packed_b);
101+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
102+
}
103+
joint_matrix_store(sg, sub_c,
104+
accC.get_pointer() + (sg_startx * TM) * N +
105+
sg_starty / SG_SZ * TN,
106+
N, layout::row_major);
107+
}); // parallel for
108+
}).wait();
109+
}
110+
111+
static constexpr size_t MATRIX_M = 128;
112+
static constexpr size_t MATRIX_N = 128;
113+
static constexpr size_t MATRIX_K = 128;
114+
int8_t A[MATRIX_M][MATRIX_K];
115+
int8_t B[MATRIX_K / 4][MATRIX_N * 4];
116+
int32_t C[MATRIX_M][MATRIX_N];
117+
int32_t D[MATRIX_M][MATRIX_N];
118+
119+
void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M,
120+
int N, int K) {
121+
// tiling
122+
for (int m = 0; m < M; m++)
123+
for (int n = 0; n < N; n++) {
124+
for (int k = 0; k < K; k++) {
125+
char *va = (char *)(A_mem + m * K + k);
126+
char *vb = (char *)(B_mem + k * N + n);
127+
int acc = *(C_mem + m * N + n);
128+
for (int i = 0; i < 4; i++) {
129+
acc += (va[i] * vb[i]);
130+
}
131+
*(C_mem + m * N + n) = acc;
132+
}
133+
}
134+
}
135+
136+
int main() {
137+
for (int i = 0; i < MATRIX_M; i++) {
138+
for (int j = 0; j < MATRIX_K; j++) {
139+
A[i][j] = i + 2 * j;
140+
}
141+
}
142+
for (int i = 0; i < MATRIX_K / 4; i++) {
143+
for (int j = 0; j < MATRIX_N * 4; j++) {
144+
B[i][j] = i + j;
145+
}
146+
}
147+
for (int i = 0; i < MATRIX_M; i++) {
148+
for (int j = 0; j < MATRIX_N; j++) {
149+
C[i][j] = 1;
150+
D[i][j] = 1;
151+
}
152+
}
153+
154+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MC((int32_t *)&C);
155+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD((int32_t *)&D);
156+
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A);
157+
big_matrix<int8_t, MATRIX_K / 4, MATRIX_N * 4> MB((int8_t *)&B);
158+
matrix_multiply(MC, MA, MB);
159+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
160+
MATRIX_N, MATRIX_K / 4);
161+
162+
bool res = true;
163+
for (int i = 0; i < MATRIX_M; i++) {
164+
for (int j = 0; j < MATRIX_N; j++) {
165+
if (C[i][j] != D[i][j])
166+
res = false;
167+
}
168+
}
169+
if (res)
170+
std::cout << "passed\n";
171+
else
172+
std::cout << "failed\n";
173+
}

0 commit comments

Comments
 (0)