Skip to content

Commit f3fa54a

Browse files
authored
[mlir][Vector] Handle 0-rank case in fold instead of RewriterPattern (#130168)
For vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors. Remove this special casing, and instead handle the 0-rank vector case in the folder.
1 parent dc28e0d commit f3fa54a

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,7 +1675,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16751675
return source;
16761676

16771677
unsigned extractResultRank = getRank(extractOp.getType());
1678-
if (extractResultRank >= broadcastSrcRank)
1678+
if (extractResultRank > broadcastSrcRank)
16791679
return Value();
16801680
// Check that the dimension of the result haven't been broadcasted.
16811681
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2156,13 +2156,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21562156
// folding patterns.
21572157
if (extractResultRank < broadcastSrcRank)
21582158
return failure();
2159+
// For scalar result, the input can only be a rank-0 vector, which will
2160+
// be handled by the folder.
2161+
if (extractResultRank == 0)
2162+
return failure();
21592163

2160-
// Special case if broadcast src is a 0D vector.
2161-
if (extractResultRank == 0) {
2162-
assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2163-
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2164-
return success();
2165-
}
21662164
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
21672165
extractOp, extractOp.getType(), source);
21682166
return success();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
736736

737737
// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
738738
// CHECK-SAME: %[[A:.*]]: vector<f32>
739-
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
739+
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
740740
// CHECK: return %[[B]] : f32
741741
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
742742
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
28342834
%3 = vector.extract %2[] : f32 from vector<f32>
28352835

28362836
// Broadcast 0D to 3D and extract scalar.
2837-
// CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
2837+
// CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
28382838
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
28392839
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
28402840

0 commit comments

Comments
 (0)