5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
-
9
- static float make_fp32 (bfloat16 x) {
10
- unsigned int y = *((int *)&x);
11
- y = y << 16 ;
12
- float *res = reinterpret_cast <float *>(&y);
13
- return *res;
14
- }
15
-
16
- template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
17
- public:
18
- T *mat;
19
-
20
- public:
21
- T *get_data () { return mat; }
22
- void set_data (T *data) { mat = data; }
23
- big_matrix (T *data) : mat(data) {}
24
- };
25
-
26
8
template <typename T, size_t NUM_ROWS, size_t NUM_COLS>
27
9
void assert_ops_ref (host_accessor<T, 2 , access::mode::read> mat,
28
10
const float ref) {
@@ -39,20 +21,25 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
39
21
}
40
22
41
23
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
42
- size_t SUB_COLS, typename OP>
24
+ size_t SUB_COLS, class kernel_name , typename OP>
43
25
void verify_op_a (const T l, const T r, const float ref, OP op) {
44
26
T mat[NUM_ROWS][NUM_COLS];
45
27
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat ((T *)&mat);
46
28
47
29
buffer<T, 2 > bufMat (big_mat.get_data (), range<2 >(NUM_ROWS, NUM_COLS));
48
30
49
31
queue q;
32
+ size_t sg_size = get_sg_size<kernel_name>(q);
50
33
q.submit ([&](handler &cgh) {
51
34
sycl::accessor accessMat{bufMat, cgh, sycl::read_write};
52
- cgh.parallel_for (
53
- nd_range<2 >({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * SG_SZ},
54
- {1 , 1 * SG_SZ}),
55
- [=](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
35
+ cgh.parallel_for <kernel_name>(
36
+ nd_range<2 >({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * sg_size},
37
+ {1 , 1 * sg_size}),
38
+ [=](nd_item<2 > spmd_item)
39
+ #ifdef SG_SZ
40
+ [[intel::reqd_sub_group_size (SG_SZ)]]
41
+ #endif
42
+ {
56
43
const auto global_idx = spmd_item.get_global_id (0 );
57
44
const auto global_idy = spmd_item.get_global_id (1 );
58
45
const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
@@ -68,28 +55,32 @@ void verify_op_a(const T l, const T r, const float ref, OP op) {
68
55
sg, sub_mat,
69
56
accessMat.template get_multi_ptr <access::decorated::no>() +
70
57
(sg_startx * SUB_ROWS) * NUM_COLS +
71
- sg_starty / SG_SZ * SUB_COLS,
58
+ sg_starty / sg_size * SUB_COLS,
72
59
NUM_COLS);
73
60
}); // parallel for
74
61
}).wait ();
75
62
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access (read_only), ref);
76
63
}
77
64
78
65
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
79
- size_t SUB_COLS, typename OP>
66
+ size_t SUB_COLS, class kernel_name , typename OP>
80
67
void verify_op_c (const T l, const T r, const float ref, OP op) {
81
68
T mat[NUM_ROWS][NUM_COLS];
82
69
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat ((T *)&mat);
83
70
84
71
buffer<T, 2 > bufMat (big_mat.get_data (), range<2 >(NUM_ROWS, NUM_COLS));
85
-
86
72
queue q;
73
+ size_t sg_size = get_sg_size<kernel_name>(q);
87
74
q.submit ([&](handler &cgh) {
88
75
sycl::accessor accessMat{bufMat, cgh, sycl::read_write};
89
- cgh.parallel_for (
90
- nd_range<2 >({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * SG_SZ},
91
- {1 , 1 * SG_SZ}),
92
- [=](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
76
+ cgh.parallel_for <kernel_name>(
77
+ nd_range<2 >({NUM_ROWS / SUB_ROWS, NUM_COLS / SUB_COLS * sg_size},
78
+ {1 , 1 * sg_size}),
79
+ [=](nd_item<2 > spmd_item)
80
+ #ifdef SG_SZ
81
+ [[intel::reqd_sub_group_size (SG_SZ)]]
82
+ #endif
83
+ {
93
84
const auto global_idx = spmd_item.get_global_id (0 );
94
85
const auto global_idy = spmd_item.get_global_id (1 );
95
86
const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
@@ -105,85 +96,103 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
105
96
sg, sub_mat,
106
97
accessMat.template get_multi_ptr <access::decorated::no>() +
107
98
(sg_startx * SUB_ROWS) * NUM_COLS +
108
- sg_starty / SG_SZ * SUB_COLS,
99
+ sg_starty / sg_size * SUB_COLS,
109
100
NUM_COLS, layout::row_major);
110
101
}); // parallel for
111
102
}).wait ();
112
103
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access (read_only), ref);
113
104
}
114
105
106
+ // Avoid same kernel name for different types
107
+ template <typename T, class name > class ewops_a {};
115
108
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
116
109
void test_ewops_a () {
117
110
118
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
111
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add > >(
119
112
T (5.0 ), T (2.0 ), 7.0 , [](auto l, auto r) { return l + r; });
120
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
113
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_sub > >(
121
114
T (5.0 ), T (2.0 ), 3.0 , [](auto l, auto r) { return l - r; });
122
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
115
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_mul > >(
123
116
T (5.0 ), T (2.0 ), 10.0 , [](auto l, auto r) { return l * r; });
124
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
117
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_div > >(
125
118
T (5.0 ), T (2.0 ), 2.5 , [](auto l, auto r) { return l / r; });
126
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
119
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_logical > >(
127
120
T (5.0 ), T (5.0 ), 5.0 , [](auto l, auto r) { return l == r ? l : T (1.0 ); });
128
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
121
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_eq > >(
129
122
T (5.0 ), T (4.0 ), 4.0 , [](auto l, auto r) { return l == r ? l : r; });
130
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
123
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ne > >(
131
124
T (5.0 ), T (5.0 ), 1.0 , [](auto l, auto r) { return l != r ? l : T (1.0 ); });
132
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
125
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_gt > >(
133
126
T (5.0 ), T (2.0 ), 3.0 ,
134
127
[](auto l, auto r) { return l > r ? T (3.0 ) : T (2.0 ); });
135
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
128
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_lt > >(
136
129
T (5.0 ), T (2.0 ), 2.0 ,
137
130
[](auto l, auto r) { return l < r ? T (3.0 ) : T (2.0 ); });
138
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
131
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ge > >(
139
132
T (5.0 ), T (2.0 ), 3.0 ,
140
133
[](auto l, auto r) { return l >= r ? T (3.0 ) : T (2.0 ); });
141
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS>(
134
+ verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_le > >(
142
135
T (5.0 ), T (2.0 ), 2.0 ,
143
136
[](auto l, auto r) { return l <= r ? T (3.0 ) : T (2.0 ); });
144
137
}
145
-
138
+ // Avoid same kernel name for different types and numbers of columns
139
+ template <typename T, size_t COLS, class name > class ewops_c {};
146
140
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
147
141
void test_ewops_c () {
148
142
149
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
143
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_add > >(
150
144
T (5.0 ), T (2.0 ), 7.0 , [](auto l, auto r) { return l + r; });
151
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
145
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_sub > >(
152
146
T (5.0 ), T (2.0 ), 3.0 , [](auto l, auto r) { return l - r; });
153
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
147
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_mul > >(
154
148
T (5.0 ), T (2.0 ), 10.0 , [](auto l, auto r) { return l * r; });
155
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
149
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_div > >(
156
150
T (5.0 ), T (2.0 ), 2.5 , [](auto l, auto r) { return l / r; });
157
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
151
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
152
+ ewops_c<T, SCOLS, class c_logical >>(
158
153
T (5.0 ), T (5.0 ), 5.0 , [](auto l, auto r) { return l == r ? l : T (1.0 ); });
159
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
154
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_eq > >(
160
155
T (5.0 ), T (4.0 ), 4.0 , [](auto l, auto r) { return l == r ? l : r; });
161
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
156
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ne > >(
162
157
T (5.0 ), T (5.0 ), 1.0 , [](auto l, auto r) { return l != r ? l : T (1.0 ); });
163
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
158
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_gt > >(
164
159
T (5.0 ), T (2.0 ), 3.0 ,
165
160
[](auto l, auto r) { return l > r ? T (3.0 ) : T (2.0 ); });
166
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
161
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_lt > >(
167
162
T (5.0 ), T (2.0 ), 2.0 ,
168
163
[](auto l, auto r) { return l < r ? T (3.0 ) : T (2.0 ); });
169
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
164
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ge > >(
170
165
T (5.0 ), T (2.0 ), 3.0 ,
171
166
[](auto l, auto r) { return l >= r ? T (3.0 ) : T (2.0 ); });
172
- verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS>(
167
+ verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_le > >(
173
168
T (5.0 ), T (2.0 ), 2.0 ,
174
169
[](auto l, auto r) { return l <= r ? T (3.0 ) : T (2.0 ); });
175
170
}
176
171
177
172
int main () {
178
173
static constexpr size_t TM = 8 ;
179
- static constexpr size_t TK = 16 ;
180
174
181
175
static constexpr size_t MATRIX_M = TM * 2 ;
182
- static constexpr size_t MATRIX_N = TN * 2 ;
183
- static constexpr size_t MATRIX_K = TK * 2 ;
184
-
185
- test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, TK>();
186
- test_ewops_c<float , MATRIX_M, MATRIX_N, TM, TN>();
187
-
176
+ static constexpr size_t MATRIX_N = 32 ;
177
+ static constexpr size_t MATRIX_K = 32 ;
178
+ queue q;
179
+ std::vector<combination> combinations =
180
+ q.get_device ()
181
+ .get_info <sycl::ext::oneapi::experimental::info::device::
182
+ matrix_combinations>();
183
+ for (unsigned int i = 0 ; i < combinations.size (); i++) {
184
+ if (combinations[i].atype == matrix_type::bf16 ) {
185
+ if (combinations[i].nsize == 0 || combinations[i].nsize == 16 ) {
186
+ test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16 >();
187
+ test_ewops_c<float , MATRIX_M, MATRIX_N, TM, 16 >();
188
+ break ;
189
+ }
190
+ if (combinations[i].nsize == 8 ) {
191
+ test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16 >();
192
+ test_ewops_c<float , MATRIX_M, MATRIX_N, TM, 8 >();
193
+ break ;
194
+ }
195
+ }
196
+ }
188
197
return 0 ;
189
198
}
0 commit comments