Skip to content

Commit 446c0a0

Browse files
committed
format
Signed-off-by: JackAKirk <[email protected]>
1 parent ee1208e commit 446c0a0

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,7 @@ void load_multiplicand_cuda(
259259
sycl::ext::oneapi::experimental::matrix::joint_matrix<
260260
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
261261
multi_ptr<T, Space> src, size_t stride) {
262-
if constexpr (std::is_same_v<
263-
S, sycl::ext::oneapi::experimental::bfloat16>) {
262+
if constexpr (std::is_same_v<S, sycl::ext::oneapi::experimental::bfloat16>) {
264263
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
265264
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
266265
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -355,8 +354,8 @@ void load_multiplicand_cuda(
355354
__hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id<Layout>());
356355
}
357356

358-
} else if constexpr (std::is_same_v<S, sycl::ext::oneapi::experimental::matrix::
359-
precision::tf32>) {
357+
} else if constexpr (std::is_same_v<S, sycl::ext::oneapi::experimental::
358+
matrix::precision::tf32>) {
360359
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
361360
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
362361
if constexpr (NumRows == 16 && NumCols == 8) {

sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ int main() {
154154
sycl::sub_group sg = item.get_sub_group();
155155

156156
joint_matrix<float, use::accumulator, 8, 32> sub_c;
157-
joint_matrix<bfloat16, use::a, 8, 16, layout::col_major> sub_a;
158-
joint_matrix<bfloat16, use::b, 16, 32, layout::col_major> sub_b;
157+
joint_matrix<bfloat16, use::a, 8, 16, layout::col_major> sub_a;
158+
joint_matrix<bfloat16, use::b, 16, 32, layout::col_major> sub_b;
159159

160160
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
161161
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,

sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ int main() {
4242
joint_matrix<uint8_t, use::b, 16, 16, layout::row_major> sub_b;
4343

4444
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accC, i32 16)
45-
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, layout::row_major);
45+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,
46+
layout::row_major);
4647
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
4748
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
4849
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.u8.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
4950
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
5051
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 %11, i32 %12, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8)
5152
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
5253
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_accD, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16)
53-
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, layout::row_major);
54+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride,
55+
layout::row_major);
5456
});
5557

5658
cgh.parallel_for<class col_col_m16n16k16>(

0 commit comments

Comments
 (0)