@@ -38,163 +38,137 @@ int main() {
38
38
[=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
39
39
sycl::sub_group sg = item.get_sub_group ();
40
40
41
- joint_matrix<float , use::accumulator, 16 , 16 >
42
- sub_c;
43
-
44
- joint_matrix<bfloat16, use::a, 16 , 16 ,
45
- layout::row_major>
46
- sub_a;
47
-
48
- joint_matrix<bfloat16, use::b, 16 , 16 ,
49
- layout::row_major>
50
- sub_b;
41
+ joint_matrix<float , use::accumulator, 16 , 16 > sub_c;
42
+ joint_matrix<bfloat16, use::a, 16 , 16 , layout::row_major> sub_a;
43
+ joint_matrix<bfloat16, use::b, 16 , 16 , layout::row_major> sub_b;
51
44
52
45
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
53
- joint_matrix_load (sg, sub_c, accC.get_pointer (), stride, layout::row_major);
46
+ joint_matrix_load (sg, sub_c, accC.get_pointer (), stride,
47
+ layout::row_major);
54
48
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
55
49
joint_matrix_load (sg, sub_a, accA.get_pointer (), stride);
56
50
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
57
51
joint_matrix_load (sg, sub_b, accB.get_pointer (), stride);
58
52
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
59
53
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
60
54
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16)
61
- joint_matrix_store (sg, sub_c, accD.get_pointer (), stride, layout::row_major);
55
+ joint_matrix_store (sg, sub_c, accD.get_pointer (), stride,
56
+ layout::row_major);
62
57
});
63
58
64
59
cgh.parallel_for <class col_col_m16n16k16 >(
65
60
nd_range<2 >({1 , 32 }, {1 , 32 }),
66
61
[=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
67
62
sycl::sub_group sg = item.get_sub_group ();
68
63
69
- joint_matrix<float , use::accumulator, 16 , 16 >
70
- sub_c;
71
-
72
- joint_matrix<bfloat16, use::a, 16 , 16 ,
73
- layout::col_major>
74
- sub_a;
75
-
76
- joint_matrix<bfloat16, use::b, 16 , 16 ,
77
- layout::col_major>
78
- sub_b;
64
+ joint_matrix<float , use::accumulator, 16 , 16 > sub_c;
65
+ joint_matrix<bfloat16, use::a, 16 , 16 , layout::col_major> sub_a;
66
+ joint_matrix<bfloat16, use::b, 16 , 16 , layout::col_major> sub_b;
79
67
80
68
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
81
- joint_matrix_load (sg, sub_c, accC.get_pointer (), stride, layout::col_major);
69
+ joint_matrix_load (sg, sub_c, accC.get_pointer (), stride,
70
+ layout::col_major);
82
71
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
83
72
joint_matrix_load (sg, sub_a, accA.get_pointer (), stride);
84
73
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
85
74
joint_matrix_load (sg, sub_b, accB.get_pointer (), stride);
86
75
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
87
76
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
88
77
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16)
89
- joint_matrix_store (sg, sub_c, accD.get_pointer (), stride, layout::col_major);
78
+ joint_matrix_store (sg, sub_c, accD.get_pointer (), stride,
79
+ layout::col_major);
90
80
});
91
81
92
82
cgh.parallel_for <class row_row_m32n8k16 >(
93
83
nd_range<2 >({1 , 32 }, {1 , 32 }),
94
84
[=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
95
85
sycl::sub_group sg = item.get_sub_group ();
96
86
97
- joint_matrix<float , use::accumulator, 32 , 8 >
98
- sub_c;
99
-
100
- joint_matrix<bfloat16, use::a, 32 , 16 ,
101
- layout::row_major>
102
- sub_a;
103
-
104
- joint_matrix<bfloat16, use::b, 16 , 8 , layout::row_major>
105
- sub_b;
87
+ joint_matrix<float , use::accumulator, 32 , 8 > sub_c;
88
+ joint_matrix<bfloat16, use::a, 32 , 16 , layout::row_major> sub_a;
89
+ joint_matrix<bfloat16, use::b, 16 , 8 , layout::row_major> sub_b;
106
90
107
91
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
108
- joint_matrix_load (sg, sub_c, accC.get_pointer (), stride, layout::row_major);
92
+ joint_matrix_load (sg, sub_c, accC.get_pointer (), stride,
93
+ layout::row_major);
109
94
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
110
95
joint_matrix_load (sg, sub_a, accA.get_pointer (), stride);
111
96
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
112
97
joint_matrix_load (sg, sub_b, accB.get_pointer (), stride);
113
98
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
114
99
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
115
100
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
116
- joint_matrix_store (sg, sub_c, accD.get_pointer (), stride, layout::row_major);
101
+ joint_matrix_store (sg, sub_c, accD.get_pointer (), stride,
102
+ layout::row_major);
117
103
});
118
104
119
105
cgh.parallel_for <class col_col_m32n8k16 >(
120
106
nd_range<2 >({1 , 32 }, {1 , 32 }),
121
107
[=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
122
108
sycl::sub_group sg = item.get_sub_group ();
123
109
124
- joint_matrix<float , use::accumulator, 32 , 8 >
125
- sub_c;
126
-
127
- joint_matrix<bfloat16, use::a, 32 , 16 ,
128
- layout::col_major>
129
- sub_a;
130
-
131
- joint_matrix<bfloat16, use::b, 16 , 8 , layout::col_major>
132
- sub_b;
110
+ joint_matrix<float , use::accumulator, 32 , 8 > sub_c;
111
+ joint_matrix<bfloat16, use::a, 32 , 16 , layout::col_major> sub_a;
112
+ joint_matrix<bfloat16, use::b, 16 , 8 , layout::col_major> sub_b;
133
113
134
114
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
135
- joint_matrix_load (sg, sub_c, accC.get_pointer (), stride, layout::col_major);
115
+ joint_matrix_load (sg, sub_c, accC.get_pointer (), stride,
116
+ layout::col_major);
136
117
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
137
118
joint_matrix_load (sg, sub_a, accA.get_pointer (), stride);
138
119
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
139
120
joint_matrix_load (sg, sub_b, accB.get_pointer (), stride);
140
121
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
141
122
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
142
123
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
143
- joint_matrix_store (sg, sub_c, accD.get_pointer (), stride, layout::col_major);
124
+ joint_matrix_store (sg, sub_c, accD.get_pointer (), stride,
125
+ layout::col_major);
144
126
});
145
127
146
128
cgh.parallel_for <class row_row_m8n32k16 >(
147
129
nd_range<2 >({1 , 32 }, {1 , 32 }),
148
130
[=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
149
131
sycl::sub_group sg = item.get_sub_group ();
150
132
151
- joint_matrix<float , use::accumulator, 8 , 32 >
152
- sub_c;
153
-
154
- joint_matrix<bfloat16, use::a, 8 , 16 , layout::row_major>
155
- sub_a;
156
-
157
- joint_matrix<bfloat16, use::b, 16 , 32 ,
158
- layout::row_major>
159
- sub_b;
133
+ joint_matrix<float , use::accumulator, 8 , 32 > sub_c;
134
+ joint_matrix<bfloat16, use::a, 8 , 16 , layout::row_major> sub_a;
135
+ joint_matrix<bfloat16, use::b, 16 , 32 , layout::row_major> sub_b;
160
136
161
137
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
162
- joint_matrix_load (sg, sub_c, accC.get_pointer (), stride, layout::row_major);
138
+ joint_matrix_load (sg, sub_c, accC.get_pointer (), stride,
139
+ layout::row_major);
163
140
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
164
141
joint_matrix_load (sg, sub_a, accA.get_pointer (), stride);
165
142
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
166
143
joint_matrix_load (sg, sub_b, accB.get_pointer (), stride);
167
144
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
168
145
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
169
146
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
170
- joint_matrix_store (sg, sub_c, accD.get_pointer (), stride, layout::row_major);
147
+ joint_matrix_store (sg, sub_c, accD.get_pointer (), stride,
148
+ layout::row_major);
171
149
});
172
150
173
151
cgh.parallel_for <class col_col_m8n32k16 >(
174
152
nd_range<2 >({1 , 32 }, {1 , 32 }),
175
153
[=](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
176
154
sycl::sub_group sg = item.get_sub_group ();
177
155
178
- joint_matrix<float , use::accumulator, 8 , 32 >
179
- sub_c;
180
-
181
- joint_matrix<bfloat16, use::a, 8 , 16 , layout::col_major>
182
- sub_a;
183
-
184
- joint_matrix<bfloat16, use::b, 16 , 32 ,
185
- layout::col_major>
186
- sub_b;
156
+ joint_matrix<float , use::accumulator, 8 , 32 > sub_c;
157
+ joint_matrix<bfloat16, use::a, 8 , 16 , layout::col_major> sub_a;
158
+ joint_matrix<bfloat16, use::b, 16 , 32 , layout::col_major> sub_b;
187
159
188
160
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
189
- joint_matrix_load (sg, sub_c, accC.get_pointer (), stride, layout::col_major);
161
+ joint_matrix_load (sg, sub_c, accC.get_pointer (), stride,
162
+ layout::col_major);
190
163
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
191
164
joint_matrix_load (sg, sub_a, accA.get_pointer (), stride);
192
165
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
193
166
joint_matrix_load (sg, sub_b, accB.get_pointer (), stride);
194
167
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
195
168
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
196
169
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
197
- joint_matrix_store (sg, sub_c, accD.get_pointer (), stride, layout::col_major);
170
+ joint_matrix_store (sg, sub_c, accD.get_pointer (), stride,
171
+ layout::col_major);
198
172
});
199
173
});
200
174
0 commit comments