Skip to content

[mlir][Transforms] Support moveOpBefore/After in dialect conversion #81240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {

/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
virtual void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator);
void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);

/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
Expand All @@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {

/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);
void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);

/// Unlink this block and insert it right before `existingBlock`.
void moveBlockBefore(Block *block, Block *anotherBlock);
Expand Down
9 changes: 2 additions & 7 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter {

/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
/// and not nested regions. Updates to regions will still require
/// notification through other more specific hooks above.
void startOpModification(Operation *op) override;

/// PatternRewriter hook for updating the given operation in-place.
Expand All @@ -761,11 +761,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;

void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) override;
void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) override;

std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};

Expand Down
74 changes: 59 additions & 15 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,8 @@ class IRRewrite {
InlineBlock,
MoveBlock,
SplitBlock,
BlockTypeConversion
BlockTypeConversion,
MoveOperation
};

virtual ~IRRewrite() = default;
Expand Down Expand Up @@ -982,6 +983,54 @@ class BlockTypeConversionRewrite : public BlockRewrite {
// `ArgConverter::applyRewrites`. This should be done in the "commit" method.
void rollback() override;
};

/// An operation rewrite.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on the role of the class, the context where it's used?

The "Action" name for this whole section is not great by the way, since the concept of "Actions" is now core to MLIR tracing...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the class to OperationRewrite. I added more documentation to the superclass IRRewrite in the dependent PR.

class OperationRewrite : public IRRewrite {
public:
/// Return the operation that this rewrite operates on.
Operation *getOperation() const { return op; }

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
rewrite->getKind() <= Kind::MoveOperation;
}

protected:
OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
Operation *op)
: IRRewrite(kind, rewriterImpl), op(op) {}

// The operation that this rewrite operates on.
Operation *op;
};

/// Moving of an operation. This rewrite is immediately reflected in the IR.
class MoveOperationRewrite : public OperationRewrite {
public:
MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Operation *op, Block *block, Operation *insertBeforeOp)
: OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
insertBeforeOp(insertBeforeOp) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::MoveOperation;
}

void rollback() override {
// Move the operation back to its original position.
Block::iterator before =
insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
block->getOperations().splice(before, op->getBlock()->getOperations(), op);
}

private:
// The block in which this operation was previously contained.
Block *block;

// The original successor of this operation before it was moved. "nullptr" if
// this operation was the only operation in the region.
Operation *insertBeforeOp;
};
} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1478,12 +1527,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(

void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
createdOps.push_back(op);
if (!previous.isSet()) {
// This is a newly created op.
createdOps.push_back(op);
return;
}
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
? nullptr
: &*previous.getPoint();
appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
}

void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
Expand Down Expand Up @@ -1722,18 +1778,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}

void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}

void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}

detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,17 @@ module {
return
}
}

// -----

// CHECK-LABEL: func @test_move_op_before_rollback()
func.func @test_move_op_before_rollback() {
// CHECK: "test.one_region_op"()
// CHECK: "test.hoist_me"()
"test.one_region_op"() ({
// expected-remark @below{{'test.hoist_me' is not legalizable}}
%0 = "test.hoist_me"() : () -> (i32)
"test.valid"(%0) : (i32) -> ()
}) : () -> ()
"test.return"() : () -> ()
}
20 changes: 18 additions & 2 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};

/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
/// This is to test the rollback logic.
struct TestUndoMoveOpBefore : public ConversionPattern {
TestUndoMoveOpBefore(MLIRContext *ctx)
: ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.moveOpBefore(op, op->getParentOp());
// Replace with an illegal op to ensure the conversion fails.
rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
return success();
}
};

/// A rewrite pattern that tests the undo mechanism when erasing a block.
struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase(MLIRContext *ctx)
Expand Down Expand Up @@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
TestCreateUnregisteredOp>(&getContext());
TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
Expand All @@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
TerminatorOp>();
TerminatorOp, OneRegionOp>();
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
Expand Down