|
11 | 11 | //
|
12 | 12 | //===----------------------------------------------------------------------===//
|
13 | 13 |
|
14 |
| -#include <utility> |
15 | 14 | #include <optional>
|
| 15 | +#include <utility> |
16 | 16 |
|
17 | 17 | #include "mlir/Analysis/SliceAnalysis.h"
|
18 | 18 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
@@ -42,15 +42,16 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
42 | 42 |
|
43 | 43 | SmallVector<unsigned> dims;
|
44 | 44 | op.getReductionDims(dims);
|
45 |
| - assert(dims.size() == 1); |
| 45 | + |
| 46 | + if (dims.size() != 1) |
| 47 | + return b.notifyMatchFailure(op, "needs a single reduction dimension"); |
46 | 48 | unsigned reductionDim = dims[0];
|
47 | 49 | if (control.innerParallel) {
|
48 | 50 | insertSplitDimension = reductionDim + 1;
|
49 | 51 | }
|
50 | 52 | SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
|
51 | 53 | int64_t reductionDimSize = loopRanges[reductionDim];
|
52 |
| - if (reductionDimSize == ShapedType::kDynamic || |
53 |
| - reductionDimSize % ratio != 0) |
| 54 | + if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0) |
54 | 55 | return b.notifyMatchFailure(
|
55 | 56 | op, "Reduction dimension not divisible by split ratio");
|
56 | 57 | if (op.getNumDpsInits() != 1)
|
@@ -85,19 +86,22 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
85 | 86 | if (control.innerParallel) {
|
86 | 87 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
|
87 | 88 | newShape.push_back(ratio); // parallel (insert)
|
88 |
| - exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1)); |
| 89 | + exprs.push_back( |
| 90 | + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
89 | 91 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
|
90 | 92 | } else {
|
91 | 93 | newShape.push_back(ratio); // parallel (insert)
|
92 | 94 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
|
93 | 95 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
|
94 |
| - exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1)); |
| 96 | + exprs.push_back( |
| 97 | + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
95 | 98 | }
|
96 | 99 | reassociation.push_back({index++, index++});
|
97 | 100 | continue;
|
98 | 101 | }
|
99 | 102 | newShape.push_back(op.getShape(operand)[idx]);
|
100 |
| - exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
| 103 | + exprs.push_back( |
| 104 | + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
101 | 105 | reassociation.push_back({index++});
|
102 | 106 | }
|
103 | 107 | newMaps.push_back(
|
|
0 commit comments