Skip to content

Commit 8ce6f7d

Browse files
[mlir][Linalg] NFC - Fail gracefully instead of crashing in SplitReduction
1 parent c55e6af commit 8ce6f7d

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include <utility>
1514
#include <optional>
15+
#include <utility>
1616

1717
#include "mlir/Analysis/SliceAnalysis.h"
1818
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -42,15 +42,16 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
4242

4343
SmallVector<unsigned> dims;
4444
op.getReductionDims(dims);
45-
assert(dims.size() == 1);
45+
46+
if (dims.size() != 1)
47+
return b.notifyMatchFailure(op, "needs a single reduction dimension");
4648
unsigned reductionDim = dims[0];
4749
if (control.innerParallel) {
4850
insertSplitDimension = reductionDim + 1;
4951
}
5052
SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
5153
int64_t reductionDimSize = loopRanges[reductionDim];
52-
if (reductionDimSize == ShapedType::kDynamic ||
53-
reductionDimSize % ratio != 0)
54+
if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
5455
return b.notifyMatchFailure(
5556
op, "Reduction dimension not divisible by split ratio");
5657
if (op.getNumDpsInits() != 1)
@@ -85,19 +86,22 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
8586
if (control.innerParallel) {
8687
newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
8788
newShape.push_back(ratio); // parallel (insert)
88-
exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
89+
exprs.push_back(
90+
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
8991
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
9092
} else {
9193
newShape.push_back(ratio); // parallel (insert)
9294
newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
9395
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
94-
exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
96+
exprs.push_back(
97+
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
9598
}
9699
reassociation.push_back({index++, index++});
97100
continue;
98101
}
99102
newShape.push_back(op.getShape(operand)[idx]);
100-
exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
103+
exprs.push_back(
104+
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
101105
reassociation.push_back({index++});
102106
}
103107
newMaps.push_back(

0 commit comments

Comments
 (0)