Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 48c0dbb

Browse files
authored
[SYCL][MATRIX] Adding test cases for testing get_coord() feature. (#1676)
1 parent 574cee1 commit 48c0dbb

9 files changed

+772
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//==----------- get_coord_bf16_gemm.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-xmx8
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
// XFAIL:*
14+
15+
#include <iostream>
16+
#include <random>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::intel;
21+
using namespace sycl::ext::oneapi::experimental::matrix;
22+
using bfloat16 = sycl::ext::oneapi::bfloat16;
23+
24+
#define SG_SZ 8
25+
26+
#include "../get_coord_bf16_gemm_impl.hpp"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//==----------- get_coord_bf16_matA.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-xmx8
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
// XFAIL:*
14+
15+
#include <iostream>
16+
#include <random>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::intel;
21+
using namespace sycl::ext::oneapi::experimental::matrix;
22+
using bfloat16 = sycl::ext::oneapi::bfloat16;
23+
24+
#define SG_SZ 8
25+
26+
#include "../get_coord_bf16_matA_impl.hpp"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//==----------- get_coord_bf16_matB.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-xmx8
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
// XFAIL:*
14+
15+
#include <iostream>
16+
#include <random>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::intel;
21+
using namespace sycl::ext::oneapi::experimental::matrix;
22+
using bfloat16 = sycl::ext::oneapi::bfloat16;
23+
24+
#define SG_SZ 8
25+
26+
#include "../get_coord_bf16_matB_impl.hpp"

SYCL/Matrix/get_coord_bf16_gemm.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//==----------- get_coord_bf16_gemm.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
// XFAIL:*
14+
15+
#include <iostream>
16+
#include <random>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::intel;
21+
using namespace sycl::ext::oneapi;
22+
using namespace sycl::ext::oneapi::experimental::matrix;
23+
using bfloat16 = sycl::ext::oneapi::bfloat16;
24+
25+
#define SG_SZ 16
26+
27+
#include "get_coord_bf16_gemm_impl.hpp"
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
#define TM 8
2+
#define TN SG_SZ
3+
#define TK 16
4+
5+
static constexpr size_t MATRIX_M = TM * 2;
6+
static constexpr size_t MATRIX_N = TN * 2;
7+
static constexpr size_t MATRIX_K = TK * 2;
8+
9+
#define BF16_EPSILON 0.00781250
10+
11+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
12+
private:
13+
T *mat;
14+
15+
public:
16+
T *get_data() { return mat; }
17+
void set_data(T *data) { mat = data; }
18+
big_matrix(T *data) : mat(data) {}
19+
};
20+
21+
// clang-format off
22+
/*
23+
Here's how the data is distributed
24+
W0 --> 0 1 2 3 4 5 6 7
25+
wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 1] wi [0,15] --> i=0, [0, 15]
26+
i=1, [1, 0] i=1, [1, 1] i=1, [1, 15]
27+
i=2, [2, 0] i=2, [2, 1] ...
28+
... ....
29+
i=7, [7, 0] i=7, [7, 1]
30+
*/
31+
// clang-format on
32+
std::tuple<uint32_t, uint32_t> get_coord_ref(int i, int wi_number) {
33+
return std::make_tuple(i, wi_number);
34+
}
35+
36+
float sum_rows[MATRIX_M] = {0};
37+
38+
template <typename T1, typename T2, size_t M, size_t N, size_t K>
39+
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
40+
big_matrix<T2, K / 2, N * 2> &B) {
41+
size_t NDRangeM = M / TM;
42+
size_t NDRangeN = N / TN;
43+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
44+
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
45+
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
46+
47+
buffer<float> sum_rows_v(sum_rows, M); // there are total of M rows
48+
49+
queue q;
50+
q.submit([&](handler &cgh) {
51+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
52+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
53+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
54+
55+
auto v = sum_rows_v.get_access<access::mode::read_write>(cgh);
56+
auto os = sycl::stream(100000, 6144, cgh);
57+
58+
cgh.parallel_for<class imatrix>(
59+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
60+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
61+
62+
{
63+
// The submatrix API has to be accessed by all the workitems in a
64+
// subgroup these functions will be called once by the subgroup no
65+
// code divergence between the workitems
66+
const auto global_idx = spmd_item.get_global_id(0);
67+
const auto global_idy = spmd_item.get_global_id(1);
68+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
69+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
70+
71+
sub_group sg = spmd_item.get_sub_group();
72+
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
73+
sub_a;
74+
// For B, we assume B has been already VNNIed.
75+
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
76+
ext::intel::experimental::matrix::layout::packed>
77+
sub_b;
78+
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
79+
80+
joint_matrix_load(sg, sub_c,
81+
accC.get_pointer() + (sg_startx * TM) * N +
82+
sg_starty / SG_SZ * TN,
83+
N, layout::row_major);
84+
for (int k = 0; k < K / TK; k += 1) { //
85+
joint_matrix_load(
86+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
87+
K);
88+
joint_matrix_load(sg, sub_b,
89+
accB.get_pointer() + (k * TK / 2) * (N * 2) +
90+
sg_starty / SG_SZ * TN * 2,
91+
N * 2);
92+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
93+
}
94+
joint_matrix_store(sg, sub_c,
95+
accC.get_pointer() + (sg_startx * TM) * N +
96+
sg_starty / SG_SZ * TN,
97+
N, layout::row_major);
98+
99+
float sum_local_rows[M] = {0}; // 8 local rows, M total
100+
auto data =
101+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
102+
103+
// Keep track of rows handled in this WI
104+
int32_t handled_rows[M] = {-1};
105+
size_t
106+
global_index; // Index into the result array that holds the sums.
107+
108+
for (int i = 0; i < data.length(); ++i) {
109+
auto dataItem = data[i];
110+
auto [row, col] = dataItem.get_coord();
111+
// get_coord_ref(i, spmd_item.get_local_id(1));
112+
global_index = row + global_idx * TM;
113+
114+
sum_local_rows[global_index] += data[i];
115+
116+
handled_rows[global_index] = 1;
117+
}
118+
119+
for (int j = 0; j < M; j++) {
120+
if (handled_rows[j] == 1) {
121+
global_index = j;
122+
sum_local_rows[global_index] = reduce_over_group(
123+
sg, sum_local_rows[global_index], sycl::plus<>());
124+
// only Groups leader perform the global reduction
125+
if (global_idy % SG_SZ == 0) {
126+
sycl::atomic_ref<float, sycl::memory_order::relaxed,
127+
sycl::memory_scope::device>
128+
aref(v[global_index]);
129+
aref.fetch_add(sum_local_rows[global_index]);
130+
}
131+
}
132+
}
133+
}); // parallel for
134+
}).wait();
135+
}
136+
137+
bfloat16 A[MATRIX_M][MATRIX_K];
138+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
139+
float C[MATRIX_M][MATRIX_N];
140+
float D[MATRIX_M][MATRIX_N];
141+
142+
float make_fp32(bfloat16 x) {
143+
unsigned int y = *((int *)&x);
144+
y = y << 16;
145+
float *res = reinterpret_cast<float *>(&y);
146+
return *res;
147+
}
148+
149+
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
150+
int K) {
151+
for (int m = 0; m < M; m++)
152+
for (int n = 0; n < N; n++) {
153+
for (int k = 0; k < K; k++) {
154+
// Because B was assumed VNNIed
155+
bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
156+
bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
157+
float acc = *((float *)(C_mem + m * N + n));
158+
for (int i = 0; i < 2; i++) {
159+
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
160+
}
161+
*((float *)(C_mem + m * N + n)) = acc;
162+
}
163+
}
164+
}
165+
166+
int main() {
167+
for (int i = 0; i < MATRIX_M; i++) {
168+
for (int j = 0; j < MATRIX_K; j++) {
169+
A[i][j] = bfloat16(1.0f * (i + j));
170+
}
171+
}
172+
for (int i = 0; i < MATRIX_K / 2; i++) {
173+
for (int j = 0; j < MATRIX_N * 2; j++) {
174+
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
175+
}
176+
}
177+
for (int i = 0; i < MATRIX_M; i++) {
178+
for (int j = 0; j < MATRIX_N; j++) {
179+
C[i][j] = 1.0;
180+
D[i][j] = 1.0;
181+
}
182+
}
183+
184+
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
185+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
186+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
187+
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
188+
matrix_multiply(MC, MA, MB);
189+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
190+
MATRIX_N, MATRIX_K / 2);
191+
192+
bool res = true;
193+
float sum_rows_ref[MATRIX_M] = {0};
194+
195+
for (int i = 0; i < MATRIX_M; i++) {
196+
for (int j = 0; j < MATRIX_N; j++) {
197+
// std::cout << C[i][j] << " ";
198+
if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
199+
res = false;
200+
sum_rows_ref[i] += C[i][j];
201+
}
202+
if ((fabs(sum_rows_ref[i]) - fabs(sum_rows[i])) > BF16_EPSILON)
203+
res = false;
204+
// std::cout << "\n";
205+
}
206+
std::cout << (res ? "passed" : "failed") << std::endl;
207+
return !res;
208+
}

SYCL/Matrix/get_coord_bf16_matA.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//==----------- get_coord_bf16_matA.cpp - DPC++ joint_matrix---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
// XFAIL:*
14+
15+
#include <iostream>
16+
#include <random>
17+
#include <sycl/sycl.hpp>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::intel;
21+
using namespace sycl::ext::oneapi;
22+
using namespace sycl::ext::oneapi::experimental::matrix;
23+
using bfloat16 = sycl::ext::oneapi::bfloat16;
24+
25+
#define SG_SZ 16
26+
27+
#include "get_coord_bf16_matA_impl.hpp"

0 commit comments

Comments
 (0)