Skip to content

Commit 57098b0

Browse files
authored
[SYCL][Joint Matrix] Add test for multiple elem-wise ops (#10258)
Add new test for multiple elem-wise ops for different matrix types in the same kernel
1 parent dd89b19 commit 57098b0

File tree

3 files changed

+144
-0
lines changed

3 files changed

+144
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//==----------- element_wise_abc.cpp - DPC++ joint_matrix------------- ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-xmx8
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
#define SG_SZ 8
14+
15+
#include "../element_wise_abc_impl.hpp"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//==----------- element_wise_abc.cpp - DPC++ joint_matrix------------- ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
#define SG_SZ 16
14+
15+
#include "element_wise_abc_impl.hpp"
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//==----------- element_wise_abc_impl.hpp - DPC++ joint_matrix-------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <iostream>
10+
#include <sycl/sycl.hpp>
11+
12+
using namespace sycl;
13+
using namespace sycl::ext::oneapi::experimental::matrix;
14+
15+
#define TM 8
16+
#define TN SG_SZ
17+
#define TK 32
18+
19+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
20+
public:
21+
T *mat;
22+
23+
public:
24+
T *get_data() { return mat; }
25+
void set_data(T *data) { mat = data; }
26+
big_matrix(T *data) : mat(data) {}
27+
};
28+
29+
template <typename T1, typename T2, size_t M, size_t N, size_t K,
30+
int vnniFactor>
31+
void matrix_elem_wise_ops(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
32+
big_matrix<T2, K / vnniFactor, N * vnniFactor> &B) {
33+
size_t NDRangeM = M / TM;
34+
size_t NDRangeN = N / TN;
35+
buffer<T2, 2> bufA(A.get_data(), range<2>(M, K));
36+
buffer<T2, 2> bufB(B.get_data(), range<2>(K, N));
37+
buffer<T1, 2> bufC(C.get_data(), range<2>(M, N));
38+
39+
queue q;
40+
q.submit([&](handler &cgh) {
41+
accessor accC{bufC, cgh};
42+
accessor accA{bufA, cgh};
43+
accessor accB{bufB, cgh};
44+
45+
cgh.parallel_for(
46+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
47+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
48+
// The submatrix API has to be accessed by all the workitems in a
49+
// subgroup these functions will be called once by the subgroup no
50+
// code divergence between the workitems
51+
const auto global_idx = spmd_item.get_global_id(0);
52+
const auto global_idy = spmd_item.get_global_id(1);
53+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
54+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
55+
56+
sub_group sg = spmd_item.get_sub_group();
57+
joint_matrix<sub_group, T2, use::a, TM, TK, layout::row_major> sub_a;
58+
// For B, we assume B has been already VNNIed.
59+
joint_matrix<sub_group, T2, use::b, TK, TN,
60+
ext::intel::experimental::matrix::layout::packed>
61+
sub_b;
62+
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
63+
64+
joint_matrix_load(
65+
sg, sub_a,
66+
accA.template get_multi_ptr<access::decorated::no>() +
67+
(sg_startx * TM) * K,
68+
K);
69+
auto wi_slice_a =
70+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
71+
for (int i = 0; i < wi_slice_a.length(); i++) {
72+
wi_slice_a[i] += 1;
73+
}
74+
75+
joint_matrix_load(
76+
sg, sub_b,
77+
accB.template get_multi_ptr<access::decorated::no>() +
78+
sg_starty / SG_SZ * TN * vnniFactor,
79+
N * vnniFactor);
80+
auto wi_slice_b =
81+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
82+
for (int i = 0; i < wi_slice_b.length(); i++) {
83+
wi_slice_b[i] += 1;
84+
}
85+
86+
joint_matrix_load(
87+
sg, sub_c,
88+
accC.template get_multi_ptr<access::decorated::no>() +
89+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
90+
N, layout::row_major);
91+
auto wi_slice_c =
92+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
93+
for (int i = 0; i < wi_slice_c.length(); i++) {
94+
wi_slice_c[i] += 1;
95+
}
96+
}); // parallel for
97+
}).wait();
98+
}
99+
100+
int main() {
101+
static constexpr unsigned vnniFactor = 4;
102+
103+
int8_t A[TM][TK];
104+
int8_t B[TK / vnniFactor][TN * vnniFactor];
105+
int32_t C[TM][TN];
106+
107+
big_matrix<int32_t, TM, TN> MC((int32_t *)&C);
108+
big_matrix<int8_t, TM, TK> MA((int8_t *)&A);
109+
big_matrix<int8_t, TK / vnniFactor, TN * vnniFactor> MB((int8_t *)&B);
110+
111+
matrix_elem_wise_ops<int32_t, int8_t, TM, TN, TK, vnniFactor>(MC, MA, MB);
112+
113+
return 0;
114+
}

0 commit comments

Comments
 (0)