Skip to content

[mlir] Do not bufferize parallel_insert_slice dest to read for full slices #112761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 18, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Oct 17, 2024

In the insert_slice bufferization interface implementation, the destination tensor is not considered read if the full tensor is overwritten by the slice. This PR adds the same check for tensor.parallel_insert_slice.

Adds two new StaticValueUtils:

  • isAllConstantIntValue checks if an array of OpFoldResult are all equal to a passed int64_t value.
  • areConstantIntValues checks if an array of OpFoldResult are all equal to a passed array of int64_t values.

fixes #112435

@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

In the insert_slice bufferization interface implementation, the destination tensor is not considered read if the full tensor is overwritten by the slice. This PR adds the same check for tensor.parallel_insert_slice.


Full diff: https://github.com/llvm/llvm-project/pull/112761.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+33-27)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+15)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 87464ccb71720d..def4ee93854a1a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 
@@ -636,6 +637,34 @@ struct InsertOpInterface
   }
 };
 
+template <typename InsertOpTy>
+static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
+                                      OpOperand &opOperand) {
+  RankedTensorType destType = insertSliceOp.getDestType();
+
+  // The source is always read.
+  if (opOperand == insertSliceOp.getSourceMutable())
+    return true;
+
+  // For the destination, it depends...
+  assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
+
+  // Dest is not read if it is entirely overwritten. E.g.:
+  // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
+  bool allOffsetsZero =
+      llvm::all_of(insertSliceOp.getMixedOffsets(),
+                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+  bool sizesMatchDestSizes = llvm::all_of(
+      llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
+        return getConstantIntValue(it.value()) ==
+               destType.getDimSize(it.index());
+      });
+  bool allStridesOne =
+      llvm::all_of(insertSliceOp.getMixedStrides(),
+                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+  return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+}
+
 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
 /// certain circumstances, this op can also be a no-op.
 ///
@@ -646,32 +675,8 @@ struct InsertSliceOpInterface
                                                      tensor::InsertSliceOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    RankedTensorType destType = insertSliceOp.getDestType();
-
-    // The source is always read.
-    if (opOperand == insertSliceOp.getSourceMutable())
-      return true;
-
-    // For the destination, it depends...
-    assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
-
-    // Dest is not read if it is entirely overwritten. E.g.:
-    // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
-    bool allOffsetsZero =
-        llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
-          return isConstantIntValue(ofr, 0);
-        });
-    bool sizesMatchDestSizes = llvm::all_of(
-        llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
-          return getConstantIntValue(it.value()) ==
-                 destType.getDimSize(it.index());
-        });
-    bool allStridesOne =
-        llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-          return isConstantIntValue(ofr, 1);
-        });
-    return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+    return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
+                                     opOperand);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -931,7 +936,8 @@ struct ParallelInsertSliceOpInterface
 
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    return true;
+    return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
+                                     opOperand);
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index e2169fe1404c82..dc4306b8316ab7 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
 
 // -----
 
+// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
+// CHECK-NOT:     memref.alloc()
+func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
+    %fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  return %3 : tensor<2xf32>
+}
+
+// -----
+
 // This test case could bufferize in-place with a better analysis. However, it
 // is simpler to let the canonicalizer fold away the tensor.insert_slice.
 

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
bool sizesMatchDestSizes = llvm::all_of(
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
return getConstantIntValue(it.value()) ==
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a bit convoluted way of checking if size is static and equal... Could we not use OpFoldResult to do that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a vector version of getConstantIntValue. Perhaps we can just use it?

Something like:

std::optional<SmallVector<int64_t>> cstSizes = getConstantIntValues(insertSliceOp.getMixedSizes());
bool sizesMatchDestSizes = (cstSizes == destType.getShape());

/// If all ofrs are constant integers or IntegerAttrs, return the integers.
std::optional<SmallVector<int64_t>>
getConstantIntValues(ArrayRef<OpFoldResult> ofrs);

====

I'm not pretty sure about the below approach, but perhaps you can give it try.

The other approach is using getStaticSizes() and check if it is static shape or not. So it could be:

bool sizesMatchDestSizes = (insertSliceOp.getStaticSizes() == destType.getShape() && destType.hasStaticShape());

Copy link
Member

@matthias-springer matthias-springer Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getConstantIntValue works for dynamic SSA values that fold to a constant (e.g., %c1) and for OpFoldResult, so I usually prefer that one over just checking the static offsets/sizes/strides.

Comment on lines 654 to 656
bool allOffsetsZero =
llvm::all_of(insertSliceOp.getMixedOffsets(),
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can replace the lambda with isZeroIndex, can you give it a try?

[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
bool sizesMatchDestSizes = llvm::all_of(
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
return getConstantIntValue(it.value()) ==
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a vector version of getConstantIntValue. Perhaps we can just use it?

Something like:

std::optional<SmallVector<int64_t>> cstSizes = getConstantIntValues(insertSliceOp.getMixedSizes());
bool sizesMatchDestSizes = (cstSizes == destType.getShape());

/// If all ofrs are constant integers or IntegerAttrs, return the integers.
std::optional<SmallVector<int64_t>>
getConstantIntValues(ArrayRef<OpFoldResult> ofrs);

====

I'm not pretty sure about the below approach, but perhaps you can give it try.

The other approach is using getStaticSizes() and check if it is static shape or not. So it could be:

bool sizesMatchDestSizes = (insertSliceOp.getStaticSizes() == destType.getShape() && destType.hasStaticShape());

Comment on lines 662 to 664
bool allStridesOne =
llvm::all_of(insertSliceOp.getMixedStrides(),
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can consider adding a new method to StaticValueUtils.h, which takes an ArrayRef and check if the values are all value or not.

 bool isConstantIntValue(ArrayRef<OpFoldResult ofr>, int64_t value); 
// or name it to isConstantIntValueArray

/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to add two new StaticValueUtils areConstantIntValues and areAllConstantIntValue. I have wished such functions were available downstream on occasion as well, so I think they are both nice to have. WDYT?

@Max191 Max191 force-pushed the parallel-insert-bufferize-dest-to-read branch from 5a0fa0e to 9f02751 Compare October 18, 2024 14:36
@Max191 Max191 requested a review from hanhanW October 18, 2024 14:38
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much better to me, thanks!

Signed-off-by: Max Dawkins <[email protected]>
@Max191 Max191 merged commit 98e838a into llvm:main Oct 18, 2024
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Op needs producer extract_slice op to bufferize in place
5 participants