Skip to content

Commit 412fb21

Browse files
committed
[mlir][Vector] Handle 0-rank case in fold instead of RewriterPattern
1 parent 24a8e18 commit 412fb21

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16781678
return source;
16791679

16801680
unsigned extractResultRank = getRank(extractOp.getType());
1681-
if (extractResultRank >= broadcastSrcRank)
1681+
if (extractResultRank > broadcastSrcRank)
16821682
return Value();
16831683
// Check that the dimension of the result haven't been broadcasted.
16841684
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2159,13 +2159,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21592159
// folding patterns.
21602160
if (extractResultRank < broadcastSrcRank)
21612161
return failure();
2162+
// For scalar result, the input can only be a rank-0 vector, which will
2163+
// be handled by the folder.
2164+
if (extractResultRank == 0)
2165+
return failure();
21622166

2163-
// Special case if broadcast src is a 0D vector.
2164-
if (extractResultRank == 0) {
2165-
assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2166-
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2167-
return success();
2168-
}
21692167
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
21702168
extractOp, extractOp.getType(), source);
21712169
return success();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
736736

737737
// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
738738
// CHECK-SAME: %[[A:.*]]: vector<f32>
739-
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
739+
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
740740
// CHECK: return %[[B]] : f32
741741
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
742742
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
28342834
%3 = vector.extract %2[] : f32 from vector<f32>
28352835

28362836
// Broadcast 0D to 3D and extract scalar.
2837-
// CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
2837+
// CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
28382838
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
28392839
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
28402840

0 commit comments

Comments
 (0)