Skip to content

Commit 8cd8b50

Browse files
[mlir][Vector] Move mask materialization patterns to greedy rewrite (llvm#119973)
The mask materialization patterns during `VectorToLLVM` are rewrite patterns. They should run as part of the greedy pattern rewrite and not the dialect conversion. (Rewrite patterns and conversion patterns are not generally compatible.) The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions.
1 parent a1f5fe8 commit 8cd8b50

File tree

4 files changed

+44
-51
lines changed

4 files changed

+44
-51
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ struct ConvertVectorToLLVMPass
6161
} // namespace
6262

6363
void ConvertVectorToLLVMPass::runOnOperation() {
64-
// Perform progressive lowering of operations on slices and
65-
// all contraction operations. Also applies folding and DCE.
64+
// Perform progressive lowering of operations on slices and all contraction
65+
// operations. Also materializes masks, applies folding and DCE.
6666
{
6767
RewritePatternSet patterns(&getContext());
6868
populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -76,14 +76,15 @@ void ConvertVectorToLLVMPass::runOnOperation() {
7676
VectorTransformsOptions());
7777
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
7878
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
79+
populateVectorMaskMaterializationPatterns(patterns,
80+
force32BitVectorIndices);
7981
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
8082
}
8183

8284
// Convert to the LLVM IR dialect.
8385
LowerToLLVMOptions options(&getContext());
8486
LLVMTypeConverter converter(&getContext(), options);
8587
RewritePatternSet patterns(&getContext());
86-
populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
8788
populateVectorTransferLoweringPatterns(patterns);
8889
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
8990
populateVectorToLLVMConversionPatterns(

mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
88
// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32>
99
// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi32>
10-
// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32>
10+
// CMP32: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi32>
1111
// CMP32: return %[[T4]] : vector<11xi1>
1212

1313
// CMP64-LABEL: @genbool_var_1d(
@@ -16,7 +16,7 @@
1616
// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
1717
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
1818
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
19-
// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64>
19+
// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
2020
// CMP64: return %[[T4]] : vector<11xi1>
2121

2222
func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,7 +3097,7 @@ func.func @create_mask_0d(%num_elems : index) -> vector<i1> {
30973097
// CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
30983098
// CHECK: %[[BOUNDS:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
30993099
// CHECK: %[[BOUNDS_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOUNDS]] : vector<1xi32> to vector<i32>
3100-
// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS_CAST]] : vector<i32>
3100+
// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS_CAST]], %[[INDICES]] : vector<i32>
31013101
// CHECK: return %[[RESULT]] : vector<i1>
31023102

31033103
// -----
@@ -3113,7 +3113,7 @@ func.func @create_mask_1d(%num_elems : index) -> vector<4xi1> {
31133113
// CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
31143114
// CHECK: %[[BOUNDS_INSERT:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
31153115
// CHECK: %[[BOUNDS:.*]] = llvm.shufflevector %[[BOUNDS_INSERT]]
3116-
// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS]] : vector<4xi32>
3116+
// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS]], %[[INDICES]] : vector<4xi32>
31173117
// CHECK: return %[[RESULT]] : vector<4xi1>
31183118

31193119
// -----

mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,28 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
1414
// CHECK-LABEL: func @transfer_read_write_1d
1515
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
1616
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
17-
// CHECK: %[[C7:.*]] = arith.constant 7.0
18-
//
19-
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
20-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
21-
// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
22-
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
17+
// 1. Create pass-through vector.
18+
// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<17xf32>
2319
//
2420
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
25-
// CHECK: %[[linearIndex:.*]] = arith.constant dense
21+
// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense
2622
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
2723
//
28-
// 3. Create bound vector to compute in-bound mask:
24+
// 3. Let dim be the memref dimension, compute the in-bound index (dim - offset)
25+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
26+
// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
27+
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
28+
//
29+
// 4. Create bound vector to compute in-bound mask:
2930
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
3031
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] :
3132
// CMP32-SAME: index to i32
3233
// CMP64-SAME: index to i64
3334
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
3435
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
35-
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]>
36+
// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
3637
// CMP64-SAME: : vector<17xi64>
3738
//
38-
// 4. Create pass-through vector.
39-
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
40-
//
4139
// 5. Bitcast to vector form.
4240
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
4341
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -48,28 +46,23 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
4846
// CHECK-SAME: -> vector<17xf32>
4947
//
5048
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
51-
// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
52-
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
49+
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
5350
// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
5451
//
55-
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
56-
// CHECK: %[[linearIndex_b:.*]] = arith.constant dense
57-
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
58-
//
59-
// 3. Create bound vector to compute in-bound mask:
52+
// 2. Create bound vector to compute in-bound mask:
6053
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
6154
// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]]
6255
// CMP32-SAME: index to i32
6356
// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
6457
// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
65-
// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
66-
// CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]>
58+
// CHECK: %[[mask_b:.*]] = arith.cmpi sgt, %[[boundVect_b]],
59+
// CHECK-SAME: %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
6760
//
68-
// 4. Bitcast to vector form.
61+
// 3. Bitcast to vector form.
6962
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
7063
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
7164
//
72-
// 5. Rewrite as a masked write.
65+
// 4. Rewrite as a masked write.
7366
// CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
7467
// CHECK-SAME: {alignment = 4 : i32} :
7568
// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
@@ -87,27 +80,25 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
8780
// CHECK-LABEL: func @transfer_read_write_1d_scalable
8881
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
8982
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
90-
// CHECK: %[[C7:.*]] = arith.constant 7.0
83+
// 1. Create pass-through vector.
84+
// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<[17]xf32>
9185
//
92-
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
93-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
86+
// 2. Let dim be the memref dimension, compute the in-bound index (dim - offset)
87+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
9488
// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
9589
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
9690
//
97-
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
91+
// 3. Create a vector with linear indices [ 0 .. vector_length - 1 ].
9892
// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
9993
//
100-
// 3. Create bound vector to compute in-bound mask:
94+
// 4. Create bound vector to compute in-bound mask:
10195
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
10296
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
10397
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
10498
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
10599
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
106100
// CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]>
107101
//
108-
// 4. Create pass-through vector.
109-
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
110-
//
111102
// 5. Bitcast to vector form.
112103
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
113104
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -118,8 +109,7 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
118109
// CHECK-SAME: -> vector<[17]xf32>
119110
//
120111
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
121-
// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
122-
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
112+
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
123113
// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
124114
//
125115
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
@@ -197,23 +187,23 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
197187
}
198188
// CHECK-LABEL: func @transfer_read_2d_to_1d
199189
// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
200-
// CHECK: %[[c1:.*]] = arith.constant 1 : index
190+
//
191+
// Create a vector with linear indices [ 0 .. vector_length - 1 ].
192+
// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
193+
// CHECK-SAME: vector<17x[[$IDX_TYPE]]>
194+
//
195+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
201196
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
202197
//
203198
// Compute the in-bound index (dim - offset)
204199
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
205200
//
206-
// Create a vector with linear indices [ 0 .. vector_length - 1 ].
207-
// CHECK: %[[linearIndex:.*]] = arith.constant dense
208-
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
209-
// CHECK-SAME: vector<17x[[$IDX_TYPE]]>
210-
//
211201
// Create bound vector to compute in-bound mask:
212202
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
213203
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
214204
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
215205
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
216-
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
206+
// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]]
217207

218208
func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
219209
%f7 = arith.constant 7.0: f32
@@ -255,12 +245,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %bas
255245
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
256246
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
257247
//
248+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
249+
//
258250
// 1. Check address space for GEP is correct.
259251
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
260252
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
261253
//
262254
// 2. Check address space of the memref is correct.
263-
// CHECK: %[[c0:.*]] = arith.constant 0 : index
264255
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
265256
//
266257
// 3. Check address space for GEP is correct.
@@ -280,12 +271,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32,
280271
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
281272
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
282273
//
274+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
275+
//
283276
// 1. Check address space for GEP is correct.
284277
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
285278
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
286279
//
287280
// 2. Check address space of the memref is correct.
288-
// CHECK: %[[c0:.*]] = arith.constant 0 : index
289281
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
290282
//
291283
// 3. Check address space for GEP is correct.
@@ -330,10 +322,10 @@ func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index)
330322

331323
// CHECK-LABEL: func @transfer_read_write_1d_mask
332324
// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
333-
// CHECK: %[[cmpi:.*]] = arith.cmpi slt
325+
// CHECK: %[[cmpi:.*]] = arith.cmpi sgt
334326
// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
335327
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
336-
// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
328+
// CHECK: %[[cmpi_1:.*]] = arith.cmpi sgt
337329
// CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
338330
// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
339331
// CHECK: return %[[r]]

0 commit comments

Comments
 (0)