Skip to content

Commit 204eb70

Browse files
authored
[mlir][Vector] Canonicalize empty vector.mask into arith.select (#140976)
This PR adds a missing canonicalization for empty `vector.mask` ops with a passthru value. ``` %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> becomes: %0 = arith.select %mask, %a, %passthru : vector<8xf32> ```
1 parent 1bdec97 commit 204eb70

File tree

3 files changed

+66
-8
lines changed

3 files changed

+66
-8
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
25592559
Location loc);
25602560
}];
25612561

2562+
let hasCanonicalizer = 1;
25622563
let hasFolder = 1;
25632564
let hasCustomAssemblyFormat = 1;
25642565
let hasVerifier = 1;

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6661,6 +6661,9 @@ LogicalResult MaskOp::verify() {
66616661
///
66626662
/// %0 = user_op %a : vector<8xf32>
66636663
///
6664+
/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
6665+
/// as it requires creating new operations.
6666+
66646667
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
66656668
SmallVectorImpl<OpFoldResult> &results) {
66666669
if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6699,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66966699
return success();
66976700
}
66986701

6702+
/// Canonialize empty `vector.mask` operations that can't be handled in
6703+
/// `VectorMask::fold` as they require creating new operations.
6704+
///
6705+
/// Example 1: Empty `vector.mask` with passthru operand.
6706+
///
6707+
/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
6708+
/// vector<8xi1> -> vector<8xf32>
6709+
///
6710+
/// becomes:
6711+
///
6712+
/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
6713+
///
6714+
class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
6715+
using OpRewritePattern::OpRewritePattern;
6716+
6717+
LogicalResult matchAndRewrite(MaskOp maskOp,
6718+
PatternRewriter &rewriter) const override {
6719+
if (!maskOp.isEmpty())
6720+
return failure();
6721+
6722+
if (!maskOp.hasPassthru())
6723+
return failure();
6724+
6725+
Block *block = maskOp.getMaskBlock();
6726+
auto terminator = cast<vector::YieldOp>(block->front());
6727+
assert(terminator.getNumOperands() == 1 &&
6728+
"expected one result when passthru is provided");
6729+
6730+
rewriter.replaceOpWithNewOp<arith::SelectOp>(
6731+
maskOp, maskOp.getResultTypes(), maskOp.getMask(),
6732+
terminator.getOperand(0), maskOp.getPassthru());
6733+
6734+
return success();
6735+
}
6736+
};
6737+
6738+
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6739+
MLIRContext *context) {
6740+
results.add<CanonializeEmptyMaskOp>(context);
6741+
}
6742+
66996743
// MaskingOpInterface definitions.
67006744

67016745
/// Returns the operation masked by this 'vector.mask'.

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
719719
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
720720
// CHECK-SAME: %[[A:.*]]: f32
721721
// CHECK: return %[[A]] : f32
722-
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
722+
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
723723
%idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
724724
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
725725
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
731731
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
732732
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
733733
// CHECK: return %[[A]] : vector<4xf32>
734-
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
734+
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
735735
%idx0 : index, %idx1 : index) -> vector<4xf32> {
736736
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
737737
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
744744
// CHECK-SAME: %[[A:.*]]: vector<f32>
745745
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
746746
// CHECK: return %[[B]] : f32
747-
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
747+
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
748748
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
749749
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
750750
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
780780
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
781781
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
782782
// CHECK: return %[[R]] : f32
783-
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
783+
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
784784
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
785785
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
786786
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
795795
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
796796
// CHECK: return %[[B]] : vector<4xf32>
797797
// rank(extract_output) < rank(broadcast_input)
798-
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
798+
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
799799
%idx0 : index, %idx1 : index) -> vector<4xf32> {
800800
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
801801
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
808808
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
809809
// CHECK: return %[[B]] : vector<4xf32>
810810
// rank(extract_output) > rank(broadcast_input)
811-
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
811+
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
812812
-> vector<4xf32> {
813813
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
814814
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
822822
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
823823
// CHECK: return %[[R]] : vector<8xf32>
824824
// rank(extract_output) == rank(broadcast_input)
825-
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
825+
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
826826
-> vector<8xf32> {
827827
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
828828
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
11691169
return %broadcast : vector<4x6xi8>
11701170
}
11711171

1172-
// -----
1172+
// -----
11731173

11741174
// CHECK-LABEL: broadcast_splat_constant
11751175
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
27562756

27572757
// -----
27582758

2759+
// CHECK-LABEL: func @empty_vector_mask_with_passthru
2760+
// CHECK-SAME: %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
2761+
func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
2762+
%passthru : vector<8xf32>) -> vector<8xf32> {
2763+
// CHECK-NOT: vector.mask
2764+
// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
2765+
// CHECK: return %[[SEL]] : vector<8xf32>
2766+
%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
2767+
return %0 : vector<8xf32>
2768+
}
2769+
2770+
// -----
2771+
27592772
// CHECK-LABEL: func @all_true_vector_mask
27602773
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
27612774
func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {

0 commit comments

Comments
 (0)