Skip to content

Commit 986287e

Browse files
[mlir][SparseTensor] Fix invalid API usage in patterns (#74690)
Rewrite patterns must return `success` if the IR was modified. This commit fixes sparse tensor tests such as `SparseTensor/sparse_fusion.mlir`, `SparseTensor/CPU/sparse_reduce_custom.mlir`, `SparseTensor/CPU/sparse_semiring_select.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
1 parent cdd81e3 commit 986287e

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
422422
if (!controlFn(&opOperand))
423423
continue;
424424

425-
// Find the producer of the operand.
426-
FailureOr<ElementwiseOpFusionResult> fusionResult =
427-
fuseElementwiseOps(rewriter, &opOperand);
428-
if (failed(fusionResult))
429-
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
430425
Operation *producer = opOperand.get().getDefiningOp();
431426

432427
// Do not fuse a sparse-in/dense-out operation, as the
@@ -435,6 +430,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
435430
!sparse_tensor::hasAnySparseResult(producer))
436431
return failure();
437432

433+
// Find the producer of the operand.
434+
FailureOr<ElementwiseOpFusionResult> fusionResult =
435+
fuseElementwiseOps(rewriter, &opOperand);
436+
if (failed(fusionResult))
437+
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
438+
438439
// Perform the fusion.
439440
for (auto [origVal, replacement] : fusionResult->replacements) {
440441
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,22 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
3838
LogicalResult matchAndRewrite(SourceOp op,
3939
PatternRewriter &rewriter) const override {
4040
Location loc = op.getLoc();
41+
4142
// Demaps non-trivial inputs.
43+
bool changed = false;
4244
SmallVector<Value> deMappedIns(op->getOperands());
43-
for (Value &in : deMappedIns)
44-
if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
45+
for (Value &in : deMappedIns) {
46+
if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
4547
in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48+
changed = true;
49+
}
50+
}
4651

4752
// CRTP call.
4853
OpAdaptor adaptor(deMappedIns, op);
49-
return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
50-
rewriter);
54+
LogicalResult status =
55+
static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56+
return changed ? success() : status;
5157
}
5258
};
5359

@@ -452,11 +458,13 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
452458
}
453459

454460
// Marks the GenericOp to avoid recursive matching.
455-
linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
461+
rewriter.updateRootInPlace(linalgOp, [&]() {
462+
linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
463+
});
456464

457465
// Already sorted.
458466
if (order.isIdentity())
459-
return failure();
467+
return success();
460468

461469
assert(order.isPermutation());
462470
// `order` is orignial loop -> sorted loop map

0 commit comments

Comments
 (0)