Skip to content

Commit 3656d7d

Browse files
dkhaldibb-sycl
authored andcommitted
[SYCL] Add matrix tests that use the new API (unified API) (intel#1391)
This PR adapts t the matrix tests to the new API (unified API) and move the old API tests to a new folder "Legacy".
1 parent 3bbcbb9 commit 3656d7d

20 files changed

+343
-2
lines changed

SYCL/Matrix/Legacy/XMX8/joint_matrix_bf16.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
//===----------------------------------------------------------------------===//
88
// REQUIRES: matrix-xmx8
99

10+
<<<<<<<< HEAD:SYCL/Matrix/Legacy/XMX8/joint_matrix_bf16.cpp
1011
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1
12+
========
13+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
14+
>>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391)):SYCL/Matrix/XMX8/joint_matrix_bf16.cpp
1115
// RUN: %CPU_RUN_PLACEHOLDER %t.out
1216
// RUN: %GPU_RUN_PLACEHOLDER %t.out
1317

SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
//===----------------------------------------------------------------------===//
88
// REQUIRES: matrix-xmx8
99

10+
<<<<<<< HEAD
11+
<<<<<<<< HEAD:SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16.cpp
1012
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1
13+
========
14+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
15+
>>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391)):SYCL/Matrix/XMX8/joint_matrix_bfloat16_use.cpp
16+
=======
17+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1
18+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
1119
// RUN: %CPU_RUN_PLACEHOLDER %t.out
1220
// RUN: %GPU_RUN_PLACEHOLDER %t.out
1321

@@ -16,7 +24,11 @@
1624

1725
using namespace sycl;
1826
using namespace sycl::ext::oneapi::experimental::matrix;
27+
<<<<<<< HEAD
1928
using bfloat16 = sycl::ext::oneapi::bfloat16;
29+
=======
30+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
31+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
2032

2133
#define SG_SZ 8
2234

SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16_32x64.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
using namespace sycl;
2020
using namespace sycl::ext::oneapi::experimental::matrix;
21+
<<<<<<< HEAD
2122
using bfloat16 = sycl::ext::oneapi::bfloat16;
23+
=======
24+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
25+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
2226

2327
#define SG_SZ 8
2428

SYCL/Matrix/Legacy/element_wise_all_ops_bf16.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
// RUN: %CPU_RUN_PLACEHOLDER %t.out
1212
// RUN: %GPU_RUN_PLACEHOLDER %t.out
1313

14+
<<<<<<< HEAD
1415

16+
=======
17+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
1518
#include <iostream>
1619
#include <random>
1720
#include <sycl/sycl.hpp>

SYCL/Matrix/Legacy/element_wise_all_ops_half_impl.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
4141
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
4242
joint_matrix<T, TM, TK> sub_a(sg);
4343

44+
<<<<<<< HEAD
4445
joint_matrix_fill(sg, sub_a, 5);
46+
=======
47+
joint_matrix_fill(sg, sub_a, 5.0);
48+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
4549

4650
auto wi_slice_a = sub_a.get_wi_data();
4751
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -74,7 +78,11 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7478
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
7579
joint_matrix<T, TM, TK> sub_a(sg);
7680

81+
<<<<<<< HEAD
7782
joint_matrix_fill(sg, sub_a, 5);
83+
=======
84+
joint_matrix_fill(sg, sub_a, 5.0);
85+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
7886

7987
auto wi_slice_a = sub_a.get_wi_data();
8088
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -107,7 +115,11 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
107115
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
108116
joint_matrix<T, TM, TK> sub_a(sg);
109117

118+
<<<<<<< HEAD
110119
joint_matrix_fill(sg, sub_a, 5);
120+
=======
121+
joint_matrix_fill(sg, sub_a, 5.0);
122+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
111123

112124
auto wi_slice_a = sub_a.get_wi_data();
113125
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -140,7 +152,11 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
140152
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
141153
joint_matrix<T, TM, TK> sub_a(sg);
142154

155+
<<<<<<< HEAD
143156
joint_matrix_fill(sg, sub_a, 4);
157+
=======
158+
joint_matrix_fill(sg, sub_a, 4.0);
159+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
144160

145161
auto wi_slice_a = sub_a.get_wi_data();
146162
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -173,7 +189,11 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
173189
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
174190
joint_matrix<T, TM, TK> sub_a(sg);
175191

192+
<<<<<<< HEAD
176193
joint_matrix_fill(sg, sub_a, 5);
194+
=======
195+
joint_matrix_fill(sg, sub_a, 5.0);
196+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
177197

178198
auto wi_slice_a = sub_a.get_wi_data();
179199
for (int i = 0; i < wi_slice_a.length(); i++) {
@@ -189,8 +209,13 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
189209
val++;
190210
if (wi_slice_a[i] == static_cast<half>(2.0)) {
191211
val -= 2;
212+
<<<<<<< HEAD
192213
val *= 3;
193214
val /= 2;
215+
=======
216+
val *= 3.0;
217+
val /= 2.0;
218+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
194219
} else {
195220
val += 2;
196221
}

SYCL/Matrix/Legacy/joint_matrix_bf16.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
//===----------------------------------------------------------------------===//
88
// REQUIRES: matrix
99

10+
<<<<<<<< HEAD:SYCL/Matrix/Legacy/joint_matrix_bf16.cpp
1011
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1
12+
========
13+
// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
14+
>>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391)):SYCL/Matrix/joint_matrix_bf16.cpp
1115
// RUN: %CPU_RUN_PLACEHOLDER %t.out
1216
// RUN: %GPU_RUN_PLACEHOLDER %t.out
1317

SYCL/Matrix/Legacy/joint_matrix_bf16_impl.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,27 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
5252
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
5353

5454
sub_group sg = spmd_item.get_sub_group();
55+
<<<<<<< HEAD
56+
joint_matrix<sub_group, unsigned short, use::a, TM, TK,
57+
layout::row_major>
58+
sub_a;
59+
// For B, we assume B has been already VNNIed.
60+
joint_matrix<sub_group, unsigned short, use::b, TK, TN,
61+
ext::intel::experimental::matrix::layout::packed>
62+
sub_b;
63+
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
64+
joint_matrix_load(sg, sub_c,
65+
accC.get_pointer() + (sg_startx * TM) * N +
66+
sg_starty / SG_SZ * TN,
67+
N, layout::row_major);
68+
for (int k = 0; k < K; k += TK) {
69+
joint_matrix_load(
70+
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k, K);
71+
joint_matrix_load(sg, sub_b,
72+
accB.get_pointer() + k * N +
73+
sg_starty / SG_SZ * TN * 2,
74+
N * 2);
75+
=======
5576
joint_matrix<unsigned short, TM, TK> sub_a(sg);
5677
// For B, since current implementation does not support non-packed
5778
// layout, users need to specify the packed_b layout.
@@ -72,12 +93,17 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
7293
accB.get_pointer() + (k) * (N) +
7394
sg_starty / SG_SZ * TN * 2,
7495
N * 2, matrix_layout::packed_b);
96+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
7597
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
7698
}
7799
joint_matrix_store(sg, sub_c,
78100
accC.get_pointer() + (sg_startx * TM) * N +
79101
sg_starty / SG_SZ * TN,
102+
<<<<<<< HEAD
103+
N, layout::row_major);
104+
=======
80105
N, matrix_layout::row_major);
106+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
81107
}); // parallel for
82108
}).wait();
83109
}
@@ -105,14 +131,20 @@ unsigned short make_bf16(float x) {
105131

106132
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
107133
int K) {
134+
<<<<<<< HEAD
135+
=======
108136
// tiling
137+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
109138
for (int m = 0; m < M; m++)
110139
for (int n = 0; n < N; n++) {
111140
for (int k = 0; k < K; k++) {
112141
short *va = (short *)(A_mem + m * K + k);
113142
short *vb = (short *)(B_mem + k * N + n);
114143
float acc = *((float *)(C_mem + m * N + n));
144+
<<<<<<< HEAD
145+
=======
115146
// FIXME: Should we do reduce-add in another version?
147+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
116148
for (int i = 0; i < 2; i++) {
117149
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
118150
}
@@ -155,6 +187,13 @@ int main() {
155187
res = false;
156188
}
157189
}
190+
<<<<<<< HEAD
158191
std::cout << (res ? "passed" : "failed") << std::endl;
159192
return !res;
193+
=======
194+
if (res)
195+
std::cout << "passed\n";
196+
else
197+
std::cout << "failed\n";
198+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
160199
}

SYCL/Matrix/Legacy/joint_matrix_bfloat16.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
using namespace sycl;
1818
using namespace sycl::ext::oneapi::experimental::matrix;
19+
<<<<<<< HEAD
1920
using bfloat16 = sycl::ext::oneapi::bfloat16;
21+
=======
22+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
23+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
2024

2125
#define SG_SZ 16
2226

SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1+
<<<<<<< HEAD
2+
<<<<<<<< HEAD:SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64.cpp
13
//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==//
4+
========
5+
//==-------- joint_matrix_bf16_vnni.cpp - DPC++ joint_matrix---------------==//
6+
>>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391)):SYCL/Matrix/Legacy/joint_matrix_int8_vnni.cpp
7+
=======
8+
//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==//
9+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
210
//
311
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
412
// See https://llvm.org/LICENSE.txt for license information.
@@ -18,8 +26,19 @@
1826

1927
using namespace sycl;
2028
using namespace sycl::ext::oneapi::experimental::matrix;
21-
using bfloat16 = sycl::ext::oneapi::bfloat16;
29+
<<<<<<< HEAD
30+
31+
#define SG_SZ 16
32+
33+
<<<<<<<< HEAD:SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64.cpp
34+
#include "joint_matrix_bfloat16_32x64_impl.hpp"
35+
========
36+
#include "joint_matrix_int8_vnni_impl.hpp"
37+
>>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391)):SYCL/Matrix/Legacy/joint_matrix_int8_vnni.cpp
38+
=======
39+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
2240

2341
#define SG_SZ 16
2442

2543
#include "joint_matrix_bfloat16_32x64_impl.hpp"
44+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))

SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,21 @@ int main() {
120120
for (int j = 0; j < MATRIX_K; j++) {
121121
// bfloat16 is created using unsigned short since conversion from float to
122122
// bfloat16 is not supported on the host side yet
123+
<<<<<<< HEAD
123124
A[i][j] = bfloat16(1.0f * (i + j));
125+
=======
126+
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
127+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
124128
Aref[i][j] = make_bf16(1.0f * (i + j));
125129
}
126130
}
127131
for (int i = 0; i < MATRIX_K / 2; i++) {
128132
for (int j = 0; j < MATRIX_N * 2; j++) {
133+
<<<<<<< HEAD
129134
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
135+
=======
136+
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
137+
>>>>>>> cbbfcc6c1 ([SYCL] Add matrix tests that use the new API (unified API) (#1391))
130138
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
131139
}
132140
}

0 commit comments

Comments
 (0)