@@ -165,10 +165,6 @@ builtin.module {
165
165
%0 = memref.alloc () : memref <4 x32 x16 xf32 , #gpu.address_space <workgroup >>
166
166
%c0 = arith.constant 0 : index
167
167
%cst_0 = arith.constant 0.000000e+00 : f32
168
- // CHECK: %[[mask:.*]] = vector.create_mask
169
- // CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1>
170
- // CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1>
171
- // CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1>
172
168
173
169
// CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
174
170
// CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]]
@@ -199,3 +195,64 @@ builtin.module {
199
195
transform.apply_cse to %top_level_func_2 : !transform.any_op
200
196
}
201
197
}
198
+
199
+ // -----
200
+
201
+ // 3D vector.transfer_read with a mask.
202
+ builtin.module {
203
+ // CHECK-LABEL: @read_3d_with_mask(
204
+ // CHECK-SAME: %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index, %[[a:.*]]: memref<1024x1024x1024xf32>
205
+ func.func @read_3d_with_mask (%sz0: index , %sz1: index , %sz2: index , %a: memref <1024 x1024 x1024 xf32 >) {
206
+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
207
+ // CHECK: %[[c1:.*]] = arith.constant 1 : index
208
+ // CHECK: %[[c2:.*]] = arith.constant 2 : index
209
+ %0 = memref.alloc () : memref <4 x32 x16 xf32 , #gpu.address_space <workgroup >>
210
+ %c0 = arith.constant 0 : index
211
+ %cst_0 = arith.constant 0.000000e+00 : f32
212
+
213
+ // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
214
+ // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c0]], %[[sz1]]
215
+ // CHECK: %[[cond0:.*]] = arith.andi %[[cmpi1]], %[[cmpi0]]
216
+ // CHECK: %[[s0:.*]] = arith.select %[[cond0]], %[[sz2]], %[[c0]]
217
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
218
+
219
+ // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c1]], %[[sz1]]
220
+ // CHECK: %[[cond1:.*]] = arith.andi %[[cmpi2]], %[[cmpi0]]
221
+ // CHECK: %[[s1:.*]] = arith.select %[[cond1]], %[[sz2]], %[[c0]]
222
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
223
+
224
+ // CHECK: %[[cmpi3:.*]] = arith.cmpi slt, %[[c2]], %[[sz1]]
225
+ // CHECK: %[[cond2:.*]] = arith.andi %[[cmpi3]], %[[cmpi0]]
226
+ // CHECK: %[[s2:.*]] = arith.select %[[cond2]], %[[sz2]], %[[c0]]
227
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
228
+
229
+ // CHECK: %[[cmpi4:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
230
+ // CHECK: %[[cond3:.*]] = arith.andi %[[cmpi1]], %[[cmpi4]]
231
+ // CHECK: %[[s3:.*]] = arith.select %[[cond3]], %[[sz2]], %[[c0]]
232
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s3]] {bypassL1}
233
+
234
+ // CHECK: %[[cond4:.*]] = arith.andi %[[cmpi2]], %[[cmpi4]]
235
+ // CHECK: %[[s4:.*]] = arith.select %[[cond4]], %[[sz2]], %[[c0]]
236
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s4]] {bypassL1}
237
+
238
+ // CHECK: %[[cond5:.*]] = arith.andi %[[cmpi3]], %[[cmpi4]]
239
+ // CHECK: %[[s5:.*]] = arith.select %[[cond5]], %[[sz2]], %[[c0]]
240
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s5]] {bypassL1}
241
+ %mask = vector.create_mask %sz0 , %sz1 , %sz2 : vector <2 x3 x4 xi1 >
242
+ %1 = vector.transfer_read %a [%c0 , %c0 , %c0 ], %cst_0 , %mask {in_bounds = [true , true , true ]} : memref <1024 x1024 x1024 xf32 >, vector <2 x3 x4 xf32 >
243
+ vector.transfer_write %1 , %0 [%c0 , %c0 , %c0 ] {in_bounds = [true , true , true ]} : vector <2 x3 x4 xf32 >, memref <4 x32 x16 xf32 , #gpu.address_space <workgroup >>
244
+
245
+ return
246
+ }
247
+
248
+ transform.sequence failures (propagate ) {
249
+ ^bb1 (%variant_op: !transform.any_op ):
250
+ %top_level_func = transform.structured.match ops {[" func.func" ]} in %variant_op : (!transform.any_op ) -> !transform.any_op
251
+ transform.apply_patterns to %top_level_func {
252
+ transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
253
+ } : !transform.any_op
254
+ transform.nvgpu.create_async_groups %top_level_func {bypass_l1 } : (!transform.any_op ) -> (!transform.any_op )
255
+ %top_level_func_2 = transform.structured.match ops {[" func.func" ]} in %variant_op : (!transform.any_op ) -> !transform.any_op
256
+ transform.apply_cse to %top_level_func_2 : !transform.any_op
257
+ }
258
+ }
0 commit comments