Skip to content

Commit 4dbaef6

Browse files
[mlir][Linalg] Avoid doing op replacement in linalg::dropUnitDims. (#105749)
It is better to do the replacement in the caller. This avoids the footgun if the caller needs the original operation. Instead return the produced operation and replacement values. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent a2a5508 commit 4dbaef6

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,13 @@ struct ControlDropUnitDims {
488488
return SmallVector<unsigned>{};
489489
};
490490
};
491-
LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
492-
const ControlDropUnitDims &options);
491+
struct DropUnitDimsResult {
492+
linalg::GenericOp resultOp;
493+
SmallVector<Value> replacements;
494+
};
495+
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
496+
GenericOp genericOp,
497+
const ControlDropUnitDims &options);
493498

494499
/// Fuse two `linalg.generic` operations that have a producer-consumer
495500
/// relationship captured through `fusedOperand`. The method expects

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,9 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
386386
return info;
387387
}
388388

389-
LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
390-
const ControlDropUnitDims &options) {
389+
FailureOr<DropUnitDimsResult>
390+
linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
391+
const ControlDropUnitDims &options) {
391392
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
392393
if (indexingMaps.empty())
393394
return failure();
@@ -545,8 +546,7 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
545546
resultReplacements.push_back(expandedValue);
546547
}
547548

548-
rewriter.replaceOp(genericOp, resultReplacements);
549-
return success();
549+
return DropUnitDimsResult{replacementOp, resultReplacements};
550550
}
551551

552552
namespace {
@@ -557,7 +557,13 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {
557557

558558
LogicalResult matchAndRewrite(GenericOp genericOp,
559559
PatternRewriter &rewriter) const override {
560-
return dropUnitDims(rewriter, genericOp, options);
560+
FailureOr<DropUnitDimsResult> result =
561+
dropUnitDims(rewriter, genericOp, options);
562+
if (failed(result)) {
563+
return failure();
564+
}
565+
rewriter.replaceOp(genericOp, result->replacements);
566+
return success();
561567
}
562568

563569
private:

mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
2525
linalg::GenericOp genericOp) {
2626
linalg::ControlDropUnitDims options;
2727
options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
28-
return linalg::dropUnitDims(rewriter, genericOp, options);
28+
FailureOr<linalg::DropUnitDimsResult> result =
29+
linalg::dropUnitDims(rewriter, genericOp, options);
30+
if (failed(result)) {
31+
return failure();
32+
}
33+
rewriter.replaceOp(genericOp, result->replacements);
34+
return success();
2935
}
3036

3137
struct TestLinalgDropUnitDims

0 commit comments

Comments
 (0)