Skip to content

[mlir][tensor][NFC] Simplify SubsetInsertionOpInterface implementation #69999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

tensor.insert_slice and tensor.parallel_insert_slice can share the same implementation.

`tensor.insert_slice` and `tensor.parallel_insert_slice` can share the same implementation.
@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Matthias Springer (matthias-springer)

Changes

tensor.insert_slice and tensor.parallel_insert_slice can share the same implementation.


Full diff: https://github.com/llvm/llvm-project/pull/69999.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (+36-82)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index f4f46d54d78e59f..85f7796096a42ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,105 +17,58 @@ using namespace mlir::tensor;
 
 namespace {
 
-/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
-/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
-/// equivalence of tensors.
 template <typename OpTy>
-bool isSubsetEquivalentToInsertSliceLikeOp(
-    OpTy insertSliceOp, Value candidate,
-    function_ref<bool(Value, Value)> equivalenceFn) {
-  // Look for a matching tensor.extract_slice op.
-  auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
-  if (!extractSliceOp)
-    return false;
-  if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
-    return false;
-  return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
-                                    isEqualConstantIntOrValue);
-}
-
-template <typename OpTy>
-Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
-                                               OpTy insertSliceOp) {
-  auto extractOp = b.create<tensor::ExtractSliceOp>(
-      loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
-      insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-      insertSliceOp.getMixedStrides());
-  return extractOp.getResult();
-}
-
-template <typename OpTy>
-SmallVector<Value>
-getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
-  SmallVector<Value> neededValues;
-  // Collect all values that are needed to construct the replacement op.
-  neededValues.append(insertSliceOp.getOffsets().begin(),
-                      insertSliceOp.getOffsets().end());
-  neededValues.append(insertSliceOp.getSizes().begin(),
-                      insertSliceOp.getSizes().end());
-  neededValues.append(insertSliceOp.getStrides().begin(),
-                      insertSliceOp.getStrides().end());
-  neededValues.push_back(insertSliceOp.getDest());
-  return neededValues;
-}
-
-struct InsertSliceOpInterface
-    : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
-                                                       tensor::InsertSliceOp> {
-  OpOperand &getSourceOperand(Operation *op) const {
-    return cast<tensor::InsertSliceOp>(op).getSourceMutable();
-  }
-
-  bool
-  isEquivalentSubset(Operation *op, Value candidate,
-                     function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
-                                                 equivalenceFn);
-  }
-
-  Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
-                              Location loc) const {
-    return buildSubsetExtractionOfInsertSliceLikeOp(
-        builder, loc, cast<tensor::InsertSliceOp>(op));
-  }
-
-  SmallVector<Value>
-  getValuesNeededToBuildSubsetExtraction(Operation *op) const {
-    return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
-        cast<tensor::InsertSliceOp>(op));
-  }
-};
-
-struct ParallelInsertSliceOpInterface
+struct InsertSliceLikeOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<
-          ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
+          InsertSliceLikeOpInterface<OpTy>, OpTy> {
   OpOperand &getSourceOperand(Operation *op) const {
-    return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
+    return cast<OpTy>(op).getSourceMutable();
   }
 
   OpOperand &getDestinationOperand(Operation *op) const {
-    return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
+    return cast<OpTy>(op).getDestMutable();
   }
 
+  /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
+  /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
+  /// equivalence of tensors.
   bool
   isEquivalentSubset(Operation *op, Value candidate,
                      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
-    return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
-                                                 equivalenceFn);
+    auto insertSliceOp = cast<OpTy>(op);
+    // Look for a matching tensor.extract_slice op.
+    auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractSliceOp)
+      return false;
+    if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
+      return false;
+    return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
+                                      isEqualConstantIntOrValue);
   }
 
   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
                               Location loc) const {
-    return buildSubsetExtractionOfInsertSliceLikeOp(
-        builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
+    auto insertSliceOp = cast<OpTy>(op);
+    auto extractOp = builder.create<tensor::ExtractSliceOp>(
+        loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
+        insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+        insertSliceOp.getMixedStrides());
+    return extractOp.getResult();
   }
 
   SmallVector<Value>
   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
-    return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
-        cast<tensor::ParallelInsertSliceOp>(op));
+    auto insertSliceOp = cast<OpTy>(op);
+    SmallVector<Value> neededValues;
+    // Collect all values that are needed to construct the replacement op.
+    neededValues.append(insertSliceOp.getOffsets().begin(),
+                        insertSliceOp.getOffsets().end());
+    neededValues.append(insertSliceOp.getSizes().begin(),
+                        insertSliceOp.getSizes().end());
+    neededValues.append(insertSliceOp.getStrides().begin(),
+                        insertSliceOp.getStrides().end());
+    neededValues.push_back(insertSliceOp.getDest());
+    return neededValues;
   }
 };
 
@@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
-    InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
-    ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+    InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
         *ctx);
+    ParallelInsertSliceOp::attachInterface<
+        InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
   });
 }

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

@matthias-springer matthias-springer merged commit 2e3c62b into llvm:main Oct 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants