Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 1cf4f4c

Browse files
authored
[SYCL][Matrix] Adding test cases for the joint_matrix_apply() and fixing namespace for get_wi_data() (#1636)
for both ATS-M and PVC. Regarding the namespace change, the tests will pass once [this PR](intel/llvm#8417) gets approved.
1 parent 5c4c4ec commit 1cf4f4c

10 files changed

+189
-23
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//==----------- joint_matrix_apply_bf16.cpp - DPC++ joint_matrix-----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix-xmx8
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <iostream>
15+
#include <random>
16+
#include <sycl/sycl.hpp>
17+
18+
using namespace sycl;
19+
using namespace sycl::ext::oneapi::experimental::matrix;
20+
using bfloat16 = sycl::ext::oneapi::bfloat16;
21+
22+
#define SG_SZ 8
23+
24+
#include "../joint_matrix_apply_bf16_impl.hpp"

SYCL/Matrix/element_wise_all_ops_bf16_impl.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
5050

5151
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
5252

53-
auto wi_slice_a = get_wi_data(sg, sub_a);
53+
auto wi_slice_a =
54+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
5455
for (int i = 0; i < wi_slice_a.length(); i++) {
5556
wi_slice_a[i] = wi_slice_a[i] + bfloat16(2);
5657
}
@@ -85,7 +86,8 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
8586

8687
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
8788

88-
auto wi_slice_a = get_wi_data(sg, sub_a);
89+
auto wi_slice_a =
90+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
8991
for (int i = 0; i < wi_slice_a.length(); i++) {
9092
wi_slice_a[i] = wi_slice_a[i] - bfloat16(2);
9193
}
@@ -118,7 +120,8 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
118120
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
119121
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
120122

121-
auto wi_slice_a = get_wi_data(sg, sub_a);
123+
auto wi_slice_a =
124+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
122125
for (int i = 0; i < wi_slice_a.length(); i++) {
123126
wi_slice_a[i] = wi_slice_a[i] * bfloat16(3.0);
124127
}
@@ -152,7 +155,8 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
152155

153156
joint_matrix_fill(sg, sub_a, bfloat16(4.0));
154157

155-
auto wi_slice_a = get_wi_data(sg, sub_a);
158+
auto wi_slice_a =
159+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
156160
for (int i = 0; i < wi_slice_a.length(); i++) {
157161
wi_slice_a[i] = wi_slice_a[i] / bfloat16(2.0);
158162
}
@@ -185,7 +189,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
185189

186190
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
187191

188-
auto wi_slice_a = get_wi_data(sg, sub_a);
192+
auto wi_slice_a =
193+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
189194
for (int i = 0; i < wi_slice_a.length(); i++) {
190195
if (wi_slice_a[i]) {
191196
if (wi_slice_a[i] > bfloat16(2.0) ||

SYCL/Matrix/element_wise_all_ops_half_impl.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
4242

4343
joint_matrix_fill(sg, sub_a, 5);
4444

45-
auto wi_slice_a = get_wi_data(sg, sub_a);
45+
auto wi_slice_a =
46+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
4647
for (int i = 0; i < wi_slice_a.length(); i++) {
4748
wi_slice_a[i] = wi_slice_a[i] + static_cast<half>(2);
4849
}
@@ -76,7 +77,8 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7677

7778
joint_matrix_fill(sg, sub_a, 5);
7879

79-
auto wi_slice_a = get_wi_data(sg, sub_a);
80+
auto wi_slice_a =
81+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
8082
for (int i = 0; i < wi_slice_a.length(); i++) {
8183
wi_slice_a[i] = wi_slice_a[i] - static_cast<half>(2);
8284
}
@@ -110,7 +112,8 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
110112

111113
joint_matrix_fill(sg, sub_a, 5);
112114

113-
auto wi_slice_a = get_wi_data(sg, sub_a);
115+
auto wi_slice_a =
116+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
114117
for (int i = 0; i < wi_slice_a.length(); i++) {
115118
wi_slice_a[i] = wi_slice_a[i] * static_cast<half>(3.0);
116119
}
@@ -144,7 +147,8 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
144147

145148
joint_matrix_fill(sg, sub_a, 4);
146149

147-
auto wi_slice_a = get_wi_data(sg, sub_a);
150+
auto wi_slice_a =
151+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
148152
for (int i = 0; i < wi_slice_a.length(); i++) {
149153
wi_slice_a[i] = wi_slice_a[i] / static_cast<half>(2.0);
150154
}
@@ -178,7 +182,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
178182

179183
joint_matrix_fill(sg, sub_a, 5);
180184

181-
auto wi_slice_a = get_wi_data(sg, sub_a);
185+
auto wi_slice_a =
186+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
182187
for (int i = 0; i < wi_slice_a.length(); i++) {
183188
if (wi_slice_a[i]) {
184189
if (wi_slice_a[i] > static_cast<half>(2.0) ||

SYCL/Matrix/element_wise_all_ops_int8_impl.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
4141

4242
joint_matrix_fill(sg, sub_a, 5);
4343

44-
auto wi_slice_a = get_wi_data(sg, sub_a);
44+
auto wi_slice_a =
45+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
4546
for (int i = 0; i < wi_slice_a.length(); i++) {
4647
wi_slice_a[i] = wi_slice_a[i] + 2;
4748
}
@@ -75,7 +76,8 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7576

7677
joint_matrix_fill(sg, sub_a, 5);
7778

78-
auto wi_slice_a = get_wi_data(sg, sub_a);
79+
auto wi_slice_a =
80+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
7981
for (int i = 0; i < wi_slice_a.length(); i++) {
8082
wi_slice_a[i] = wi_slice_a[i] - 2;
8183
}
@@ -109,7 +111,8 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
109111

110112
joint_matrix_fill(sg, sub_a, 5);
111113

112-
auto wi_slice_a = get_wi_data(sg, sub_a);
114+
auto wi_slice_a =
115+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
113116
for (int i = 0; i < wi_slice_a.length(); i++) {
114117
wi_slice_a[i] = wi_slice_a[i] * 3;
115118
}
@@ -143,7 +146,8 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
143146

144147
joint_matrix_fill(sg, sub_a, 4);
145148

146-
auto wi_slice_a = get_wi_data(sg, sub_a);
149+
auto wi_slice_a =
150+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
147151
for (int i = 0; i < wi_slice_a.length(); i++) {
148152
wi_slice_a[i] = wi_slice_a[i] / 2;
149153
}
@@ -177,7 +181,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
177181

178182
joint_matrix_fill(sg, sub_a, 5);
179183

180-
auto wi_slice_a = get_wi_data(sg, sub_a);
184+
auto wi_slice_a =
185+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
181186
for (int i = 0; i < wi_slice_a.length(); i++) {
182187
if (wi_slice_a[i]) {
183188
if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 ||

SYCL/Matrix/element_wise_all_ops_int8_packed_impl.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
4343

4444
joint_matrix_fill(sg, sub_b, 5);
4545

46-
auto wi_slice_b = get_wi_data(sg, sub_b);
46+
auto wi_slice_b =
47+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
4748
for (int i = 0; i < wi_slice_b.length(); i++) {
4849
wi_slice_b[i] = wi_slice_b[i] + 2;
4950
}
@@ -79,7 +80,8 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7980

8081
joint_matrix_fill(sg, sub_b, 5);
8182

82-
auto wi_slice_b = get_wi_data(sg, sub_b);
83+
auto wi_slice_b =
84+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
8385
for (int i = 0; i < wi_slice_b.length(); i++) {
8486
wi_slice_b[i] = wi_slice_b[i] - 2;
8587
}
@@ -115,7 +117,8 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
115117

116118
joint_matrix_fill(sg, sub_b, 5);
117119

118-
auto wi_slice_b = get_wi_data(sg, sub_b);
120+
auto wi_slice_b =
121+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
119122
for (int i = 0; i < wi_slice_b.length(); i++) {
120123
wi_slice_b[i] = wi_slice_b[i] * 3;
121124
}
@@ -151,7 +154,8 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
151154

152155
joint_matrix_fill(sg, sub_b, 4);
153156

154-
auto wi_slice_b = get_wi_data(sg, sub_b);
157+
auto wi_slice_b =
158+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
155159
for (int i = 0; i < wi_slice_b.length(); i++) {
156160
wi_slice_b[i] = wi_slice_b[i] / 2;
157161
}
@@ -187,7 +191,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
187191

188192
joint_matrix_fill(sg, sub_b, 5);
189193

190-
auto wi_slice_b = get_wi_data(sg, sub_b);
194+
auto wi_slice_b =
195+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
191196
for (int i = 0; i < wi_slice_b.length(); i++) {
192197
if (wi_slice_b[i]) {
193198
if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 ||

SYCL/Matrix/element_wise_irreg_sum_rows_impl.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ void matrix_sum_rows(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
5757
// (tK/4)
5858
int32_t sum_local_rows[M] = {0}; // 8 local rows, M total
5959
// sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row
60-
auto data = get_wi_data(sg, sub_b);
60+
auto data =
61+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
6162

6263
// each WI calculates local sum of rows
6364
for (int row = 0; row < TK / 4; row++) { // there are 8 rows

SYCL/Matrix/element_wise_ops_impl.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
7171
N * 4);
7272
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
7373
}
74-
auto wi_slice_c = get_wi_data(sg, sub_c);
74+
auto wi_slice_c =
75+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
7576
for (int i = 0; i < wi_slice_c.length(); i++) {
7677
wi_slice_c[i] *= 2;
7778
}

SYCL/Matrix/elemwise_irreg_size_ops_bf16.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
9797
N * 2);
9898
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
9999
}
100-
auto wi_slice_c = get_wi_data(sg, sub_c);
100+
auto wi_slice_c =
101+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
101102
for (int i = 0; i < wi_slice_c.length(); i++) {
102103
wi_slice_c[i] += 5.0;
103104
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//==----------- joint_matrix_apply_bf16.cpp - DPC++ joint_matrix-----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
12+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
13+
14+
#include <iostream>
15+
#include <random>
16+
#include <sycl/sycl.hpp>
17+
18+
using namespace sycl;
19+
using namespace sycl::ext::oneapi::experimental::matrix;
20+
using bfloat16 = sycl::ext::oneapi::bfloat16;
21+
22+
#define SG_SZ 16
23+
24+
#include "joint_matrix_apply_bf16_impl.hpp"
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
2+
#define TM 8
3+
#define TN SG_SZ
4+
#define TK 16
5+
6+
static float make_fp32(bfloat16 x) {
7+
unsigned int y = sycl::bit_cast<uint16_t>(x);
8+
y = y << 16;
9+
float *res = reinterpret_cast<float *>(&y);
10+
return *res;
11+
}
12+
13+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
14+
public:
15+
T *mat;
16+
17+
public:
18+
T *get_data() { return mat; }
19+
void set_data(T *data) { mat = data; }
20+
big_matrix(T *data) : mat(data) {}
21+
};
22+
23+
template <typename T> struct apply_add {
24+
void operator()(T &x) const { x = x + bfloat16(2); }
25+
};
26+
27+
template <typename T, size_t M, size_t N, typename F>
28+
void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
29+
const float ref, F &&lambda) {
30+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
31+
32+
q.submit([&](handler &cgh) {
33+
accessor accA{bufA, cgh};
34+
35+
cgh.parallel_for(r, [accA, lambda](
36+
nd_item<2> spmd_item) [[intel::reqd_sub_group_size(
37+
SG_SZ)]] {
38+
const auto global_idx = spmd_item.get_global_id(0);
39+
const auto global_idy = spmd_item.get_global_id(1);
40+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
41+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
42+
43+
sub_group sg = spmd_item.get_sub_group();
44+
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
45+
46+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
47+
48+
joint_matrix_apply(sg, sub_a, lambda);
49+
50+
ext::intel::experimental::matrix::joint_matrix_store(
51+
sg, sub_a,
52+
accA.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
53+
N);
54+
}); // parallel for
55+
}).wait();
56+
// Check if the results are correct
57+
{
58+
host_accessor Acc{bufA};
59+
assert(std::all_of(Acc.begin(), Acc.end(), [=](auto Elem) {
60+
return (std::fabs(static_cast<float>(make_fp32(Elem) - ref)) <
61+
std::numeric_limits<float>::epsilon());
62+
}));
63+
}
64+
}
65+
66+
static constexpr size_t MATRIX_M = TM * 2;
67+
static constexpr size_t MATRIX_N = TN * 2;
68+
bfloat16 A[MATRIX_M][MATRIX_N];
69+
float D[MATRIX_M][MATRIX_N];
70+
71+
void matrix_ops_ref(float *D, int M, int N) {
72+
for (int m = 0; m < M; m++)
73+
for (int n = 0; n < N; n++) {
74+
*(D + m * N + n) = 0;
75+
*(D + m * N + n) *= 2;
76+
}
77+
}
78+
79+
int main() {
80+
81+
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
82+
big_matrix<bfloat16, MATRIX_M, MATRIX_N> MA((bfloat16 *)&A);
83+
84+
size_t NDRangeM = MATRIX_M / TM;
85+
size_t NDRangeN = MATRIX_N / TN;
86+
queue q;
87+
nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});
88+
89+
matrix_verify_add<bfloat16, MATRIX_M, MATRIX_N>(
90+
q, MA, r, 7.0, [=](bfloat16 &x) { x = x + bfloat16(2); });
91+
matrix_verify_add<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 7.0,
92+
apply_add<bfloat16>());
93+
std::cout << "Passed\n";
94+
return 0;
95+
}

0 commit comments

Comments
 (0)