Skip to content

Commit ee1208e

Browse files
committed
format.
Signed-off-by: JackAKirk <[email protected]>
1 parent 49147d3 commit ee1208e

10 files changed

+245
-366
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores-legacy.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,4 +790,3 @@ inline __SYCL_ALWAYS_INLINE float round_to_tf32(float a) {
790790
} // namespace ext
791791
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
792792
} // namespace sycl
793-

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11

2-
//===--- matrix-tensorcores.hpp - tensor cores matrix ext impl --*- C++ -*---===//
2+
//===-------- matrix-tensorcores.hpp - matrix ext impl ---*- C++ -*-------===//
33
//
44
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
55
// See https://llvm.org/LICENSE.txt for license information.
66
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
77
//
8-
// ===----------------------------------------------------------------------=== //
8+
// ===-------------------------------------------------------------------=== //
99

1010
#pragma once
1111
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
@@ -185,11 +185,9 @@ void load_accumulator_layoutT(
185185
__imma_m16n16k16_ld_c(destptr, src.get(), stride,
186186
get_layout_id<Layout>());
187187
} else if constexpr (NumRows == 8 && NumCols == 32) {
188-
__imma_m8n32k16_ld_c(destptr, src.get(), stride,
189-
get_layout_id<Layout>());
188+
__imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id<Layout>());
190189
} else if constexpr (NumRows == 32 && NumCols == 8) {
191-
__imma_m32n8k16_ld_c(destptr, src.get(), stride,
192-
get_layout_id<Layout>());
190+
__imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id<Layout>());
193191
}
194192
} else if constexpr (std::is_same_v<S, float>) {
195193
auto dstptr = reinterpret_cast<float *>(&res.wi_marray);
@@ -549,8 +547,8 @@ void joint_matrix_mad_cuda(
549547
get_layout_pair_id<LayoutA, LayoutB>(), 0);
550548
}
551549
} else if constexpr (std::is_same_v<Tm, uint16_t> ||
552-
std::is_same_v<Tm, sycl::ext::oneapi::experimental::
553-
bfloat16>) {
550+
std::is_same_v<
551+
Tm, sycl::ext::oneapi::experimental::bfloat16>) {
554552
__mma_bf16_m16n16k16_mma_f32(
555553
reinterpret_cast<float *>(&D.wi_marray),
556554
reinterpret_cast<const int32_t *>(&A.wi_marray),
@@ -586,8 +584,8 @@ void joint_matrix_mad_cuda(
586584
get_layout_pair_id<LayoutA, LayoutB>(), 0);
587585
}
588586
} else if constexpr (std::is_same_v<Tm, uint16_t> ||
589-
std::is_same_v<Tm, sycl::ext::oneapi::experimental::
590-
bfloat16>) {
587+
std::is_same_v<
588+
Tm, sycl::ext::oneapi::experimental::bfloat16>) {
591589
__mma_bf16_m8n32k16_mma_f32(
592590
reinterpret_cast<float *>(&D.wi_marray),
593591
reinterpret_cast<const int32_t *>(&A.wi_marray),
@@ -609,8 +607,8 @@ void joint_matrix_mad_cuda(
609607
get_layout_pair_id<LayoutA, LayoutB>(), 0);
610608
}
611609
} else if constexpr (std::is_same_v<Tm, uint16_t> ||
612-
std::is_same_v<Tm, sycl::ext::oneapi::experimental::
613-
bfloat16>) {
610+
std::is_same_v<
611+
Tm, sycl::ext::oneapi::experimental::bfloat16>) {
614612
__mma_bf16_m32n8k16_mma_f32(
615613
reinterpret_cast<float *>(&D.wi_marray),
616614
reinterpret_cast<const int32_t *>(&A.wi_marray),
@@ -653,4 +651,3 @@ void joint_matrix_mad_cuda(
653651
} // namespace ext
654652
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
655653
} // namespace sycl
656-

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
2121
inline __SYCL_ALWAYS_INLINE void
2222
joint_matrix_fill(Group sg,
2323
joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
24-
const T2& v) {
24+
const T2 &v) {
2525
std::ignore = sg;
2626
#if defined(__SYCL_DEVICE_ONLY__)
2727
#if defined(__NVPTX__)
@@ -177,4 +177,3 @@ float round_to_tf32(float &a) {
177177
} // namespace ext
178178
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
179179
} // namespace sycl
180-

sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp

Lines changed: 42 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -38,163 +38,137 @@ int main() {
3838
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
3939
sycl::sub_group sg = item.get_sub_group();
4040

41-
joint_matrix<float, use::accumulator, 16, 16>
42-
sub_c;
43-
44-
joint_matrix<bfloat16, use::a, 16, 16,
45-
layout::row_major>
46-
sub_a;
47-
48-
joint_matrix<bfloat16, use::b, 16, 16,
49-
layout::row_major>
50-
sub_b;
41+
joint_matrix<float, use::accumulator, 16, 16> sub_c;
42+
joint_matrix<bfloat16, use::a, 16, 16, layout::row_major> sub_a;
43+
joint_matrix<bfloat16, use::b, 16, 16, layout::row_major> sub_b;
5144

5245
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
53-
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, layout::row_major);
46+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,
47+
layout::row_major);
5448
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
5549
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
5650
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
5751
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
5852
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
5953
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
6054
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16)
61-
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, layout::row_major);
55+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride,
56+
layout::row_major);
6257
});
6358

6459
cgh.parallel_for<class col_col_m16n16k16>(
6560
nd_range<2>({1, 32}, {1, 32}),
6661
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
6762
sycl::sub_group sg = item.get_sub_group();
6863

69-
joint_matrix<float, use::accumulator, 16, 16>
70-
sub_c;
71-
72-
joint_matrix<bfloat16, use::a, 16, 16,
73-
layout::col_major>
74-
sub_a;
75-
76-
joint_matrix<bfloat16, use::b, 16, 16,
77-
layout::col_major>
78-
sub_b;
64+
joint_matrix<float, use::accumulator, 16, 16> sub_c;
65+
joint_matrix<bfloat16, use::a, 16, 16, layout::col_major> sub_a;
66+
joint_matrix<bfloat16, use::b, 16, 16, layout::col_major> sub_b;
7967

8068
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
81-
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, layout::col_major);
69+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,
70+
layout::col_major);
8271
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
8372
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
8473
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
8574
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
8675
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
8776
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
8877
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16)
89-
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, layout::col_major);
78+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride,
79+
layout::col_major);
9080
});
9181

9282
cgh.parallel_for<class row_row_m32n8k16>(
9383
nd_range<2>({1, 32}, {1, 32}),
9484
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
9585
sycl::sub_group sg = item.get_sub_group();
9686

97-
joint_matrix<float, use::accumulator, 32, 8>
98-
sub_c;
99-
100-
joint_matrix<bfloat16, use::a, 32, 16,
101-
layout::row_major>
102-
sub_a;
103-
104-
joint_matrix<bfloat16, use::b, 16, 8, layout::row_major>
105-
sub_b;
87+
joint_matrix<float, use::accumulator, 32, 8> sub_c;
88+
joint_matrix<bfloat16, use::a, 32, 16, layout::row_major> sub_a;
89+
joint_matrix<bfloat16, use::b, 16, 8, layout::row_major> sub_b;
10690

10791
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
108-
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, layout::row_major);
92+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,
93+
layout::row_major);
10994
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
11095
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
11196
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
11297
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
11398
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
11499
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
115100
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
116-
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, layout::row_major);
101+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride,
102+
layout::row_major);
117103
});
118104

119105
cgh.parallel_for<class col_col_m32n8k16>(
120106
nd_range<2>({1, 32}, {1, 32}),
121107
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
122108
sycl::sub_group sg = item.get_sub_group();
123109

124-
joint_matrix<float, use::accumulator, 32, 8>
125-
sub_c;
126-
127-
joint_matrix<bfloat16, use::a, 32, 16,
128-
layout::col_major>
129-
sub_a;
130-
131-
joint_matrix<bfloat16, use::b, 16, 8, layout::col_major>
132-
sub_b;
110+
joint_matrix<float, use::accumulator, 32, 8> sub_c;
111+
joint_matrix<bfloat16, use::a, 32, 16, layout::col_major> sub_a;
112+
joint_matrix<bfloat16, use::b, 16, 8, layout::col_major> sub_b;
133113

134114
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
135-
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, layout::col_major);
115+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,
116+
layout::col_major);
136117
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
137118
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
138119
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
139120
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
140121
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
141122
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
142123
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
143-
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, layout::col_major);
124+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride,
125+
layout::col_major);
144126
});
145127

146128
cgh.parallel_for<class row_row_m8n32k16>(
147129
nd_range<2>({1, 32}, {1, 32}),
148130
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
149131
sycl::sub_group sg = item.get_sub_group();
150132

151-
joint_matrix<float, use::accumulator, 8, 32>
152-
sub_c;
153-
154-
joint_matrix<bfloat16, use::a, 8, 16, layout::row_major>
155-
sub_a;
156-
157-
joint_matrix<bfloat16, use::b, 16, 32,
158-
layout::row_major>
159-
sub_b;
133+
joint_matrix<float, use::accumulator, 8, 32> sub_c;
134+
joint_matrix<bfloat16, use::a, 8, 16, layout::row_major> sub_a;
135+
joint_matrix<bfloat16, use::b, 16, 32, layout::row_major> sub_b;
160136

161137
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
162-
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, layout::row_major);
138+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,
139+
layout::row_major);
163140
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
164141
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
165142
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
166143
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
167144
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
168145
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
169146
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
170-
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, layout::row_major);
147+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride,
148+
layout::row_major);
171149
});
172150

173151
cgh.parallel_for<class col_col_m8n32k16>(
174152
nd_range<2>({1, 32}, {1, 32}),
175153
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
176154
sycl::sub_group sg = item.get_sub_group();
177155

178-
joint_matrix<float, use::accumulator, 8, 32>
179-
sub_c;
180-
181-
joint_matrix<bfloat16, use::a, 8, 16, layout::col_major>
182-
sub_a;
183-
184-
joint_matrix<bfloat16, use::b, 16, 32,
185-
layout::col_major>
186-
sub_b;
156+
joint_matrix<float, use::accumulator, 8, 32> sub_c;
157+
joint_matrix<bfloat16, use::a, 8, 16, layout::col_major> sub_a;
158+
joint_matrix<bfloat16, use::b, 16, 32, layout::col_major> sub_b;
187159

188160
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
189-
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, layout::col_major);
161+
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride,
162+
layout::col_major);
190163
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
191164
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
192165
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
193166
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
194167
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
195168
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
196169
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
197-
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, layout::col_major);
170+
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride,
171+
layout::col_major);
198172
});
199173
});
200174

0 commit comments

Comments
 (0)