Skip to content

Commit 981e745

Browse files
authored
[SYCL][E2E][Joint Matrix] New test transpose A and B (#16684)
1 parent b9b1f88 commit 981e745

File tree

2 files changed

+149
-3
lines changed

2 files changed

+149
-3
lines changed

sycl/test-e2e/Matrix/Inputs/common.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
183183
}
184184
} else if constexpr (exact || std::is_integral_v<T1>) {
185185
if (src[i * cols + j] != ref[i * cols + j]) {
186-
std::cerr << "Incorrect result in matrix."
186+
std::cerr << "Incorrect result in matrix. "
187187
<< "i: " << i << ", j: " << j
188-
<< ", Ref: " << ref[i * cols + j]
189-
<< ", Val: " << src[i * cols + j] << "\n";
188+
<< ", Ref: " << (int)ref[i * cols + j]
189+
<< ", Val: " << (int)src[i * cols + j] << "\n";
190190
return false;
191191
}
192192
} else {
@@ -221,3 +221,16 @@ template <typename KernelName> size_t get_sg_size(queue q) {
221221
.template get_info<info::kernel_device_specific::max_sub_group_size>(
222222
q.get_device());
223223
}
224+
225+
template <typename T>
226+
void matrix_print(unsigned int rows, unsigned int cols, T *mat) {
227+
for (unsigned int i = 0; i < rows; i++) {
228+
for (unsigned int j = 0; j < cols; j++) {
229+
if constexpr (std::is_integral_v<T>)
230+
std::cout << (int)mat[i * cols + j] << " ";
231+
else
232+
std::cout << (float)mat[i * cols + j] << " ";
233+
}
234+
std::cout << "\n";
235+
}
236+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
//===---joint_matrix_transposeAB.cpp - DPC++ joint_matrix--------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: aspect-ext_intel_matrix
9+
10+
// RUN: %{build} -o %t.out
11+
// RUN: %{run} %t.out
12+
13+
// SG size = 32 is not currently supported for SYCL Joint Matrix by IGC on DG2
14+
// RUN: %if !arch-intel_gpu_dg2 %{ %{build} -o %t_sg32.out -DSG_SZ=32 %}
15+
// RUN: %if !arch-intel_gpu_dg2 %{ %{run} %t_sg32.out %}
16+
17+
// XFAIL: gpu
18+
// XFAIL-TRACKER: GSD-5768
19+
20+
// XFAIL: cpu
21+
// XFAIL-TRACKER: CMPLRLLVM-52693
22+
23+
#include "common.hpp"
24+
#include <sycl/usm.hpp>
25+
26+
template <typename T, size_t TileRows, size_t TileCols> class MT;
27+
28+
template <size_t TR, size_t TC, typename T, size_t NR, size_t NC, use Use>
29+
void matrix_transpose(T *in, T *out, queue q) {
30+
static_assert((NR % TR) == 0);
31+
static_assert((NC % TC) == 0);
32+
size_t sg_size = get_sg_size<class MT<T, TR, TC>>(q);
33+
std::cout << "SG size " << sg_size << " ";
34+
35+
q.submit([&](handler &cgh) {
36+
cgh.parallel_for<class MT<T, TR, TC>>(
37+
nd_range<2>({NR / TR, NC / TC * sg_size}, {1, 1 * sg_size}),
38+
[=](nd_item<2> spmd_item)
39+
#ifdef SG_SZ
40+
[[sycl::reqd_sub_group_size(SG_SZ)]]
41+
#endif
42+
{
43+
auto in_ptr =
44+
address_space_cast<sycl::access::address_space::global_space,
45+
sycl::access::decorated::no>(in);
46+
auto out_ptr =
47+
address_space_cast<sycl::access::address_space::global_space,
48+
sycl::access::decorated::no>(out);
49+
50+
const auto global_idx = spmd_item.get_global_id(0);
51+
const auto global_idy = spmd_item.get_global_id(1);
52+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
53+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
54+
55+
sub_group sg = spmd_item.get_sub_group();
56+
joint_matrix<sub_group, T, Use, TR, TC, layout::row_major>
57+
matrix_row_major;
58+
joint_matrix<sub_group, T, Use, TR, TC, layout::col_major>
59+
matrix_col_major;
60+
61+
auto row_major_offset =
62+
(sg_startx * TR) * NC + sg_starty / sg_size * TC;
63+
auto col_major_offset =
64+
(sg_startx * TR) + (sg_starty / sg_size * TC) * NR;
65+
66+
joint_matrix_load(sg, matrix_row_major, in_ptr + row_major_offset,
67+
NC);
68+
joint_matrix_copy(sg, matrix_row_major, matrix_col_major);
69+
ext::intel::experimental::matrix::joint_matrix_store(
70+
sg, matrix_col_major, out_ptr + col_major_offset, NR);
71+
}); // parallel for
72+
}).wait();
73+
}
74+
75+
template <typename T, size_t TR, size_t TC, use Use> void test() {
76+
std::cout << "Test " << TR << " x " << TC << " ";
77+
static constexpr size_t SCALE = 2;
78+
static constexpr size_t MATRIX_R = TR * SCALE;
79+
static constexpr size_t MATRIX_C = TC * SCALE;
80+
81+
queue q;
82+
T *in = malloc_shared<T>(MATRIX_R * MATRIX_C, q);
83+
T *col_major = malloc_shared<T>(MATRIX_C * MATRIX_R, q);
84+
T *ref_col_major = malloc_shared<T>(MATRIX_C * MATRIX_R, q);
85+
86+
matrix_rand(MATRIX_R, MATRIX_C, in, (T)5);
87+
matrix_transpose<TR, TC, T, MATRIX_R, MATRIX_C, Use>(in, col_major, q);
88+
matrix_transpose(MATRIX_R, MATRIX_C, ref_col_major, in);
89+
assert((matrix_compare<T, T, true>(MATRIX_C, MATRIX_R, col_major,
90+
ref_col_major)));
91+
std::cout << "PASSED\n";
92+
93+
free(in, q);
94+
free(col_major, q);
95+
free(ref_col_major, q);
96+
}
97+
98+
int main() {
99+
queue q;
100+
std::vector<combination> combinations =
101+
q.get_device().get_info<syclex::info::device::matrix_combinations>();
102+
bool bf16_run = false;
103+
bool half_run = false;
104+
bool int8_run = false;
105+
106+
for (auto &combination : combinations) {
107+
if (!bf16_run && combination.atype == matrix_type::bf16) {
108+
std::cout << "bf16:\n";
109+
test<bfloat16, 8, 16, use::a>();
110+
test<bfloat16, 16, 16, use::b>();
111+
bf16_run = true;
112+
}
113+
114+
if (!half_run && combination.atype == matrix_type::fp16) {
115+
std::cout << "half:\n";
116+
test<half, 8, 16, use::a>();
117+
test<half, 16, 16, use::b>();
118+
half_run = true;
119+
}
120+
121+
if (!int8_run && combination.atype == matrix_type::sint8) {
122+
std::cout << "int8:\n";
123+
test<int8_t, 8, 32, use::a>();
124+
test<int8_t, 32, 16, use::b>();
125+
int8_run = true;
126+
}
127+
128+
if (bf16_run && half_run && int8_run)
129+
break;
130+
}
131+
132+
return 0;
133+
}

0 commit comments

Comments
 (0)