-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add RewriterBase to the C API #98962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Fehr Mathieu (math-fehr) ChangesThis exposes most of the The missing operations are the ones taking The Python bindings for these methods and classes are not implemented. Patch is 49.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98962.diff 7 Files Affected:
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index bed93045f4b50..09f8a72a0c599 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -33,10 +33,263 @@ extern "C" {
}; \
typedef struct name name
+DEFINE_C_API_STRUCT(MlirRewriterBase, void);
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+/// Get the MLIR context referenced by the rewriter.
+MLIR_CAPI_EXPORTED MlirContext
+mlirRewriterBaseGetContext(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+// They do not include functions using Block::iterator or Region::iterator, as
+// they are not exposed by the C API yet. This includes methods using
+// `InsertPoint` directly.
+
+/// Reset the insertion point to no location. Creating an operation without a
+/// set insertion point is an error, but this can still be useful when the
+/// current insertion point a builder refers to is being removed.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter);
+
+/// Sets the insertion point to the specified operation, which will cause
+/// subsequent insertions to go right before it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Sets the insertion point to the node after the specified operation, which
+/// will cause subsequent insertions to go right after it.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Sets the insertion point to the node after the specified value. If value
+/// has a defining operation, sets the insertion point to the node after such
+/// defining operation. This will cause subsequent insertions to go right
+/// after it. Otherwise, value is a BlockArgument. Sets the insertion point to
+/// the start of its block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+ MlirValue value);
+
+/// Sets the insertion point to the start of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+ MlirBlock block);
+
+/// Sets the insertion point to the end of the specified block.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+ MlirBlock block);
+
+/// Return the block the current insertion point belongs to. Note that the
+/// insertion point is not necessarily the end of the block.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
+
+/// Returns the current block of the rewriter.
+MLIR_CAPI_EXPORTED MlirBlock
+mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it. The block is placed before 'insertBefore'. `locs` contains the
+/// locations of the inserted arguments, and should match the size of
+/// `argTypes`.
+MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore(
+ MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes,
+ MlirType const *argTypes, MlirLocation const *locations);
+
+/// Insert the given operation at the current insertion point and return it.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI
+MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions(
+ MlirRewriterBase rewriter, MlirOperation op);
+
+// The IRMapper is not yet exposed in the CAPI, nor Region::iterator.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+ MlirBlock before);
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+/// Move the blocks that belong to "region" before the given position in
+/// another region "parent". The two regions must be different. The caller
+/// is responsible for creating or updating the operation transferring flow
+/// of control to the region and passing it the correct block arguments.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region,
+ MlirBlock before);
+
+/// Replace the results of the given (original) operation with the specified
+/// list of values (replacements). The result types of the given op and the
+/// replacements must match. The original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op,
+ intptr_t nValues, MlirValue const *values);
+
+/// Replace the results of the given (original) operation with the specified
+/// new op (replacement). The result types of the two ops must match. The
+/// original op is erased.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+ MlirOperation op, MlirOperation newOp);
+
+/// Erases an operation that is known to have no uses.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Erases a block along with all operations inside it.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter,
+ MlirBlock block);
+
+/// Inline the operations of block 'source' before the operation 'op'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The source block must have no successors. Otherwise, the resulting IR
+/// would have unreachable operations.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source,
+ MlirOperation op, intptr_t nArgValues,
+ MlirValue const *argValues);
+
+/// Inline the operations of block 'source' into the end of block 'dest'. The
+/// source block will be deleted and must have no uses. 'argValues' is used to
+/// replace the block arguments of 'source'
+///
+/// The dest block must have no successors. Otherwise, the resulting IR would
+/// have unreachable operation.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter,
+ MlirBlock source,
+ MlirBlock dest,
+ intptr_t nArgValues,
+ MlirValue const *argValues);
+
+// splitBlock is not implemented as Block::iterator is not exposed by the CAPI
+
+/// Unlink this operation from its current block and insert it right before
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter,
+ MlirOperation op,
+ MlirOperation existingOp);
+
+/// Unlink this operation from its current block and insert it right after
+/// `existingOp` which may be in the same or another block in the same
+/// function.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter,
+ MlirOperation op,
+ MlirOperation existingOp);
+
+/// Unlink this block and insert it right before `existingBlock`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
+ MlirBlock existingBlock);
+
+/// This method is used to notify the rewriter that an in-place operation
+/// modification is about to happen. A call to this function *must* be
+/// followed by a call to either `finalizeOpModification` or
+/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
+/// a new operation and removing the old one) but also often allows simpler
+/// code in the client.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// This method is used to signal the end of an in-place modification of the
+/// given operation. This can only be called on operations that were provided
+/// to a call to `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// This method cancels a pending in-place modification. This can only be
+/// called on operations that were provided to a call to
+/// `startOpModification`.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
+ MlirOperation op);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from,
+ MlirValue to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllValueRangeUsesWith(
+ MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from,
+ MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
+ MlirOperation from, intptr_t nTo,
+ MlirValue const *to);
+
+/// Find uses of `from` and replace them with `to`. Also notify the listener
+/// about every in-place op modification (for every use that was replaced)
+/// and that the `from` operation is about to be replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllOpUsesWithOperation(
+ MlirRewriterBase rewriter, MlirOperation from, MlirOperation to);
+
+/// Find uses of `from` within `block` and replace them with `to`. Also notify
+/// the listener about every in-place op modification (for every use that was
+/// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+/// uses were replaced.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpUsesWithinBlock(
+ MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues,
+ MlirValue const *newValues, MlirBlock block);
+
+/// Find uses of `from` and replace them with `to` except if the user is
+/// `exceptedUser`. Also notify the listener about every in-place op
+/// modification (for every use that was replaced).
+MLIR_CAPI_EXPORTED void
+mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from,
+ MlirValue to, MlirOperation exceptedUser);
+
+//===----------------------------------------------------------------------===//
+/// IRRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Create an IRRewriter and transfer ownership to the caller.
+MLIR_CAPI_EXPORTED MlirRewriterBase mlirIRRewriterCreate(MlirContext context);
+
+/// Create an IRRewriter and transfer ownership to the caller. Additionally
+/// set the insertion point before the operation.
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirIRRewriterCreateFromOp(MlirOperation op);
+
+/// Takes an IRRewriter owned by the caller and destroys it. It is the
+/// responsibility of the user to only pass an IRRewriter class.
+MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
+
+//===----------------------------------------------------------------------===//
+/// FrozenRewritePatternSet API
+//===----------------------------------------------------------------------===//
+
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
mlirFreezeRewritePattern(MlirRewritePatternSet op);
@@ -47,6 +300,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
new file mode 100644
index 0000000000000..0e6dcb2477626
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -0,0 +1,23 @@
+//===- Rewrite.h - C API Utils for Core MLIR classes ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// rewrite patterns. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_REWRITE_H
+#define MLIR_CAPI_REWRITE_H
+
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/PatternMatch.h"
+
+DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase);
+
+#endif // MLIR_CAPIREWRITER_H
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 0de1958398f63..7f3c833df0910 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -7,15 +7,260 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Rewrite.h"
+
#include "mlir-c/Transforms.h"
#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Rewrite.h"
#include "mlir/CAPI/Support.h"
+#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+//===----------------------------------------------------------------------===//
+/// RewriterBase API inherited from OpBuilder
+//===----------------------------------------------------------------------===//
+
+MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
+ return wrap(unwrap(rewriter)->getContext());
+}
+
+//===----------------------------------------------------------------------===//
+/// Insertion points methods
+
+void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
+ unwrap(rewriter)->clearInsertionPoint();
+}
+
+void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ unwrap(rewriter)->setInsertionPoint(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
+}
+
+void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
+ MlirValue value) {
+ unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
+}
+
+void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
+ MlirBlock block) {
+ unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
+}
+
+void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
+ MlirBlock block) {
+ unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
+}
+
+MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
+ return wrap(unwrap(rewriter)->getInsertionBlock());
+}
+
+MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
+ return wrap(unwrap(rewriter)->getBlock());
+}
+
+//===----------------------------------------------------------------------===//
+/// Block and operation creation/insertion/cloning
+
+MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
+ MlirBlock insertBefore,
+ intptr_t nArgTypes,
+ MlirType const *argTypes,
+ MlirLocation const *locations) {
+ SmallVector<Type, 4> args;
+ ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
+ SmallVector<Location, 4> locs;
+ ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
+ return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
+ unwrappedLocs));
+}
+
+MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ return wrap(unwrap(rewriter)->insert(unwrap(op)));
+}
+
+// Other methods of OpBuilder
+
+// The IRMapper is not yet exposed in the CAPI
+MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ return wrap(unwrap(rewriter)->clone(*unwrap(op)));
+}
+
+// The IRMapper is not yet exposed in the CAPI
+MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
+ MlirOperation op) {
+ return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
+}
+
+// The IRMapper is not yet exposed in the CAPI, nor Region::iterator.
+void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
+ MlirRegion region, MlirBlock before) {
+
+ unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewriterBase API
+//===----------------------------------------------------------------------===//
+
+// Region::iterator is not yet exposed in the CAPI.
+void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
+ MlirRegion region, MlirBlock before) {
+ unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
+}
+
+void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
+ MlirOperation op, intptr_t nValues,
+ MlirValue const *values) {
+ SmallVector<Value, 4> vals;
+ ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
+ unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
+}
+
+void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
+ MlirOperation op,
+ MlirOperation ne...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo small nit. Pretty cool btw - I guess now we can drive rewriter from whatever language has C FFI (which I've gotten asked about numerous times...)
Summary: This exposes most of the `RewriterBase` methods to the C API. This allows to manipulate both the `IRRewriter` and the `PatternRewriter`. The `IRRewriter` can be created from the C API, while the `PatternRewriter` cannot. The missing operations are the ones taking `Block::iterator` and `Region::iterator` as parameters, as they are not exposed by the C API yet AFAIK. The Python bindings for these methods and classes are not implemented. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250937
This exposes most of the
RewriterBase
methods to the C API.This allows to manipulate both the
IRRewriter
and thePatternRewriter
. TheIRRewriter
can be created from the C API, while thePatternRewriter
cannot.The missing operations are the ones taking
Block::iterator
andRegion::iterator
asparameters, as they are not exposed by the C API yet AFAIK.
The Python bindings for these methods and classes are not implemented.