-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor][NFC] Simplify SubsetInsertionOpInterface
implementation
#69999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor][NFC] Simplify SubsetInsertionOpInterface
implementation
#69999
Conversation
`tensor.insert_slice` and `tensor.parallel_insert_slice` can share the same implementation.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Matthias Springer (matthias-springer) Changes
Full diff: https://github.com/llvm/llvm-project/pull/69999.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index f4f46d54d78e59f..85f7796096a42ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,105 +17,58 @@ using namespace mlir::tensor;
namespace {
-/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
-/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
-/// equivalence of tensors.
template <typename OpTy>
-bool isSubsetEquivalentToInsertSliceLikeOp(
- OpTy insertSliceOp, Value candidate,
- function_ref<bool(Value, Value)> equivalenceFn) {
- // Look for a matching tensor.extract_slice op.
- auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
- if (!extractSliceOp)
- return false;
- if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
- return false;
- return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
- isEqualConstantIntOrValue);
-}
-
-template <typename OpTy>
-Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
- OpTy insertSliceOp) {
- auto extractOp = b.create<tensor::ExtractSliceOp>(
- loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
- insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
- insertSliceOp.getMixedStrides());
- return extractOp.getResult();
-}
-
-template <typename OpTy>
-SmallVector<Value>
-getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
- SmallVector<Value> neededValues;
- // Collect all values that are needed to construct the replacement op.
- neededValues.append(insertSliceOp.getOffsets().begin(),
- insertSliceOp.getOffsets().end());
- neededValues.append(insertSliceOp.getSizes().begin(),
- insertSliceOp.getSizes().end());
- neededValues.append(insertSliceOp.getStrides().begin(),
- insertSliceOp.getStrides().end());
- neededValues.push_back(insertSliceOp.getDest());
- return neededValues;
-}
-
-struct InsertSliceOpInterface
- : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
- tensor::InsertSliceOp> {
- OpOperand &getSourceOperand(Operation *op) const {
- return cast<tensor::InsertSliceOp>(op).getSourceMutable();
- }
-
- bool
- isEquivalentSubset(Operation *op, Value candidate,
- function_ref<bool(Value, Value)> equivalenceFn) const {
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
- equivalenceFn);
- }
-
- Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
- Location loc) const {
- return buildSubsetExtractionOfInsertSliceLikeOp(
- builder, loc, cast<tensor::InsertSliceOp>(op));
- }
-
- SmallVector<Value>
- getValuesNeededToBuildSubsetExtraction(Operation *op) const {
- return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
- cast<tensor::InsertSliceOp>(op));
- }
-};
-
-struct ParallelInsertSliceOpInterface
+struct InsertSliceLikeOpInterface
: public SubsetInsertionOpInterface::ExternalModel<
- ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
+ InsertSliceLikeOpInterface<OpTy>, OpTy> {
OpOperand &getSourceOperand(Operation *op) const {
- return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
+ return cast<OpTy>(op).getSourceMutable();
}
OpOperand &getDestinationOperand(Operation *op) const {
- return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
+ return cast<OpTy>(op).getDestMutable();
}
+ /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+ /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+ /// equivalence of tensors.
bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
- auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
- return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
- equivalenceFn);
+ auto insertSliceOp = cast<OpTy>(op);
+ // Look for a matching tensor.extract_slice op.
+ auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp)
+ return false;
+ if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+ return false;
+ return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+ isEqualConstantIntOrValue);
}
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
- return buildSubsetExtractionOfInsertSliceLikeOp(
- builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
+ auto insertSliceOp = cast<OpTy>(op);
+ auto extractOp = builder.create<tensor::ExtractSliceOp>(
+ loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
+ insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+ insertSliceOp.getMixedStrides());
+ return extractOp.getResult();
}
SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
- return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
- cast<tensor::ParallelInsertSliceOp>(op));
+ auto insertSliceOp = cast<OpTy>(op);
+ SmallVector<Value> neededValues;
+ // Collect all values that are needed to construct the replacement op.
+ neededValues.append(insertSliceOp.getOffsets().begin(),
+ insertSliceOp.getOffsets().end());
+ neededValues.append(insertSliceOp.getSizes().begin(),
+ insertSliceOp.getSizes().end());
+ neededValues.append(insertSliceOp.getStrides().begin(),
+ insertSliceOp.getStrides().end());
+ neededValues.push_back(insertSliceOp.getDest());
+ return neededValues;
}
};
@@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
- InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
- ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+ InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
*ctx);
+ ParallelInsertSliceOp::attachInterface<
+ InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
});
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!
tensor.insert_slice
andtensor.parallel_insert_slice
can share the same implementation.