Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit f2e536c

Browse files
authored
[SYCL] Correct bfloat16 class namespace (#1468)
1 parent a795f30 commit f2e536c

16 files changed

+61
-320
lines changed

SYCL/Matrix/XMX8/element_wise_all_ops_bf16.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using namespace sycl;
1919
using namespace sycl::ext::intel;
2020
using namespace sycl::ext::oneapi::experimental::matrix;
21+
using bfloat16 = sycl::ext::oneapi::bfloat16;
2122

2223
#define SG_SZ 8
2324

SYCL/Matrix/XMX8/joint_matrix_bf16.cpp

Lines changed: 0 additions & 22 deletions
This file was deleted.

SYCL/Matrix/XMX8/joint_matrix_bfloat16_use.cpp

Lines changed: 0 additions & 24 deletions
This file was deleted.

SYCL/Matrix/XMX8/joint_matrix_ss_int8_use.cpp

Lines changed: 0 additions & 24 deletions
This file was deleted.

SYCL/Matrix/element_wise_all_ops_bf16.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using namespace sycl;
1919
using namespace sycl::ext::intel;
2020
using namespace sycl::ext::oneapi;
2121
using namespace sycl::ext::oneapi::experimental::matrix;
22+
using bfloat16 = sycl::ext::oneapi::bfloat16;
2223

2324
#define SG_SZ 16
2425

SYCL/Matrix/element_wise_all_ops_bf16_impl.hpp

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ static float make_fp32(uint16_t x) {
1010
return *res;
1111
}
1212

13-
static uint16_t make_bf16(float x) {
14-
int *res = reinterpret_cast<int *>(&x);
15-
*res = *res >> 16;
16-
return (uint16_t)*res;
17-
}
18-
1913
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
2014
public:
2115
T *mat;
@@ -40,7 +34,7 @@ void assert_ops_ref(
4034
template <typename T, size_t M, size_t N>
4135
void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
4236
const float ref) {
43-
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
37+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
4438

4539
q.submit([&](handler &cgh) {
4640
auto accA = bufA.get_access<access::mode::read_write>(cgh);
@@ -55,12 +49,13 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
5549
sub_group sg = spmd_item.get_sub_group();
5650
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
5751

58-
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
52+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
5953

6054
auto wi_slice_a = get_wi_data(sg, sub_a);
6155
for (int i = 0; i < wi_slice_a.length(); i++) {
62-
wi_slice_a[i] = wi_slice_a[i] + make_bf16(2);
56+
wi_slice_a[i] = wi_slice_a[i] + bfloat16(2);
6357
}
58+
6459
ext::intel::experimental::matrix::joint_matrix_store(
6560
sg, sub_a,
6661
accA.get_pointer() + (sg_startx * TM) * N +
@@ -74,7 +69,7 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7469
template <typename T, size_t M, size_t N>
7570
void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7671
const float ref) {
77-
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
72+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
7873

7974
q.submit([&](handler &cgh) {
8075
auto accA = bufA.get_access<access::mode::read_write>(cgh);
@@ -89,11 +84,11 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
8984
sub_group sg = spmd_item.get_sub_group();
9085
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
9186

92-
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
87+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
9388

9489
auto wi_slice_a = get_wi_data(sg, sub_a);
9590
for (int i = 0; i < wi_slice_a.length(); i++) {
96-
wi_slice_a[i] = wi_slice_a[i] - make_bf16(2);
91+
wi_slice_a[i] = wi_slice_a[i] - bfloat16(2);
9792
}
9893
ext::intel::experimental::matrix::joint_matrix_store(
9994
sg, sub_a,
@@ -108,7 +103,7 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
108103
template <typename T, size_t M, size_t N>
109104
void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
110105
const float ref) {
111-
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
106+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
112107

113108
q.submit([&](handler &cgh) {
114109
auto accA = bufA.get_access<access::mode::read_write>(cgh);
@@ -122,11 +117,11 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
122117

123118
sub_group sg = spmd_item.get_sub_group();
124119
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
125-
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
120+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
126121

127122
auto wi_slice_a = get_wi_data(sg, sub_a);
128123
for (int i = 0; i < wi_slice_a.length(); i++) {
129-
wi_slice_a[i] = wi_slice_a[i] * make_bf16(3.0);
124+
wi_slice_a[i] = wi_slice_a[i] * bfloat16(3.0);
130125
}
131126
ext::intel::experimental::matrix::joint_matrix_store(
132127
sg, sub_a,
@@ -141,7 +136,7 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
141136
template <typename T, size_t M, size_t N>
142137
void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
143138
const float ref) {
144-
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
139+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
145140

146141
q.submit([&](handler &cgh) {
147142
auto accA = bufA.get_access<access::mode::read_write>(cgh);
@@ -156,11 +151,11 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
156151
sub_group sg = spmd_item.get_sub_group();
157152
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
158153

159-
joint_matrix_fill(sg, sub_a, make_bf16(4.0));
154+
joint_matrix_fill(sg, sub_a, bfloat16(4.0));
160155

161156
auto wi_slice_a = get_wi_data(sg, sub_a);
162157
for (int i = 0; i < wi_slice_a.length(); i++) {
163-
wi_slice_a[i] = wi_slice_a[i] / make_bf16(2.0);
158+
wi_slice_a[i] = wi_slice_a[i] / bfloat16(2.0);
164159
}
165160
ext::intel::experimental::matrix::joint_matrix_store(
166161
sg, sub_a,
@@ -175,7 +170,7 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
175170
template <typename T, size_t M, size_t N>
176171
void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
177172
const float ref) {
178-
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, N));
173+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
179174

180175
q.submit([&](handler &cgh) {
181176
auto accA = bufA.get_access<access::mode::read_write>(cgh);
@@ -189,26 +184,26 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
189184
sub_group sg = spmd_item.get_sub_group();
190185
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
191186

192-
joint_matrix_fill(sg, sub_a, make_bf16(5.0));
187+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
193188

194189
auto wi_slice_a = get_wi_data(sg, sub_a);
195190
for (int i = 0; i < wi_slice_a.length(); i++) {
196191
if (wi_slice_a[i]) {
197-
if (wi_slice_a[i] > make_bf16(2.0) ||
198-
wi_slice_a[i] >= make_bf16(2.0) ||
199-
wi_slice_a[i] < make_bf16(2.0) ||
200-
wi_slice_a[i] <= make_bf16(2.0)) {
201-
T val = (wi_slice_a[i] != make_bf16(2.0)) ? wi_slice_a[i]
202-
: make_bf16(2.0);
203-
val = make_bf16(make_fp32(val) - static_cast<float>(1));
204-
val = make_bf16(make_fp32(val) + static_cast<float>(1));
205-
if (wi_slice_a[i] == make_bf16(2.0)) {
206-
val = make_bf16(make_fp32(val) - static_cast<float>(2));
207-
val = make_bf16(make_fp32(val) * static_cast<float>(3));
208-
val = make_bf16(make_fp32(val) / static_cast<float>(2));
192+
if (wi_slice_a[i] > bfloat16(2.0) ||
193+
wi_slice_a[i] >= bfloat16(2.0) ||
194+
wi_slice_a[i] < bfloat16(2.0) ||
195+
wi_slice_a[i] <= bfloat16(2.0)) {
196+
T val = (wi_slice_a[i] != bfloat16(2.0)) ? wi_slice_a[i]
197+
: bfloat16(2.0);
198+
val = bfloat16(make_fp32(val) - static_cast<float>(1));
199+
val = bfloat16(make_fp32(val) + static_cast<float>(1));
200+
if (wi_slice_a[i] == bfloat16(2.0)) {
201+
val = bfloat16(make_fp32(val) - static_cast<float>(2));
202+
val = bfloat16(make_fp32(val) * static_cast<float>(3));
203+
val = bfloat16(make_fp32(val) / static_cast<float>(2));
209204

210205
} else {
211-
val = make_bf16(make_fp32(val) + static_cast<float>(2));
206+
val = bfloat16(make_fp32(val) + static_cast<float>(2));
212207
}
213208
wi_slice_a[i] = val;
214209
}
@@ -226,7 +221,7 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
226221

227222
static constexpr size_t MATRIX_M = TM * 2;
228223
static constexpr size_t MATRIX_N = TN * 2;
229-
unsigned short A[MATRIX_M][MATRIX_N];
224+
bfloat16 A[MATRIX_M][MATRIX_N];
230225
float D[MATRIX_M][MATRIX_N];
231226

232227
void matrix_ops_ref(float *D, int M, int N) {
@@ -240,18 +235,18 @@ void matrix_ops_ref(float *D, int M, int N) {
240235
int main() {
241236

242237
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
243-
big_matrix<unsigned short, MATRIX_M, MATRIX_N> MA((unsigned short *)&A);
238+
big_matrix<bfloat16, MATRIX_M, MATRIX_N> MA((bfloat16 *)&A);
244239

245240
size_t NDRangeM = MATRIX_M / TM;
246241
size_t NDRangeN = MATRIX_N / TN;
247242
queue q;
248243
nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});
249244

250-
matrix_verify_add<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
251-
matrix_verify_sub<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 3.0);
252-
matrix_verify_mul<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 15.0);
253-
matrix_verify_div<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 2.0);
254-
matrix_verify_logic<unsigned short, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
245+
matrix_verify_add<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
246+
matrix_verify_sub<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 3.0);
247+
matrix_verify_mul<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 15.0);
248+
matrix_verify_div<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 2.0);
249+
matrix_verify_logic<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
255250

256251
return 0;
257252
}

SYCL/Matrix/elemwise_irreg_size_ops_bf16.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
using namespace sycl;
2020
using namespace sycl::ext::oneapi::experimental::matrix;
21+
using bfloat16 = sycl::ext::oneapi::bfloat16;
2122

2223
#define SG_SZ 16
2324

@@ -50,8 +51,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
5051
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2);
5152
size_t NDRangeM = M / TM;
5253
size_t NDRangeN = N / TN;
53-
buffer<unsigned short, 2> bufA(A.get_data(), range<2>(M, K));
54-
buffer<unsigned short, 2> bufB(B.get_data(), range<2>(K / 2, N * 2));
54+
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
55+
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K / 2, N * 2));
5556
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
5657

5758
queue q;
@@ -75,11 +76,10 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
7576
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
7677

7778
sub_group sg = spmd_item.get_sub_group();
78-
joint_matrix<sub_group, unsigned short, use::a, TM, TK,
79-
layout::row_major>
79+
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
8080
sub_a;
8181
// For B, we assume B has been already VNNIed.
82-
joint_matrix<sub_group, unsigned short, use::b, TK, TN,
82+
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
8383
ext::intel::experimental::matrix::layout::packed>
8484
sub_b;
8585
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
@@ -112,8 +112,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
112112
static constexpr size_t MATRIX_M = TM * 2;
113113
static constexpr size_t MATRIX_N = TN * 2;
114114
static constexpr size_t MATRIX_K = TK * 2;
115-
unsigned short A[MATRIX_M][MATRIX_K];
116-
unsigned short B[MATRIX_K / 2][MATRIX_N * 2];
115+
bfloat16 A[MATRIX_M][MATRIX_K];
116+
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
117117
float C[MATRIX_M][MATRIX_N];
118118
float D[MATRIX_M][MATRIX_N];
119119

@@ -124,12 +124,6 @@ float make_fp32(short x) {
124124
return *res;
125125
}
126126

127-
unsigned short make_bf16(float x) {
128-
int *res = reinterpret_cast<int *>(&x);
129-
*res = *res >> 16;
130-
return (unsigned short)*res;
131-
}
132-
133127
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
134128
int K) {
135129
// tiling
@@ -152,12 +146,12 @@ void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
152146
int main() {
153147
for (int i = 0; i < MATRIX_M; i++) {
154148
for (int j = 0; j < MATRIX_K; j++) {
155-
A[i][j] = make_bf16(1.0f * (i + j));
149+
A[i][j] = bfloat16(1.0f * (i + j));
156150
}
157151
}
158152
for (int i = 0; i < MATRIX_K / 2; i++) {
159153
for (int j = 0; j < MATRIX_N * 2; j++) {
160-
B[i][j] = make_bf16(2.0f * i + 3.0f * j);
154+
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
161155
}
162156
}
163157
for (int i = 0; i < MATRIX_M; i++) {
@@ -169,9 +163,8 @@ int main() {
169163

170164
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
171165
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
172-
big_matrix<unsigned short, MATRIX_M, MATRIX_K> MA((unsigned short *)&A);
173-
big_matrix<unsigned short, MATRIX_K / 2, MATRIX_N * 2> MB(
174-
(unsigned short *)&B);
166+
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
167+
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
175168
matrix_multiply(MC, MA, MB);
176169
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
177170
MATRIX_N, MATRIX_K / 2);

SYCL/Matrix/joint_matrix_bf16.cpp

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)