Skip to content

Commit c873fe2

Browse files
authored
[SYCL][joint matrix] add missing licence to test and add combination-based query (#12489)
1 parent f3d30b6 commit c873fe2

File tree

5 files changed

+111
-107
lines changed

5 files changed

+111
-107
lines changed

sycl/test-e2e/Matrix/SG32/element_wise_all_ops.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,16 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
// REQUIRES: matrix
8+
// REQUIRES: cpu, gpu
99
// REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943
10+
// SG size = 32 is not currently supported for SYCL Joint Matrix by IGC on DG2
11+
// UNSUPPORTED: gpu-intel-dg2
1012

1113
// RUN: %{build} -o %t.out
1214
// RUN: %{run} %t.out
1315

14-
#include <iostream>
15-
#include <random>
16-
#include <sycl/sycl.hpp>
16+
#include "../common.hpp"
1717

18-
using namespace sycl;
19-
using namespace sycl::ext::oneapi::experimental::matrix;
20-
using bfloat16 = sycl::ext::oneapi::bfloat16;
21-
22-
constexpr size_t SG_SZ = 32;
23-
constexpr size_t TN = 16;
18+
#define SG_SZ 32
2419

2520
#include "../element_wise_all_ops_impl.hpp"

sycl/test-e2e/Matrix/XMX8/element_wise_all_ops.cpp

Lines changed: 0 additions & 25 deletions
This file was deleted.

sycl/test-e2e/Matrix/common.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
//==------------------ common.hpp - DPC++ joint_matrix---------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
18
#include <cmath>
29
#include <iostream>
310
#include <random>
@@ -173,3 +180,27 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
173180
}
174181
return true;
175182
}
183+
184+
bool is_type_supported_by_device(queue q, matrix_type type) {
185+
std::vector<combination> combinations =
186+
q.get_device()
187+
.get_info<sycl::ext::oneapi::experimental::info::device::
188+
matrix_combinations>();
189+
for (int i = 0; i < combinations.size(); i++) {
190+
if (combinations[i].atype == type) {
191+
return true;
192+
}
193+
}
194+
return false;
195+
}
196+
197+
template <typename KernelName> size_t get_sg_size(queue q) {
198+
auto KernelID = get_kernel_id<KernelName>();
199+
auto KB =
200+
get_kernel_bundle<bundle_state::executable>(q.get_context(), {KernelID});
201+
auto kernel = KB.get_kernel(KernelID);
202+
203+
return kernel
204+
.template get_info<info::kernel_device_specific::max_sub_group_size>(
205+
q.get_device());
206+
}

sycl/test-e2e/Matrix/element_wise_all_ops.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,14 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
// REQUIRES: matrix
8+
// REQUIRES: cpu, gpu
9+
// Test is flaky/timeouts on some variants of DG2 and temporary disabled. Needs
10+
// to be investigated.
11+
// UNSUPPORTED: gpu-intel-dg2
912

1013
// RUN: %{build} -o %t.out
1114
// RUN: %{run} %t.out
1215

13-
#include <iostream>
14-
#include <random>
15-
#include <sycl/sycl.hpp>
16-
17-
using namespace sycl;
18-
using namespace sycl::ext::oneapi::experimental::matrix;
19-
using bfloat16 = sycl::ext::oneapi::bfloat16;
20-
21-
#define SG_SZ 16
22-
constexpr size_t TN = 16;
16+
#include "common.hpp"
2317

2418
#include "element_wise_all_ops_impl.hpp"

sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 70 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,6 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
9-
static float make_fp32(bfloat16 x) {
10-
unsigned int y = *((int *)&x);
11-
y = y << 16;
12-
float *res = reinterpret_cast<float *>(&y);
13-
return *res;
14-
}
15-
16-
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
17-
public:
18-
T *mat;
19-
20-
public:
21-
T *get_data() { return mat; }
22-
void set_data(T *data) { mat = data; }
23-
big_matrix(T *data) : mat(data) {}
24-
};
25-
268
template <typename T, size_t NUM_ROWS, size_t NUM_COLS>
279
void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
2810
const float ref) {
@@ -39,20 +21,25 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
3921
}
4022

4123
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
42-
size_t SUB_COLS, typename OP>
24+
size_t SUB_COLS, class kernel_name, typename OP>
4325
void verify_op_a(const T l, const T r, const float ref, OP op) {
4426
T mat[NUM_ROWS][NUM_COLS];
4527
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat((T *)&mat);
4628

4729
buffer<T, 2> bufMat(big_mat.get_data(), range<2>(NUM_ROWS, NUM_COLS));
4830

4931
queue q;
32+
size_t sg_size = get_sg_size<kernel_name>(q);
5033
q.submit([&](handler &cgh) {
5134
sycl::accessor accessMat{bufMat, cgh, sycl::read_write};
52-
cgh.parallel_for(
53-
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * SG_SZ},
54-
{1, 1 * SG_SZ}),
55-
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
35+
cgh.parallel_for<kernel_name>(
36+
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * sg_size},
37+
{1, 1 * sg_size}),
38+
[=](nd_item<2> spmd_item)
39+
#ifdef SG_SZ
40+
[[intel::reqd_sub_group_size(SG_SZ)]]
41+
#endif
42+
{
5643
const auto global_idx = spmd_item.get_global_id(0);
5744
const auto global_idy = spmd_item.get_global_id(1);
5845
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
@@ -68,28 +55,32 @@ void verify_op_a(const T l, const T r, const float ref, OP op) {
6855
sg, sub_mat,
6956
accessMat.template get_multi_ptr<access::decorated::no>() +
7057
(sg_startx * SUB_ROWS) * NUM_COLS +
71-
sg_starty / SG_SZ * SUB_COLS,
58+
sg_starty / sg_size * SUB_COLS,
7259
NUM_COLS);
7360
}); // parallel for
7461
}).wait();
7562
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access(read_only), ref);
7663
}
7764

7865
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
79-
size_t SUB_COLS, typename OP>
66+
size_t SUB_COLS, class kernel_name, typename OP>
8067
void verify_op_c(const T l, const T r, const float ref, OP op) {
8168
T mat[NUM_ROWS][NUM_COLS];
8269
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat((T *)&mat);
8370

8471
buffer<T, 2> bufMat(big_mat.get_data(), range<2>(NUM_ROWS, NUM_COLS));
85-
8672
queue q;
73+
size_t sg_size = get_sg_size<kernel_name>(q);
8774
q.submit([&](handler &cgh) {
8875
sycl::accessor accessMat{bufMat, cgh, sycl::read_write};
89-
cgh.parallel_for(
90-
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * SG_SZ},
91-
{1, 1 * SG_SZ}),
92-
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
76+
cgh.parallel_for<kernel_name>(
77+
nd_range<2>({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * sg_size},
78+
{1, 1 * sg_size}),
79+
[=](nd_item<2> spmd_item)
80+
#ifdef SG_SZ
81+
[[intel::reqd_sub_group_size(SG_SZ)]]
82+
#endif
83+
{
9384
const auto global_idx = spmd_item.get_global_id(0);
9485
const auto global_idy = spmd_item.get_global_id(1);
9586
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
@@ -105,85 +96,103 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
10596
sg, sub_mat,
10697
accessMat.template get_multi_ptr<access::decorated::no>() +
10798
(sg_startx * SUB_ROWS) * NUM_COLS +
108-
sg_starty / SG_SZ * SUB_COLS,
99+
sg_starty / sg_size * SUB_COLS,
109100
NUM_COLS, layout::row_major);
110101
}); // parallel for
111102
}).wait();
112103
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access(read_only), ref);
113104
}
114105

106+
// Avoid same kernel name for different types
107+
template <typename T, class name> class ewops_a {};
115108
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
116109
void test_ewops_a() {
117110

118-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
111+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add>>(
119112
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
120-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
113+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_sub>>(
121114
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
122-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
115+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_mul>>(
123116
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
124-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
117+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_div>>(
125118
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
126-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
119+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_logical>>(
127120
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
128-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
121+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_eq>>(
129122
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
130-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
123+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ne>>(
131124
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
132-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
125+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_gt>>(
133126
T(5.0), T(2.0), 3.0,
134127
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
135-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
128+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_lt>>(
136129
T(5.0), T(2.0), 2.0,
137130
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
138-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
131+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ge>>(
139132
T(5.0), T(2.0), 3.0,
140133
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
141-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
134+
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_le>>(
142135
T(5.0), T(2.0), 2.0,
143136
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
144137
}
145-
138+
// Avoid same kernel name for different types and numbers of columns
139+
template <typename T, size_t COLS, class name> class ewops_c {};
146140
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
147141
void test_ewops_c() {
148142

149-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
143+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_add>>(
150144
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
151-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
145+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_sub>>(
152146
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
153-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
147+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_mul>>(
154148
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
155-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
149+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_div>>(
156150
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
157-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
151+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
152+
ewops_c<T, SCOLS, class c_logical>>(
158153
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
159-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
154+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_eq>>(
160155
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
161-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
156+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ne>>(
162157
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
163-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
158+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_gt>>(
164159
T(5.0), T(2.0), 3.0,
165160
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
166-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
161+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_lt>>(
167162
T(5.0), T(2.0), 2.0,
168163
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
169-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
164+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ge>>(
170165
T(5.0), T(2.0), 3.0,
171166
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
172-
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
167+
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_le>>(
173168
T(5.0), T(2.0), 2.0,
174169
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
175170
}
176171

177172
int main() {
178173
static constexpr size_t TM = 8;
179-
static constexpr size_t TK = 16;
180174

181175
static constexpr size_t MATRIX_M = TM * 2;
182-
static constexpr size_t MATRIX_N = TN * 2;
183-
static constexpr size_t MATRIX_K = TK * 2;
184-
185-
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, TK>();
186-
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, TN>();
187-
176+
static constexpr size_t MATRIX_N = 32;
177+
static constexpr size_t MATRIX_K = 32;
178+
queue q;
179+
std::vector<combination> combinations =
180+
q.get_device()
181+
.get_info<sycl::ext::oneapi::experimental::info::device::
182+
matrix_combinations>();
183+
for (unsigned int i = 0; i < combinations.size(); i++) {
184+
if (combinations[i].atype == matrix_type::bf16) {
185+
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
186+
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
187+
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 16>();
188+
break;
189+
}
190+
if (combinations[i].nsize == 8) {
191+
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
192+
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 8>();
193+
break;
194+
}
195+
}
196+
}
188197
return 0;
189198
}

0 commit comments

Comments
 (0)