-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Fix tensor::PackOp fold() handling of padding value #87296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
We can't just check if it is a splat constant or not. We should also check if the value match.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Han-Chung Wang (hanhanW) ChangesWe can't just check if it is a splat constant or not. We should also check if the value match. Full diff: https://github.com/llvm/llvm-project/pull/87296.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 38a9ad60bb7948..8dc1ef67ce65c5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1069,9 +1069,11 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// Try to remove a tensor operation if it would only reshape a constant.
/// Removes the op and replaces the constant with a new constant of the result
/// shape.
-static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
- TensorType result) {
- if (source && source.isSplat() && result.hasStaticShape())
+static OpFoldResult
+reshapeConstantSource(DenseElementsAttr source, TensorType result,
+ std::optional<Attribute> cst = std::nullopt) {
+ if (source && source.isSplat() && result.hasStaticShape() &&
+ (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
return source.resizeSplat(result);
return {};
@@ -4143,9 +4145,12 @@ bool PackOp::isLikePad() {
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+ std::optional<Attribute> paddingValue;
+ if (adaptor.getPaddingValue())
+ paddingValue = adaptor.getPaddingValue();
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
- getResult().getType()))
+ adaptor.getDestType(), paddingValue))
return reshapedSource;
return {};
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9ab54fe9c133db..ac365c9d297e88 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -830,6 +830,39 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
// -----
+// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
+// CHECK-NOT: tensor.pack
+// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+ %pad = arith.constant 1.000000e-01 : f32
+ %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+ %0 = tensor.pack %cst
+ padding_value(%pad : f32)
+ outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
+// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+// CHECK: tensor.pack
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+ %pad = arith.constant 0.0 : f32
+ %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+ %0 = tensor.pack %cst
+ padding_value(%pad : f32)
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32]
+ into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
|
@@ -4143,9 +4145,12 @@ bool PackOp::isLikePad() { | |||
} | |||
|
|||
OpFoldResult PackOp::fold(FoldAdaptor adaptor) { | |||
std::optional<Attribute> paddingValue; | |||
if (adaptor.getPaddingValue()) | |||
paddingValue = adaptor.getPaddingValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we should avoid the double call to getPaddingValue()
here: this accessor is far from free (there is a loop hidden behind it)
if (source && source.isSplat() && result.hasStaticShape()) | ||
static OpFoldResult | ||
reshapeConstantSource(DenseElementsAttr source, TensorType result, | ||
std::optional<Attribute> cst = std::nullopt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the documentation for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good to me after addressing Mehdi's comments.
FYI: @rengolin |
We can't just check if it is a splat constant or not. We should also check if the value match.