Skip to content

Commit 74afa2b

Browse files
authored
[SYCL][Matrix] Add test for all sizes (#8842)
Based on this pull request: intel/llvm-test-suite#1523
1 parent 21afc0c commit 74afa2b

File tree

4 files changed

+246
-1
lines changed

4 files changed

+246
-1
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//==-------- joint_matrix_all_sizes.cpp - DPC++ joint_matrix---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-xmx8
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <iostream>
15+
#include <sycl/sycl.hpp>
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
using bfloat16 = sycl::ext::oneapi::bfloat16;
20+
21+
#define SG_SZ 8
22+
23+
#include "../joint_matrix_all_sizes_impl.hpp"

sycl/test-e2e/Matrix/XMX8/joint_matrix_bfloat16_32x64.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
// REQUIRES: matrix
8+
// REQUIRES: matrix-xmx8
99

1010
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
1111
// RUN: %CPU_RUN_PLACEHOLDER %t.out
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//==-------- joint_matrix_all_sizes.cpp - DPC++ joint_matrix---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <iostream>
15+
#include <sycl/sycl.hpp>
16+
17+
using namespace sycl;
18+
using namespace sycl::ext::oneapi::experimental::matrix;
19+
using bfloat16 = sycl::ext::oneapi::bfloat16;
20+
21+
#define SG_SZ 16
22+
23+
#include "joint_matrix_all_sizes_impl.hpp"
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#define BF16_EPSILON 0.00781250
2+
static constexpr size_t M_MULTIPLIER = 16;
3+
4+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
5+
private:
6+
T *mat;
7+
8+
public:
9+
T *get_data() { return mat; }
10+
void set_data(T *data) { mat = data; }
11+
big_matrix(T *data) : mat(data) {}
12+
};
13+
14+
template <typename T>
15+
void matrix_vnni(unsigned int rows, unsigned int cols, T *src, T *dest,
16+
unsigned int vnniFactor) {
17+
for (unsigned int i = 0; i < rows / vnniFactor; i++) {
18+
for (unsigned int j = 0; j < cols; j++) {
19+
for (unsigned int k = 0; k < vnniFactor; k++) {
20+
dest[i * cols * vnniFactor + j * vnniFactor + k] =
21+
src[(i * vnniFactor + k) * cols + j];
22+
}
23+
}
24+
}
25+
}
26+
27+
template <typename T1, typename T2, size_t M, size_t N, size_t K,
28+
int vnniFactor, size_t TM, size_t TN, size_t TK>
29+
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
30+
big_matrix<T2, K / vnniFactor, N * vnniFactor> &B) {
31+
size_t NDRangeM = M / TM;
32+
size_t NDRangeN = N / TN;
33+
buffer<T2, 2> bufA(A.get_data(), range<2>(M, K));
34+
buffer<T2, 2> bufB(B.get_data(), range<2>(K, N));
35+
buffer<T1, 2> bufC(C.get_data(), range<2>(M, N));
36+
37+
queue q;
38+
q.submit([&](handler &cgh) {
39+
sycl::accessor accC{bufC, cgh, sycl::read_write};
40+
sycl::accessor accA{bufA, cgh, sycl::read_only};
41+
sycl::accessor accB{bufB, cgh, sycl::read_only};
42+
43+
cgh.parallel_for(
44+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
45+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
46+
47+
{
48+
// The submatrix API has to be accessed by all the workitems in a
49+
// subgroup these functions will be called once by the subgroup no
50+
// code divergence between the workitems
51+
const auto global_idx = spmd_item.get_global_id(0);
52+
const auto global_idy = spmd_item.get_global_id(1);
53+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
54+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
55+
56+
sub_group sg = spmd_item.get_sub_group();
57+
joint_matrix<sub_group, T2, use::a, TM, TK, layout::row_major> sub_a;
58+
// For B, we assume B has been already VNNIed.
59+
joint_matrix<sub_group, T2, use::b, TK, TN,
60+
ext::intel::experimental::matrix::layout::packed>
61+
sub_b;
62+
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
63+
64+
joint_matrix_load(sg, sub_c,
65+
accC.get_pointer() + (sg_startx * TM) * N +
66+
sg_starty / SG_SZ * TN,
67+
N, layout::row_major);
68+
for (int k = 0; k < K / TK; k += 1) {
69+
joint_matrix_load(
70+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
71+
K);
72+
joint_matrix_load(sg, sub_b,
73+
accB.get_pointer() +
74+
(k * TK / vnniFactor) * (N * vnniFactor) +
75+
sg_starty / SG_SZ * TN * vnniFactor,
76+
N * vnniFactor);
77+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
78+
}
79+
joint_matrix_store(sg, sub_c,
80+
accC.get_pointer() + (sg_startx * TM) * N +
81+
sg_starty / SG_SZ * TN,
82+
N, layout::row_major);
83+
}); // parallel for
84+
}).wait();
85+
}
86+
87+
static constexpr size_t MATRIX_N = 128;
88+
static constexpr size_t MATRIX_K = 128;
89+
90+
float make_fp32(bfloat16 x) {
91+
unsigned int y = *((int *)&x);
92+
y = y << 16;
93+
float *res = reinterpret_cast<float *>(&y);
94+
return *res;
95+
}
96+
97+
template <typename Ta, typename Tc>
98+
void matrix_multiply_ref(Ta *A, Ta *B, Tc *C, int M, int N, int K) {
99+
for (int m = 0; m < M; m++)
100+
for (int n = 0; n < N; n++) {
101+
for (int k = 0; k < K; k++) {
102+
if (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
103+
C[m * N + n] += make_fp32(A[m * K + k]) * make_fp32(B[k * N + n]);
104+
if (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>)
105+
C[m * N + n] += A[m * K + k] * B[k * N + n];
106+
}
107+
}
108+
}
109+
110+
template <typename Ta, typename Tc, int vnni_factor, size_t tM, size_t tN,
111+
size_t tK>
112+
int init_and_multiply() {
113+
114+
static constexpr size_t MATRIX_M = tM * M_MULTIPLIER;
115+
std::cout << "MATRIX_M=" << MATRIX_M << "\n";
116+
117+
Ta A[MATRIX_M][MATRIX_K];
118+
Ta B[MATRIX_K][MATRIX_N];
119+
Ta Bvnni[MATRIX_K / vnni_factor][MATRIX_N * vnni_factor];
120+
Tc C[MATRIX_M][MATRIX_N];
121+
Tc D[MATRIX_M][MATRIX_N];
122+
123+
for (int i = 0; i < MATRIX_M; i++) {
124+
for (int j = 0; j < MATRIX_K; j++) {
125+
if (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
126+
A[i][j] = bfloat16(1.0f * (i + j));
127+
if (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>)
128+
A[i][j] = i + j;
129+
}
130+
}
131+
for (int i = 0; i < MATRIX_K; i++) {
132+
for (int j = 0; j < MATRIX_N; j++) {
133+
if (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>)
134+
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
135+
if (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>)
136+
B[i][j] = i + 2 * j;
137+
}
138+
}
139+
for (int i = 0; i < MATRIX_M; i++) {
140+
for (int j = 0; j < MATRIX_N; j++) {
141+
C[i][j] = 1;
142+
D[i][j] = 1;
143+
}
144+
}
145+
146+
big_matrix<Tc, MATRIX_M, MATRIX_N> MC((Tc *)&C);
147+
big_matrix<Tc, MATRIX_M, MATRIX_N> MD((Tc *)&D);
148+
big_matrix<Ta, MATRIX_M, MATRIX_K> MA((Ta *)&A);
149+
matrix_vnni<Ta>(MATRIX_K, MATRIX_N, (Ta *)&B, (Ta *)&Bvnni, vnni_factor);
150+
big_matrix<Ta, MATRIX_K / vnni_factor, MATRIX_N * vnni_factor> MBvnni(
151+
(Ta *)&Bvnni);
152+
153+
matrix_multiply<Tc, Ta, MATRIX_M, MATRIX_N, MATRIX_K, vnni_factor, tM, tN,
154+
tK>(MC, MA, MBvnni);
155+
matrix_multiply_ref((Ta *)A, (Ta *)B, (Tc *)D, MATRIX_M, MATRIX_N, MATRIX_K);
156+
157+
bool res = true;
158+
for (int i = 0; i < MATRIX_M; i++) {
159+
for (int j = 0; j < MATRIX_N; j++) {
160+
if constexpr (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) {
161+
if (fabs(C[i][j] - D[i][j]) > BF16_EPSILON) {
162+
res = false;
163+
std::cout << "Failed bfloat16: C is " << C[i][j] << ", D is "
164+
<< D[i][j] << std::endl;
165+
}
166+
} else if (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) {
167+
if (C[i][j] != D[i][j]) {
168+
res = false;
169+
std::cout << "Failed int8_t: C is " << C[i][j] << ", D is " << D[i][j]
170+
<< std::endl;
171+
}
172+
}
173+
}
174+
}
175+
std::cout << (res ? "passed" : "failed") << std::endl;
176+
return !res;
177+
}
178+
179+
int main() {
180+
init_and_multiply<bfloat16, float, 2, 1, SG_SZ, 16>();
181+
init_and_multiply<bfloat16, float, 2, 2, SG_SZ, 16>();
182+
init_and_multiply<bfloat16, float, 2, 3, SG_SZ, 16>();
183+
init_and_multiply<bfloat16, float, 2, 4, SG_SZ, 16>();
184+
init_and_multiply<bfloat16, float, 2, 5, SG_SZ, 16>();
185+
init_and_multiply<bfloat16, float, 2, 6, SG_SZ, 16>();
186+
init_and_multiply<bfloat16, float, 2, 7, SG_SZ, 16>();
187+
init_and_multiply<bfloat16, float, 2, 8, SG_SZ, 16>();
188+
189+
init_and_multiply<int8_t, int32_t, 4, 1, SG_SZ, 32>();
190+
init_and_multiply<int8_t, int32_t, 4, 2, SG_SZ, 32>();
191+
init_and_multiply<int8_t, int32_t, 4, 3, SG_SZ, 32>();
192+
init_and_multiply<int8_t, int32_t, 4, 4, SG_SZ, 32>();
193+
init_and_multiply<int8_t, int32_t, 4, 5, SG_SZ, 32>();
194+
init_and_multiply<int8_t, int32_t, 4, 6, SG_SZ, 32>();
195+
init_and_multiply<int8_t, int32_t, 4, 7, SG_SZ, 32>();
196+
init_and_multiply<int8_t, int32_t, 4, 8, SG_SZ, 32>();
197+
198+
return 0;
199+
}

0 commit comments

Comments
 (0)