Skip to content

Commit 3dbeadb

Browse files
committed
updated device code tests.
Used consistant naming convention in impl. Signed-off-by: JackAKirk <[email protected]>
1 parent b9a051f commit 3dbeadb

9 files changed

+226
-260
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
170170
return 1;
171171
}
172172

173-
template <sycl::ext::oneapi::experimental::matrix::layout LayoutL, typename S,
173+
template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename S,
174174
typename T, size_t NumRows, size_t NumCols,
175175
access::address_space Space>
176176
void load_accumulator_layoutT(
@@ -183,42 +183,42 @@ void load_accumulator_layoutT(
183183
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
184184
if constexpr (NumRows == 16 && NumCols == 16) {
185185
__imma_m16n16k16_ld_c(destptr, src.get(), stride,
186-
get_layout_id<LayoutL>());
186+
get_layout_id<Layout>());
187187
} else if constexpr (NumRows == 8 && NumCols == 32) {
188188
__imma_m8n32k16_ld_c(destptr, src.get(), stride,
189-
get_layout_id<LayoutL>());
189+
get_layout_id<Layout>());
190190
} else if constexpr (NumRows == 32 && NumCols == 8) {
191191
__imma_m32n8k16_ld_c(destptr, src.get(), stride,
192-
get_layout_id<LayoutL>());
192+
get_layout_id<Layout>());
193193
}
194194
} else if constexpr (std::is_same_v<S, float>) {
195195
auto dstptr = reinterpret_cast<float *>(&res.wi_marray);
196196
if constexpr (NumRows == 16 && NumCols == 16) {
197197
__hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride,
198-
get_layout_id<LayoutL>());
198+
get_layout_id<Layout>());
199199
} else if constexpr (NumRows == 8 && NumCols == 32) {
200200
__hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride,
201-
get_layout_id<LayoutL>());
201+
get_layout_id<Layout>());
202202
} else if constexpr (NumRows == 32 && NumCols == 8) {
203203
__hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride,
204-
get_layout_id<LayoutL>());
204+
get_layout_id<Layout>());
205205
}
206206
} else if constexpr (std::is_same_v<S, half>) {
207-
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
207+
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
208208
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
209209
if constexpr (NumRows == 32 && NumCols == 8) {
210210
__hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride,
211-
get_layout_id<LayoutL>());
211+
get_layout_id<Layout>());
212212
} else if constexpr (NumRows == 8 && NumCols == 32) {
213213
__hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride,
214-
get_layout_id<LayoutL>());
214+
get_layout_id<Layout>());
215215
} else if constexpr (NumRows == 16 && NumCols == 16) {
216216
__hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride,
217-
get_layout_id<LayoutL>());
217+
get_layout_id<Layout>());
218218
}
219219
} else if constexpr (std::is_same_v<S, double>) {
220220
__dmma_m8n8k4_ld_c(reinterpret_cast<double *>(&res.wi_marray), src.get(),
221-
stride, get_layout_id<LayoutL>());
221+
stride, get_layout_id<Layout>());
222222
}
223223
};
224224

@@ -230,8 +230,8 @@ void load_accumulator_cuda(
230230
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
231231
sycl::sub_group> &res,
232232
multi_ptr<T, Space> src, size_t stride,
233-
sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) {
234-
switch (LayoutAcc) {
233+
sycl::ext::oneapi::experimental::matrix::layout Layout) {
234+
switch (Layout) {
235235
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
236236
load_accumulator_layoutT<
237237
sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src,
@@ -379,9 +379,9 @@ void load_multiplicand_cuda(
379379
}
380380
}
381381

382-
template <sycl::ext::oneapi::experimental::matrix::layout LayoutL, typename T,
382+
template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename T,
383383
size_t NumRows, size_t NumCols, access::address_space Space>
384-
void storeLayoutT(
384+
void store_layoutT(
385385
sycl::ext::oneapi::experimental::matrix::joint_matrix<
386386
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
387387
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
@@ -391,48 +391,48 @@ void storeLayoutT(
391391
if constexpr (std::is_same_v<T, float>) {
392392
__hmma_m16n16k16_st_c_f32(dst.get(),
393393
reinterpret_cast<float *>(&src.wi_marray),
394-
stride, get_layout_id<LayoutL>());
394+
stride, get_layout_id<Layout>());
395395
} else if constexpr (std::is_same_v<T, int32_t>) {
396396
__imma_m16n16k16_st_c_i32(dst.get(),
397397
reinterpret_cast<int32_t *>(&src.wi_marray),
398-
stride, get_layout_id<LayoutL>());
398+
stride, get_layout_id<Layout>());
399399
} else if constexpr (std::is_same_v<T, half>) {
400400
__hmma_m16n16k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
401401
reinterpret_cast<int32_t *>(&src.wi_marray),
402-
stride, get_layout_id<LayoutL>());
402+
stride, get_layout_id<Layout>());
403403
}
404404
} else if constexpr (NumRows == 8 && NumCols == 32) {
405405
if constexpr (std::is_same_v<T, float>) {
406406
__hmma_m8n32k16_st_c_f32(dst.get(),
407407
reinterpret_cast<float *>(&src.wi_marray),
408-
stride, get_layout_id<LayoutL>());
408+
stride, get_layout_id<Layout>());
409409
} else if constexpr (std::is_same_v<T, int32_t>) {
410410
__imma_m8n32k16_st_c_i32(dst.get(),
411411
reinterpret_cast<int32_t *>(&src.wi_marray),
412-
stride, get_layout_id<LayoutL>());
412+
stride, get_layout_id<Layout>());
413413
} else if constexpr (std::is_same_v<T, half>) {
414414
__hmma_m8n32k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
415415
reinterpret_cast<int32_t *>(&src.wi_marray),
416-
stride, get_layout_id<LayoutL>());
416+
stride, get_layout_id<Layout>());
417417
}
418418
} else if constexpr (NumRows == 32 && NumCols == 8) {
419419
if constexpr (std::is_same_v<T, float>) {
420420
__hmma_m32n8k16_st_c_f32(dst.get(),
421421
reinterpret_cast<float *>(&src.wi_marray),
422-
stride, get_layout_id<LayoutL>());
422+
stride, get_layout_id<Layout>());
423423
} else if constexpr (std::is_same_v<T, int32_t>) {
424424
__imma_m32n8k16_st_c_i32(dst.get(),
425425
reinterpret_cast<int32_t *>(&src.wi_marray),
426-
stride, get_layout_id<LayoutL>());
426+
stride, get_layout_id<Layout>());
427427
} else if constexpr (std::is_same_v<T, half>) {
428428
__hmma_m32n8k16_st_c_f16(reinterpret_cast<int32_t *>(dst.get()),
429429
reinterpret_cast<int32_t *>(&src.wi_marray),
430-
stride, get_layout_id<LayoutL>());
430+
stride, get_layout_id<Layout>());
431431
}
432432
} else if constexpr (std::is_same_v<T, double>) {
433433
__dmma_m8n8k4_st_c_f64(dst.get(),
434434
reinterpret_cast<double *>(&src.wi_marray), stride,
435-
get_layout_id<LayoutL>());
435+
get_layout_id<Layout>());
436436
}
437437
}
438438

@@ -444,14 +444,14 @@ void joint_matrix_store_cuda(
444444
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic,
445445
sycl::sub_group> &src,
446446
multi_ptr<T, Space> dst, size_t stride,
447-
sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) {
448-
switch (LayoutAcc) {
447+
sycl::ext::oneapi::experimental::matrix::layout Layout) {
448+
switch (Layout) {
449449
case sycl::ext::oneapi::experimental::matrix::layout::row_major:
450-
storeLayoutT<sycl::ext::oneapi::experimental::matrix::layout::row_major>(
450+
store_layoutT<sycl::ext::oneapi::experimental::matrix::layout::row_major>(
451451
src, dst, stride);
452452
break;
453453
case sycl::ext::oneapi::experimental::matrix::layout::col_major:
454-
storeLayoutT<sycl::ext::oneapi::experimental::matrix::layout::col_major>(
454+
store_layoutT<sycl::ext::oneapi::experimental::matrix::layout::col_major>(
455455
src, dst, stride);
456456
break;
457457
default:

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ void joint_matrix_load(
4747
sycl::ext::oneapi::experimental::matrix::layout::dynamic,
4848
Group> &res,
4949
multi_ptr<T, Space> src, size_t stride,
50-
sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) {
50+
sycl::ext::oneapi::experimental::matrix::layout Layout) {
5151
#if defined(__SYCL_DEVICE_ONLY__)
5252
#if defined(__NVPTX__)
53-
sycl::ext::oneapi::detail::load_accumulator_cuda(res, src, stride, LayoutAcc);
53+
sycl::ext::oneapi::detail::load_accumulator_cuda(res, src, stride, Layout);
5454
#endif // defined(__NVPTX__)
5555
#else
5656
std::ignore = sg;
@@ -100,12 +100,12 @@ void joint_matrix_store(
100100
sycl::ext::oneapi::experimental::matrix::layout::dynamic,
101101
Group> &src,
102102
multi_ptr<T, Space> dst, size_t stride,
103-
sycl::ext::oneapi::experimental::matrix::layout LayoutAcc) {
103+
sycl::ext::oneapi::experimental::matrix::layout Layout) {
104104
#if defined(__SYCL_DEVICE_ONLY__)
105105
#if defined(__NVPTX__)
106106
sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
107107
Space>(src, dst, stride,
108-
LayoutAcc);
108+
Layout);
109109
#endif // defined(__NVPTX__)
110110
#else
111111
std::ignore = sg;

0 commit comments

Comments
 (0)