Skip to content

Commit 3b001db

Browse files
authored
[mlir][vector] Update tests for xfer permutation lowering (1/N) (#123076)
1. Remove `%c0 = arith.constant 0 : index` from testt functions. This extra Op is not needed (the index can be passed as an argument), so this is just noise. 2. Replaced `%cst_0` with `%pad` to communicate what the underlying SSA value is intended for. 3. Unified some comments.
1 parent b92cc78 commit 3b001db

File tree

1 file changed

+77
-73
lines changed

1 file changed

+77
-73
lines changed

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 77 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,43 @@
1313

1414
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map
1515
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
16-
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
16+
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>
1717
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
1818
// CHECK: vector.transfer_write
1919
// CHECK-NOT: permutation_map
2020
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
2121
func.func @xfer_write_transposing_permutation_map(
2222
%vec: vector<4x8xi16>,
23-
%mem: memref<2x2x8x4xi16>) {
23+
%mem: memref<2x2x8x4xi16>,
24+
%idx: index) {
2425

25-
%c0 = arith.constant 0 : index
26-
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
26+
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
2727
in_bounds = [true, true],
2828
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
2929
} : vector<4x8xi16>, memref<2x2x8x4xi16>
3030

3131
return
3232
}
3333

34-
// Even with out-of-bounds, it is safe to apply this pattern
34+
// Even with out-of-bounds accesses, it is safe to apply this pattern
35+
3536
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
3637
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
37-
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
38-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
38+
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>,
39+
// CHECK-SAME: %[[IDX:.*]]: index) {
3940
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
4041
// Expect the in_bounds attribute to be preserved. Since we don't print it when
4142
// all flags are "false", it should not appear in the output.
4243
// CHECK-NOT: in_bounds
4344
// CHECK: vector.transfer_write
4445
// CHECK-NOT: permutation_map
45-
// CHECK-SAME: %[[TR]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
46+
// CHECK-SAME: %[[TR]], %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
4647
func.func @xfer_write_transposing_permutation_map_out_of_bounds(
4748
%vec: vector<4x8xi16>,
48-
%mem: memref<2x2x?x?xi16>) {
49+
%mem: memref<2x2x?x?xi16>,
50+
%idx: index) {
4951

50-
%c0 = arith.constant 0 : index
51-
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
52+
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
5253
in_bounds = [false, false],
5354
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
5455
} : vector<4x8xi16>, memref<2x2x?x?xi16>
@@ -59,18 +60,19 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
5960
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
6061
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
6162
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
62-
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) {
63+
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>
6364
// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
6465
// CHECK: vector.transfer_write
6566
// CHECK-NOT: permutation_map
6667
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
6768
func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
6869
%vec: vector<4x[8]xi16>,
6970
%mem: memref<2x2x?x4xi16>,
70-
%mask: vector<[8]x4xi1>) {
71+
%mask: vector<[8]x4xi1>,
72+
%idx: index) {
7173

7274
%c0 = arith.constant 0 : index
73-
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
75+
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
7476
in_bounds = [true, true],
7577
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
7678
} : vector<4x[8]xi16>, memref<2x2x?x4xi16>
@@ -79,16 +81,18 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
7981
}
8082

8183
// Masked version is not supported
84+
8285
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked
8386
// CHECK-NOT: vector.transpose
8487
func.func @xfer_write_transposing_permutation_map_masked(
8588
%vec: vector<4x8xi16>,
8689
%mem: memref<2x2x8x4xi16>,
87-
%mask: vector<8x4xi1>) {
90+
%mask: vector<8x4xi1>,
91+
%idx: index) {
8892

8993
%c0 = arith.constant 0 : index
9094
vector.mask %mask {
91-
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
95+
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
9296
in_bounds = [true, true],
9397
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
9498
} : vector<4x8xi16>, memref<2x2x8x4xi16>
@@ -128,7 +132,8 @@ func.func @xfer_write_non_transposing_permutation_map(
128132
return
129133
}
130134

131-
// Even with out-of-bounds, it is safe to apply this pattern
135+
// Even with out-of-bounds accesses, it is safe to apply this pattern
136+
132137
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
133138
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
134139
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
@@ -157,8 +162,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
157162
// CHECK: func.func @permutation_with_mask_xfer_write_scalable(
158163
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
159164
// CHECK-SAME: %[[MEM:.*]]: memref<1x4x?x1xi16>,
160-
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
161-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
165+
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
162166
// CHECK: %[[BC_1:.*]] = vector.broadcast %[[VEC]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
163167
// CHECK: %[[BC_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
164168
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BC_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
@@ -167,18 +171,19 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
167171
func.func @permutation_with_mask_xfer_write_scalable(
168172
%vec: vector<4x[8]xi16>,
169173
%mem: memref<1x4x?x1xi16>,
170-
%mask: vector<4x[8]xi1>){
174+
%mask: vector<4x[8]xi1>,
175+
%idx: index){
171176

172-
%c0 = arith.constant 0 : index
173-
vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
177+
vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
174178
in_bounds = [true, true],
175179
permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
176180
} : vector<4x[8]xi16>, memref<1x4x?x1xi16>
177181

178182
return
179183
}
180184

181-
// transfer_write in MaskOp case not supported.
185+
// Masked version is not supported
186+
182187
// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
183188
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
184189
// CHECK-SAME: %[[VEC:.*]]: vector<16xf32>,
@@ -204,18 +209,19 @@ func.func @masked_permutation_xfer_write_fixed_width(
204209
// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
205210
// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
206211
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>,
207-
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
212+
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
208213
// CHECK-SAME: -> tensor<?x?x?x?xf32> {
209214
// CHECK-NOT: vector.transpose
210215
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
211216
func.func @masked_permutation_xfer_write_scalable(
212217
%vec: vector<4x[8]xi16>,
213218
%dest: tensor<?x?x?x?xf32>,
214-
%mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
219+
%mask: vector<4x[8]xi1>,
220+
%idx: index) -> tensor<?x?x?x?xf32> {
215221

216222
%c0 = arith.constant 0 : index
217223
%res = vector.mask %mask {
218-
vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
224+
vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
219225
in_bounds = [true, true],
220226
permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
221227
} : vector<4x[8]xi16>, tensor<?x?x?x?xf32>
@@ -224,22 +230,23 @@ func.func @masked_permutation_xfer_write_scalable(
224230
return %res : tensor<?x?x?x?xf32>
225231
}
226232

227-
// transfer_write in MaskOp case not supported.
233+
// Masked version is not supported
234+
228235
// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
229236
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>
230237
// CHECK-SAME: %[[VEC:.*]]: vector<14x8x16xf32>
231-
// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
238+
// CHECK-SAME: %[[DIM:.*]]: index, %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
232239
// CHECK-NOT: vector.broadcast
233240
// CHECK: vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
234241
func.func @masked_non_permutation_xfer_write_fixed_width(
235242
%dest : tensor<?x?x?x?xf32>,
236243
%vec : vector<14x8x16xf32>,
237-
%dim : index) -> tensor<?x?x?x?xf32> {
244+
%dim : index,
245+
%idx: index) -> tensor<?x?x?x?xf32> {
238246

239-
%c0 = arith.constant 0 : index
240247
%mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
241248
%res = vector.mask %mask {
242-
vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
249+
vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
243250
in_bounds = [false, false, true],
244251
permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
245252
} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
@@ -259,25 +266,23 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
259266

260267
// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_fixed_width(
261268
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
262-
// CHECK-SAME: %[[IDX_1:.*]]: index,
263-
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
264-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
269+
// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
265270
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
266-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x4xi1>
267-
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
271+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x4xi1>
272+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
268273
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
269274
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
270275
// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
271276
func.func @permutation_with_mask_xfer_read_fixed_width(
272277
%mem: memref<?x?xf32>,
273278
%dim_1: index,
274-
%dim_2: index) -> (vector<8x4x2xf32>) {
279+
%dim_2: index,
280+
%idx: index) -> (vector<8x4x2xf32>) {
275281

276-
%c0 = arith.constant 0 : index
277-
%cst_0 = arith.constant 0.000000e+00 : f32
282+
%pad = arith.constant 0.000000e+00 : f32
278283

279284
%mask = vector.create_mask %dim_2, %dim_1 : vector<2x4xi1>
280-
%res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
285+
%res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
281286
in_bounds = [true, true, true],
282287
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
283288
} : memref<?x?xf32>, vector<8x4x2xf32>
@@ -287,46 +292,45 @@ func.func @permutation_with_mask_xfer_read_fixed_width(
287292

288293
// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable(
289294
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
290-
// CHECK-SAME: %[[IDX_1:.*]]: index,
291-
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
292-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
293-
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
294-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x[4]xi1>
295-
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
295+
// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
296+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
297+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
298+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
296299
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
297300
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
298301
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
299302
func.func @permutation_with_mask_xfer_read_scalable(
300303
%mem: memref<?x?xf32>,
301304
%dim_1: index,
302-
%dim_2: index) -> (vector<8x[4]x2xf32>) {
305+
%dim_2: index,
306+
%idx: index) -> (vector<8x[4]x2xf32>) {
303307

304-
%c0 = arith.constant 0 : index
305-
%cst_0 = arith.constant 0.000000e+00 : f32
308+
%pad = arith.constant 0.000000e+00 : f32
306309

307310
%mask = vector.create_mask %dim_2, %dim_1 : vector<2x[4]xi1>
308-
%res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
311+
%res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
309312
in_bounds = [true, true, true],
310313
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
311314
} : memref<?x?xf32>, vector<8x[4]x2xf32>
312315

313316
return %res : vector<8x[4]x2xf32>
314317
}
315318

316-
// transfer_read in MaskOp case not supported.
319+
// Masked version is not supported
320+
317321
// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
318322
// CHECK-SAME: %[[DEST:.*]]: tensor<?x1xf32>,
319323
// CHECK-SAME: %[[MASK:.*]]: vector<4x1xi1>
320324
// CHECK-NOT: vector.transpose
321325
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
322326
func.func @masked_permutation_xfer_read_fixed_width(
323327
%dest: tensor<?x1xf32>,
324-
%mask : vector<4x1xi1>) {
328+
%mask : vector<4x1xi1>,
329+
%idx: index) {
325330

326-
%cst = arith.constant 0.000000e+00 : f32
327-
%c0 = arith.constant 0 : index
331+
%pad = arith.constant 0.000000e+00 : f32
328332
%3 = vector.mask %mask {
329-
vector.transfer_read %dest[%c0, %c0], %cst {
333+
vector.transfer_read %dest[%idx, %idx], %pad {
330334
permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>
331335
} : tensor<?x1xf32>, vector<1x4x4xf32>
332336
} : vector<4x1xi1> -> vector<1x4x4xf32>
@@ -337,18 +341,18 @@ func.func @masked_permutation_xfer_read_fixed_width(
337341

338342
// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
339343
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
340-
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
344+
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>
341345
// CHECK-NOT: vector.transpose
342346
// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
343347
func.func @masked_permutation_xfer_read_scalable(
344348
%dest: tensor<?x?xf32>,
345-
%mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
349+
%mask : vector<2x[4]xi1>,
350+
%idx: index) -> vector<8x[4]x2xf32> {
346351

347-
%c0 = arith.constant 0 : index
348-
%cst_0 = arith.constant 0.000000e+00 : f32
352+
%pad = arith.constant 0.000000e+00 : f32
349353

350354
%res = vector.mask %mask {
351-
vector.transfer_read %dest[%c0, %c0], %cst_0 {
355+
vector.transfer_read %dest[%idx, %idx], %pad {
352356
in_bounds = [true, true, true],
353357
permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
354358
} : tensor<?x?xf32>, vector<8x[4]x2xf32>
@@ -377,41 +381,41 @@ module attributes {transform.with_named_sequence} {
377381

378382
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
379383
// CHECK: func.func @transfer_read_reduce_rank_scalable(
380-
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
381-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
382-
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
384+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
385+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
383386
// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
384387
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
385388
func.func @transfer_read_reduce_rank_scalable(
386-
%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
389+
%mem: memref<?x?x?x?xf32>, %idx: index) -> vector<8x[4]x2x3xf32> {
387390

388-
%c0 = arith.constant 0 : index
389-
%cst_0 = arith.constant 0.000000e+00 : f32
391+
%pad = arith.constant 0.000000e+00 : f32
390392

391-
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
393+
%res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
392394
in_bounds = [true, true, true, true],
393395
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
394396
} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
395397

396398
return %res : vector<8x[4]x2x3xf32>
397399
}
398400

399-
// Masked case not supported.
401+
// Masked version is not supported
402+
400403
// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
401404
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
402-
// CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
405+
// CHECK-SAME: %[[DIM:.*]]: index,
406+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
403407
// CHECK-NOT: vector.broadcast
404408
// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
405409
func.func @masked_transfer_read_reduce_rank(
406410
%mem: memref<?x?x?x?xf32>,
407-
%dim: index) -> vector<8x[4]x2x3xf32> {
411+
%dim: index,
412+
%idx: index) -> vector<8x[4]x2x3xf32> {
408413

409-
%c0 = arith.constant 0 : index
410-
%cst_0 = arith.constant 0.000000e+00 : f32
414+
%pad = arith.constant 0.000000e+00 : f32
411415
%mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
412416

413417
%res = vector.mask %mask {
414-
vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
418+
vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
415419
in_bounds = [true, true, true, true],
416420
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
417421
} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>

0 commit comments

Comments
 (0)