-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Do not bufferize parallel_insert_slice dest to read for full slices #112761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Do not bufferize parallel_insert_slice dest to read for full slices #112761
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (Max191) ChangesIn the insert_slice bufferization interface implementation, the destination tensor is not considered read if the full tensor is overwritten by the slice. This PR adds the same check for tensor.parallel_insert_slice. Full diff: https://github.com/llvm/llvm-project/pull/112761.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 87464ccb71720d..def4ee93854a1a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -636,6 +637,34 @@ struct InsertOpInterface
}
};
+template <typename InsertOpTy>
+static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
+ OpOperand &opOperand) {
+ RankedTensorType destType = insertSliceOp.getDestType();
+
+ // The source is always read.
+ if (opOperand == insertSliceOp.getSourceMutable())
+ return true;
+
+ // For the destination, it depends...
+ assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
+
+ // Dest is not read if it is entirely overwritten. E.g.:
+ // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
+ bool allOffsetsZero =
+ llvm::all_of(insertSliceOp.getMixedOffsets(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+ bool sizesMatchDestSizes = llvm::all_of(
+ llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
+ return getConstantIntValue(it.value()) ==
+ destType.getDimSize(it.index());
+ });
+ bool allStridesOne =
+ llvm::all_of(insertSliceOp.getMixedStrides(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+ return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+}
+
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
///
@@ -646,32 +675,8 @@ struct InsertSliceOpInterface
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- RankedTensorType destType = insertSliceOp.getDestType();
-
- // The source is always read.
- if (opOperand == insertSliceOp.getSourceMutable())
- return true;
-
- // For the destination, it depends...
- assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
-
- // Dest is not read if it is entirely overwritten. E.g.:
- // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
- bool allOffsetsZero =
- llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
- return isConstantIntValue(ofr, 0);
- });
- bool sizesMatchDestSizes = llvm::all_of(
- llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
- return getConstantIntValue(it.value()) ==
- destType.getDimSize(it.index());
- });
- bool allStridesOne =
- llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return isConstantIntValue(ofr, 1);
- });
- return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+ return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
+ opOperand);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -931,7 +936,8 @@ struct ParallelInsertSliceOpInterface
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- return true;
+ return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
+ opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index e2169fe1404c82..dc4306b8316ab7 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
// -----
+// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
+// CHECK-NOT: memref.alloc()
+func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
+ %fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ return %3 : tensor<2xf32>
+}
+
+// -----
+
// This test case could bufferize in-place with a better analysis. However, it
// is simpler to let the canonicalizer fold away the tensor.insert_slice.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me.
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); | ||
bool sizesMatchDestSizes = llvm::all_of( | ||
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) { | ||
return getConstantIntValue(it.value()) == |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a bit convoluted way of checking if size is static and equal... Could we not use OpFoldResult
to do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a vector version of getConstantIntValue
. Perhaps we can just use it?
Something like:
std::optional<SmallVector<int64_t>> cstSizes = getConstantIntValues(insertSliceOp.getMixedSizes());
bool sizesMatchDestSizes = (cstSizes == destType.getShape());
llvm-project/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Lines 89 to 91 in 8c62bf5
/// If all ofrs are constant integers or IntegerAttrs, return the integers. | |
std::optional<SmallVector<int64_t>> | |
getConstantIntValues(ArrayRef<OpFoldResult> ofrs); |
====
I'm not pretty sure about the below approach, but perhaps you can give it try.
The other approach is using getStaticSizes()
and check if it is static shape or not. So it could be:
bool sizesMatchDestSizes = (insertSliceOp.getStaticSizes() == destType.getShape() && destType.hasStaticShape());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getConstantIntValue
works for dynamic SSA values that fold to a constant (e.g., %c1
) and for OpFoldResult
, so I usually prefer that one over just checking the static offsets/sizes/strides.
bool allOffsetsZero = | ||
llvm::all_of(insertSliceOp.getMixedOffsets(), | ||
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can replace the lambda with isZeroIndex
, can you give it a try?
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); | ||
bool sizesMatchDestSizes = llvm::all_of( | ||
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) { | ||
return getConstantIntValue(it.value()) == |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a vector version of getConstantIntValue
. Perhaps we can just use it?
Something like:
std::optional<SmallVector<int64_t>> cstSizes = getConstantIntValues(insertSliceOp.getMixedSizes());
bool sizesMatchDestSizes = (cstSizes == destType.getShape());
llvm-project/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Lines 89 to 91 in 8c62bf5
/// If all ofrs are constant integers or IntegerAttrs, return the integers. | |
std::optional<SmallVector<int64_t>> | |
getConstantIntValues(ArrayRef<OpFoldResult> ofrs); |
====
I'm not pretty sure about the below approach, but perhaps you can give it try.
The other approach is using getStaticSizes()
and check if it is static shape or not. So it could be:
bool sizesMatchDestSizes = (insertSliceOp.getStaticSizes() == destType.getShape() && destType.hasStaticShape());
bool allStridesOne = | ||
llvm::all_of(insertSliceOp.getMixedStrides(), | ||
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can consider adding a new method to StaticValueUtils.h
, which takes an ArrayRef and check if the values are all value
or not.
bool isConstantIntValue(ArrayRef<OpFoldResult ofr>, int64_t value);
// or name it to isConstantIntValueArray
llvm-project/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Lines 93 to 94 in 8c62bf5
/// Return true if `ofr` is constant integer equal to `value`. | |
bool isConstantIntValue(OpFoldResult ofr, int64_t value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted to add two new StaticValueUtils areConstantIntValues
and areAllConstantIntValue
. I have wished such functions were available downstream on occasion as well, so I think they are both nice to have. WDYT?
…lices Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
5a0fa0e
to
9f02751
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better to me, thanks!
Signed-off-by: Max Dawkins <[email protected]>
In the insert_slice bufferization interface implementation, the destination tensor is not considered read if the full tensor is overwritten by the slice. This PR adds the same check for tensor.parallel_insert_slice.
Adds two new StaticValueUtils:
isAllConstantIntValue
checks if an array ofOpFoldResult
are all equal to a passedint64_t
value.areConstantIntValues
checks if an array ofOpFoldResult
are all equal to a passed array ofint64_t
values.fixes #112435