@@ -170,7 +170,7 @@ get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
170
170
return 1 ;
171
171
}
172
172
173
- template <sycl::ext::oneapi::experimental::matrix::layout LayoutL , typename S,
173
+ template <sycl::ext::oneapi::experimental::matrix::layout Layout , typename S,
174
174
typename T, size_t NumRows, size_t NumCols,
175
175
access::address_space Space>
176
176
void load_accumulator_layoutT (
@@ -183,42 +183,42 @@ void load_accumulator_layoutT(
183
183
auto destptr = reinterpret_cast <int32_t *>(&res.wi_marray );
184
184
if constexpr (NumRows == 16 && NumCols == 16 ) {
185
185
__imma_m16n16k16_ld_c (destptr, src.get (), stride,
186
- get_layout_id<LayoutL >());
186
+ get_layout_id<Layout >());
187
187
} else if constexpr (NumRows == 8 && NumCols == 32 ) {
188
188
__imma_m8n32k16_ld_c (destptr, src.get (), stride,
189
- get_layout_id<LayoutL >());
189
+ get_layout_id<Layout >());
190
190
} else if constexpr (NumRows == 32 && NumCols == 8 ) {
191
191
__imma_m32n8k16_ld_c (destptr, src.get (), stride,
192
- get_layout_id<LayoutL >());
192
+ get_layout_id<Layout >());
193
193
}
194
194
} else if constexpr (std::is_same_v<S, float >) {
195
195
auto dstptr = reinterpret_cast <float *>(&res.wi_marray );
196
196
if constexpr (NumRows == 16 && NumCols == 16 ) {
197
197
__hmma_m16n16k16_ld_c_f32 (dstptr, src.get (), stride,
198
- get_layout_id<LayoutL >());
198
+ get_layout_id<Layout >());
199
199
} else if constexpr (NumRows == 8 && NumCols == 32 ) {
200
200
__hmma_m8n32k16_ld_c_f32 (dstptr, src.get (), stride,
201
- get_layout_id<LayoutL >());
201
+ get_layout_id<Layout >());
202
202
} else if constexpr (NumRows == 32 && NumCols == 8 ) {
203
203
__hmma_m32n8k16_ld_c_f32 (dstptr, src.get (), stride,
204
- get_layout_id<LayoutL >());
204
+ get_layout_id<Layout >());
205
205
}
206
206
} else if constexpr (std::is_same_v<S, half>) {
207
- auto tileptr = reinterpret_cast <int32_t const *>(src.get ());
207
+ auto tileptr = reinterpret_cast <const int32_t *>(src.get ());
208
208
auto dstptr = reinterpret_cast <int32_t *>(&res.wi_marray );
209
209
if constexpr (NumRows == 32 && NumCols == 8 ) {
210
210
__hmma_m32n8k16_ld_c_f16 (dstptr, tileptr, stride,
211
- get_layout_id<LayoutL >());
211
+ get_layout_id<Layout >());
212
212
} else if constexpr (NumRows == 8 && NumCols == 32 ) {
213
213
__hmma_m8n32k16_ld_c_f16 (dstptr, tileptr, stride,
214
- get_layout_id<LayoutL >());
214
+ get_layout_id<Layout >());
215
215
} else if constexpr (NumRows == 16 && NumCols == 16 ) {
216
216
__hmma_m16n16k16_ld_c_f16 (dstptr, tileptr, stride,
217
- get_layout_id<LayoutL >());
217
+ get_layout_id<Layout >());
218
218
}
219
219
} else if constexpr (std::is_same_v<S, double >) {
220
220
__dmma_m8n8k4_ld_c (reinterpret_cast <double *>(&res.wi_marray ), src.get (),
221
- stride, get_layout_id<LayoutL >());
221
+ stride, get_layout_id<Layout >());
222
222
}
223
223
};
224
224
@@ -230,8 +230,8 @@ void load_accumulator_cuda(
230
230
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
231
231
sycl::sub_group> &res,
232
232
multi_ptr<T, Space> src, size_t stride,
233
- sycl::ext::oneapi::experimental::matrix::layout LayoutAcc ) {
234
- switch (LayoutAcc ) {
233
+ sycl::ext::oneapi::experimental::matrix::layout Layout ) {
234
+ switch (Layout ) {
235
235
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
236
236
load_accumulator_layoutT<
237
237
sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src,
@@ -379,9 +379,9 @@ void load_multiplicand_cuda(
379
379
}
380
380
}
381
381
382
- template <sycl::ext::oneapi::experimental::matrix::layout LayoutL , typename T,
382
+ template <sycl::ext::oneapi::experimental::matrix::layout Layout , typename T,
383
383
size_t NumRows, size_t NumCols, access::address_space Space>
384
- void storeLayoutT (
384
+ void store_layoutT (
385
385
sycl::ext::oneapi::experimental::matrix::joint_matrix<
386
386
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
387
387
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
@@ -391,48 +391,48 @@ void storeLayoutT(
391
391
if constexpr (std::is_same_v<T, float >) {
392
392
__hmma_m16n16k16_st_c_f32 (dst.get (),
393
393
reinterpret_cast <float *>(&src.wi_marray ),
394
- stride, get_layout_id<LayoutL >());
394
+ stride, get_layout_id<Layout >());
395
395
} else if constexpr (std::is_same_v<T, int32_t >) {
396
396
__imma_m16n16k16_st_c_i32 (dst.get (),
397
397
reinterpret_cast <int32_t *>(&src.wi_marray ),
398
- stride, get_layout_id<LayoutL >());
398
+ stride, get_layout_id<Layout >());
399
399
} else if constexpr (std::is_same_v<T, half>) {
400
400
__hmma_m16n16k16_st_c_f16 (reinterpret_cast <int32_t *>(dst.get ()),
401
401
reinterpret_cast <int32_t *>(&src.wi_marray ),
402
- stride, get_layout_id<LayoutL >());
402
+ stride, get_layout_id<Layout >());
403
403
}
404
404
} else if constexpr (NumRows == 8 && NumCols == 32 ) {
405
405
if constexpr (std::is_same_v<T, float >) {
406
406
__hmma_m8n32k16_st_c_f32 (dst.get (),
407
407
reinterpret_cast <float *>(&src.wi_marray ),
408
- stride, get_layout_id<LayoutL >());
408
+ stride, get_layout_id<Layout >());
409
409
} else if constexpr (std::is_same_v<T, int32_t >) {
410
410
__imma_m8n32k16_st_c_i32 (dst.get (),
411
411
reinterpret_cast <int32_t *>(&src.wi_marray ),
412
- stride, get_layout_id<LayoutL >());
412
+ stride, get_layout_id<Layout >());
413
413
} else if constexpr (std::is_same_v<T, half>) {
414
414
__hmma_m8n32k16_st_c_f16 (reinterpret_cast <int32_t *>(dst.get ()),
415
415
reinterpret_cast <int32_t *>(&src.wi_marray ),
416
- stride, get_layout_id<LayoutL >());
416
+ stride, get_layout_id<Layout >());
417
417
}
418
418
} else if constexpr (NumRows == 32 && NumCols == 8 ) {
419
419
if constexpr (std::is_same_v<T, float >) {
420
420
__hmma_m32n8k16_st_c_f32 (dst.get (),
421
421
reinterpret_cast <float *>(&src.wi_marray ),
422
- stride, get_layout_id<LayoutL >());
422
+ stride, get_layout_id<Layout >());
423
423
} else if constexpr (std::is_same_v<T, int32_t >) {
424
424
__imma_m32n8k16_st_c_i32 (dst.get (),
425
425
reinterpret_cast <int32_t *>(&src.wi_marray ),
426
- stride, get_layout_id<LayoutL >());
426
+ stride, get_layout_id<Layout >());
427
427
} else if constexpr (std::is_same_v<T, half>) {
428
428
__hmma_m32n8k16_st_c_f16 (reinterpret_cast <int32_t *>(dst.get ()),
429
429
reinterpret_cast <int32_t *>(&src.wi_marray ),
430
- stride, get_layout_id<LayoutL >());
430
+ stride, get_layout_id<Layout >());
431
431
}
432
432
} else if constexpr (std::is_same_v<T, double >) {
433
433
__dmma_m8n8k4_st_c_f64 (dst.get (),
434
434
reinterpret_cast <double *>(&src.wi_marray ), stride,
435
- get_layout_id<LayoutL >());
435
+ get_layout_id<Layout >());
436
436
}
437
437
}
438
438
@@ -444,14 +444,14 @@ void joint_matrix_store_cuda(
444
444
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
445
445
sycl::sub_group> &src,
446
446
multi_ptr<T, Space> dst, size_t stride,
447
- sycl::ext::oneapi::experimental::matrix::layout LayoutAcc ) {
448
- switch (LayoutAcc ) {
447
+ sycl::ext::oneapi::experimental::matrix::layout Layout ) {
448
+ switch (Layout ) {
449
449
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
450
- storeLayoutT <sycl::ext::oneapi::experimental::matrix::layout::row_major>(
450
+ store_layoutT <sycl::ext::oneapi::experimental::matrix::layout::row_major>(
451
451
src, dst, stride);
452
452
break ;
453
453
case sycl::ext::oneapi::experimental::matrix::layout::col_major:
454
- storeLayoutT <sycl::ext::oneapi::experimental::matrix::layout::col_major>(
454
+ store_layoutT <sycl::ext::oneapi::experimental::matrix::layout::col_major>(
455
455
src, dst, stride);
456
456
break ;
457
457
default :
0 commit comments