Skip to content

Commit 7f21c27

Browse files
committed
[SYCL][Joint Matirx] Add test for multiple elem-wise ops
Add new test for multiple elem-wise ops for different matrix types in the same kernel
1 parent 474461c commit 7f21c27

File tree

3 files changed

+157
-0
lines changed

3 files changed

+157
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//==----------- element_wise_abc.cpp - DPC++ joint_matrix------------- ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-xmx8
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
#define SG_SZ 8
14+
15+
#include "../element_wise_abc_impl.hpp"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//==----------- element_wise_abc.cpp - DPC++ joint_matrix------------- ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
#define SG_SZ 16
14+
15+
#include "element_wise_abc_impl.hpp"
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//==----------- element_wise_abc_impl.hpp - DPC++ joint_matrix-------------
2+
//----==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include <iostream>
11+
#include <sycl/sycl.hpp>
12+
13+
using namespace sycl;
14+
using namespace sycl::ext::oneapi::experimental::matrix;
15+
16+
#define TM 8
17+
#define TN SG_SZ
18+
#define TK 32
19+
20+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
21+
public:
22+
T *mat;
23+
24+
public:
25+
T *get_data() { return mat; }
26+
void set_data(T *data) { mat = data; }
27+
big_matrix(T *data) : mat(data) {}
28+
};
29+
30+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
31+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
32+
size_t NUM_COLS_C>
33+
void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
34+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
35+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
36+
size_t M = NUM_ROWS_C;
37+
size_t N = NUM_COLS_C;
38+
size_t K = NUM_COLS_A;
39+
40+
// B => K/4 x N*4, A => M x K, C => M, N
41+
// stride should be X's cols, e.g., B's stirde = N*4
42+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
43+
44+
size_t NDRangeM = M / TM;
45+
size_t NDRangeN = N / TN;
46+
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
47+
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N));
48+
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
49+
50+
queue q;
51+
q.submit([&](handler &cgh) {
52+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
53+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
54+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
55+
56+
cgh.parallel_for(
57+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
58+
[accA, accB, accC, M, N,
59+
K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
60+
// The submatrix API has to be accessed by all the workitems in a
61+
// subgroup these functions will be called once by the subgroup no
62+
// code divergence between the workitems
63+
const auto global_idx = spmd_item.get_global_id(0);
64+
const auto global_idy = spmd_item.get_global_id(1);
65+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
66+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
67+
68+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
69+
joint_matrix<sub_group, int8_t, use::a, TM, TK, layout::row_major>
70+
sub_a;
71+
72+
// For B, we assume B has been already VNNIed.
73+
joint_matrix<sub_group, int8_t, use::b, TK, TN,
74+
ext::intel::experimental::matrix::layout::packed>
75+
sub_b;
76+
77+
joint_matrix<sub_group, int32_t, use::accumulator, TM, TN> sub_c;
78+
79+
joint_matrix_load(
80+
sg, sub_c,
81+
accC.template get_multi_ptr<access::decorated::no>() +
82+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
83+
N, layout::row_major);
84+
85+
joint_matrix_load(
86+
sg, sub_a,
87+
accA.template get_multi_ptr<access::decorated::no>() +
88+
(sg_startx * TM) * K,
89+
K);
90+
auto wi_slice_a =
91+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
92+
for (int i = 0; i < wi_slice_a.length(); i++) {
93+
wi_slice_a[i] += 1;
94+
}
95+
96+
joint_matrix_load(
97+
sg, sub_b,
98+
accB.template get_multi_ptr<access::decorated::no>() +
99+
+sg_starty / SG_SZ * TN * 4,
100+
N * 4);
101+
auto wi_slice_b =
102+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
103+
for (int i = 0; i < wi_slice_b.length(); i++) {
104+
wi_slice_b[i] += 1;
105+
}
106+
107+
auto wi_slice_c =
108+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
109+
for (int i = 0; i < wi_slice_c.length(); i++) {
110+
wi_slice_c[i] += 1;
111+
}
112+
}); // parallel for
113+
}).wait();
114+
}
115+
116+
int8_t A[TM][TK];
117+
int8_t B[TK / 4][TN * 4];
118+
int32_t C[TM][TN];
119+
120+
int main() {
121+
big_matrix<int32_t, TM, TN> MC((int32_t *)&C);
122+
big_matrix<int8_t, TM, TK> MA((int8_t *)&A);
123+
big_matrix<int8_t, TK / 4, TN * 4> MB((int8_t *)&B);
124+
matrix_elem_wise_ops(MC, MA, MB);
125+
126+
return 0;
127+
}

0 commit comments

Comments
 (0)