@@ -51,34 +51,45 @@ struct joint_matrix<
51
51
} // namespace experimental::matrix
52
52
53
53
namespace detail {
54
- using namespace experimental ;
55
54
56
- template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols,
57
- matrix::matrix_layout Layout, access::address_space Space,
58
- typename Cond = void >
55
+ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use MT,
56
+ size_t NumRows, size_t NumCols,
57
+ sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
58
+ access::address_space Space, typename Cond = void >
59
59
struct joint_matrix_load_impl {
60
- void load (matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res,
60
+ void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
61
+ T, MT, NumRows, NumCols, Layout> &res,
61
62
multi_ptr<T, Space> src, size_t stride);
62
63
};
63
64
64
- template <matrix::matrix_layout Layout> constexpr int get_layout_id ();
65
+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout>
66
+ constexpr int get_layout_id ();
65
67
66
- template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() {
68
+ template <>
69
+ constexpr int get_layout_id<
70
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
67
71
return 0 ;
68
72
}
69
73
70
- template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() {
74
+ template <>
75
+ constexpr int get_layout_id<
76
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
71
77
return 1 ;
72
78
}
73
79
74
- template <matrix::matrix_layout Layout, access::address_space Space>
80
+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
81
+ access::address_space Space>
75
82
struct joint_matrix_load_impl <
76
- double , matrix::matrix_use::a, 8 , 4 , Layout, Space,
77
- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
78
- Layout == matrix::matrix_layout::col_major>> {
79
- void
80
- load (matrix::joint_matrix<double , matrix::matrix_use::a, 8 , 4 , Layout> &res,
81
- multi_ptr<double , Space> src, size_t stride) {
83
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8 , 4 ,
84
+ Layout, Space,
85
+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
86
+ matrix::matrix_layout::row_major ||
87
+ Layout == sycl::ext::oneapi::experimental::
88
+ matrix::matrix_layout::col_major>> {
89
+ void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
90
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::a,
91
+ 8 , 4 , Layout> &res,
92
+ multi_ptr<double , Space> src, size_t stride) {
82
93
83
94
#ifdef __NVPTX__
84
95
#ifdef __SYCL_DEVICE_ONLY__
@@ -88,14 +99,19 @@ struct joint_matrix_load_impl<
88
99
}
89
100
};
90
101
91
- template <matrix::matrix_layout Layout, access::address_space Space>
102
+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
103
+ access::address_space Space>
92
104
struct joint_matrix_load_impl <
93
- double , matrix::matrix_use::b, 4 , 8 , Layout, Space,
94
- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
95
- Layout == matrix::matrix_layout::col_major>> {
96
- void
97
- load (matrix::joint_matrix<double , matrix::matrix_use::b, 4 , 8 , Layout> &res,
98
- multi_ptr<double , Space> src, size_t stride) {
105
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4 , 8 ,
106
+ Layout, Space,
107
+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
108
+ matrix::matrix_layout::row_major ||
109
+ Layout == sycl::ext::oneapi::experimental::
110
+ matrix::matrix_layout::col_major>> {
111
+ void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
112
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::b,
113
+ 4 , 8 , Layout> &res,
114
+ multi_ptr<double , Space> src, size_t stride) {
99
115
#ifdef __NVPTX__
100
116
#ifdef __SYCL_DEVICE_ONLY__
101
117
__dmma_m8n8k4_ld_b (res.data , src.get (), stride, get_layout_id<Layout>());
@@ -104,14 +120,21 @@ struct joint_matrix_load_impl<
104
120
}
105
121
};
106
122
107
- template <matrix::matrix_layout Layout, access::address_space Space>
123
+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
124
+ access::address_space Space>
108
125
struct joint_matrix_load_impl <
109
- double , matrix::matrix_use::accumulator, 8 , 8 , Layout, Space,
110
- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
111
- Layout == matrix::matrix_layout::col_major>> {
112
- void load (matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 ,
113
- Layout> &res,
114
- multi_ptr<double , Space> src, size_t stride) {
126
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
127
+ 8 , Layout, Space,
128
+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
129
+ matrix::matrix_layout::row_major ||
130
+ Layout == sycl::ext::oneapi::experimental::
131
+ matrix::matrix_layout::col_major>> {
132
+ void
133
+ load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
134
+ double ,
135
+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
136
+ 8 , Layout> &res,
137
+ multi_ptr<double , Space> src, size_t stride) {
115
138
116
139
#ifdef __NVPTX__
117
140
#ifdef __SYCL_DEVICE_ONLY__
@@ -122,22 +145,30 @@ struct joint_matrix_load_impl<
122
145
};
123
146
124
147
template <typename T, size_t NumRows, size_t NumCols,
125
- matrix::matrix_layout Layout, access::address_space Space ,
126
- typename Cond = void >
148
+ sycl::ext::oneapi::experimental:: matrix::matrix_layout Layout,
149
+ access::address_space Space, typename Cond = void >
127
150
struct joint_matrix_store_impl {
128
- void store (matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows,
129
- NumCols, Layout> &src,
130
- multi_ptr<T, Space> dst, size_t stride);
151
+ void
152
+ store (sycl::ext::oneapi::experimental::matrix::joint_matrix<
153
+ T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
154
+ NumRows, NumCols, Layout> &src,
155
+ multi_ptr<T, Space> dst, size_t stride);
131
156
};
132
157
133
- template <matrix::matrix_layout Layout, access::address_space Space>
158
+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout,
159
+ access::address_space Space>
134
160
struct joint_matrix_store_impl <
135
161
double , 8 , 8 , Layout, Space,
136
- typename std::enable_if_t <Layout == matrix::matrix_layout::row_major ||
137
- Layout == matrix::matrix_layout::col_major>> {
138
- void store (matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 ,
139
- Layout> &src,
140
- multi_ptr<double , Space> dst, size_t stride) {
162
+ typename std::enable_if_t <Layout == sycl::ext::oneapi::experimental::
163
+ matrix::matrix_layout::row_major ||
164
+ Layout == sycl::ext::oneapi::experimental::
165
+ matrix::matrix_layout::col_major>> {
166
+ void
167
+ store (sycl::ext::oneapi::experimental::matrix::joint_matrix<
168
+ double ,
169
+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
170
+ 8 , Layout> &src,
171
+ multi_ptr<double , Space> dst, size_t stride) {
141
172
142
173
#ifdef __NVPTX__
143
174
#ifdef __SYCL_DEVICE_ONLY__
@@ -149,60 +180,98 @@ struct joint_matrix_store_impl<
149
180
};
150
181
151
182
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
152
- matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
153
- matrix::matrix_layout LayoutC, typename Cond = void >
183
+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
184
+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
185
+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
186
+ typename Cond = void >
154
187
struct joint_matrix_mad_impl {
155
- matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
156
- mad (matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A,
157
- matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B,
158
- matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
188
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
189
+ T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
190
+ N, LayoutC>
191
+ mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
192
+ T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
193
+ LayoutA>
194
+ A,
195
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
196
+ T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
197
+ LayoutB>
198
+ B,
199
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
200
+ T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
201
+ M, N, LayoutC>
159
202
C);
160
203
};
161
204
162
- template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB>
205
+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
206
+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB>
163
207
constexpr int get_layout_pair_id ();
164
208
165
209
template <>
166
- constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
167
- matrix::matrix_layout::row_major>() {
210
+ constexpr int get_layout_pair_id<
211
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
212
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
168
213
return 0 ;
169
214
}
170
215
171
216
template <>
172
- constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
173
- matrix::matrix_layout::col_major>() {
217
+ constexpr int get_layout_pair_id<
218
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
219
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
174
220
return 1 ;
175
221
}
176
222
177
223
template <>
178
- constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
179
- matrix::matrix_layout::row_major>() {
224
+ constexpr int get_layout_pair_id<
225
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
226
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
180
227
return 2 ;
181
228
}
182
229
183
230
template <>
184
- constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
185
- matrix::matrix_layout::col_major>() {
231
+ constexpr int get_layout_pair_id<
232
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
233
+ sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
186
234
return 3 ;
187
235
}
188
236
189
- template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
190
- matrix::matrix_layout LayoutC>
237
+ template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
238
+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB,
239
+ sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC>
191
240
struct joint_matrix_mad_impl <
192
241
double , double , 8 , 4 , 8 , LayoutA, LayoutB, LayoutC,
193
- typename std::enable_if_t <(LayoutA == matrix::matrix_layout::row_major ||
194
- LayoutA == matrix::matrix_layout::col_major) &&
195
- (LayoutB == matrix::matrix_layout::row_major ||
196
- LayoutB == matrix::matrix_layout::col_major) &&
197
- (LayoutC == matrix::matrix_layout::row_major ||
198
- LayoutC == matrix::matrix_layout::col_major)>> {
199
- matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 , LayoutC>
200
- mad (matrix::joint_matrix<double , matrix::matrix_use::a, 8 , 4 , LayoutA> A,
201
- matrix::joint_matrix<double , matrix::matrix_use::b, 4 , 8 , LayoutB> B,
202
- matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 ,
203
- LayoutC>
242
+ typename std::enable_if_t <
243
+ (LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
244
+ row_major ||
245
+ LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
246
+ col_major) &&
247
+ (LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
248
+ row_major ||
249
+ LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
250
+ col_major) &&
251
+ (LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
252
+ row_major ||
253
+ LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
254
+ col_major)>> {
255
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
256
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
257
+ 8 , 8 , LayoutC>
258
+ mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
259
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8 , 4 ,
260
+ LayoutA>
261
+ A,
262
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
263
+ double , sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4 , 8 ,
264
+ LayoutB>
265
+ B,
266
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
267
+ double ,
268
+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 ,
269
+ 8 , LayoutC>
204
270
C) {
205
- matrix::joint_matrix<double , matrix::matrix_use::accumulator, 8 , 8 , LayoutC>
271
+ sycl::ext::oneapi::experimental::matrix::joint_matrix<
272
+ double ,
273
+ sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8 , 8 ,
274
+ LayoutC>
206
275
D;
207
276
208
277
#ifdef __NVPTX__
@@ -225,8 +294,9 @@ template <typename Group, typename T, matrix_use MT, size_t NumRows,
225
294
void joint_matrix_load (
226
295
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res,
227
296
multi_ptr<T, Space> src, size_t stride) {
228
- detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load (
229
- res, src, stride);
297
+ sycl::ext::oneapi::detail::joint_matrix_load_impl<T, MT, NumRows, NumCols,
298
+ Layout, Space>{}
299
+ .load (res, src, stride);
230
300
}
231
301
232
302
template <typename Group, typename T, size_t NumRows, size_t NumCols,
@@ -235,8 +305,9 @@ void joint_matrix_store(Group sg,
235
305
joint_matrix<T, matrix_use::accumulator, NumRows,
236
306
NumCols, Layout, Group> &src,
237
307
multi_ptr<T, Space> dst, size_t stride) {
238
- detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store (
239
- src, dst, stride);
308
+ sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
309
+ Layout, Space>{}
310
+ .store (src, dst, stride);
240
311
}
241
312
242
313
template <typename Group, typename T1, typename T2, std::size_t M,
@@ -247,8 +318,8 @@ joint_matrix_mad(
247
318
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
248
319
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
249
320
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
250
- return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB,
251
- LayoutC>{}
321
+ return sycl::ext::oneapi:: detail::joint_matrix_mad_impl<
322
+ T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
252
323
.mad (A, B, C);
253
324
}
254
325
0 commit comments