Skip to content

Commit 3941355

Browse files
committed
[mlir][vector] Support 0-D vector when eliding single element reduction
ElideSingleElementReduction causes assertion failure when we give 0-D vector. It's possible to fold the case by using vector.extractelement op instead. It's originally reported in llvm#60193. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D143242
1 parent 981218e commit 3941355

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,13 +530,19 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
530530
if (maskableOp.isMasked())
531531
return failure();
532532

533-
if (reductionOp.getVectorType().getDimSize(0) != 1)
533+
auto vectorType = reductionOp.getVectorType();
534+
if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
534535
return failure();
535536

536537
Location loc = reductionOp.getLoc();
537-
Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
538-
reductionOp.getVector(),
539-
rewriter.getI64ArrayAttr(0));
538+
Value result;
539+
if (vectorType.getRank() == 0) {
540+
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
541+
} else {
542+
result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
543+
reductionOp.getVector(),
544+
rewriter.getI64ArrayAttr(0));
545+
}
540546

541547
if (Value acc = reductionOp.getAcc())
542548
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,3 +2157,13 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
21572157
%1 = vector.extractelement %0 [%c5 : index] : vector<15xf32>
21582158
return %1 : f32
21592159
}
2160+
2161+
// -----
2162+
2163+
// CHECK-LABEL: func.func @fold_0d_vector_reduction
2164+
func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
2165+
// CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
2166+
// CHECK-NEXT: return %[[RES]] : f32
2167+
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
2168+
return %0 : f32
2169+
}

0 commit comments

Comments
 (0)