-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Fold when source is const #71643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Rik Huijzer (rikhuijzer) ChangesFixes #60656. This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function: func.func @<!-- -->main(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
return %0 : tensor<8x16x8x32xf32>
} will be changed into the following with func.func @<!-- -->main(%arg0: tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
return %cst : tensor<8x16x8x32xf32>
} As a side-note, this patch is essentially an extension of f79f430. That commit implemented the folding for Full diff: https://github.com/llvm/llvm-project/pull/71643.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 21e1f87bfa53709..c184971e478195e 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -659,6 +659,7 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
}
}];
let hasVerifier = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -986,6 +987,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -1867,6 +1869,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
}];
let hasCanonicalizeMethod = 1;
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -1948,6 +1952,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
}];
let hasCanonicalizeMethod = 1;
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6fc45379111fc34..c33dd603cb02899 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -834,6 +834,16 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
ReplaceEmptyTensorStaticShapeDims>(context);
}
+/// Try to remove a tensor operation if it would only reshape a constant.
+/// Removes the op and replaces the constant with a new constant of the result shape.
+static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
+ TensorType result) {
+ if (source && source.isSplat() && result.hasStaticShape())
+ return source.resizeSplat(result);
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
@@ -1089,6 +1099,14 @@ LogicalResult GatherOp::verify() {
return success();
}
+OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
@@ -1367,6 +1385,14 @@ LogicalResult ReshapeOp::verify() {
return success();
}
+OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Reassociative reshape ops
//===----------------------------------------------------------------------===//
@@ -2153,12 +2179,10 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
}
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
- if (auto splat =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
- auto resultType = llvm::cast<ShapedType>(getResult().getType());
- if (resultType.hasStaticShape())
- return splat.resizeSplat(resultType);
- }
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->getSource();
@@ -3823,6 +3847,14 @@ bool PackOp::isLikePad() {
return isLikePadUnPad(*this, packedTensorType);
}
+OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// UnPackOp
//===----------------------------------------------------------------------===//
@@ -3951,6 +3983,15 @@ bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
+
+OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1078ee3b59a4306..ea8c17640d7c143 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -672,6 +672,30 @@ func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8x
// -----
+// CHECK-LABEL: func @fold_gather_constant_splat
+// CHECK-NOT: tensor.gather
+// CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
+func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
+ %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
+ (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
+ return %0 : tensor<1x2x 1x1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_reshape_constant_splat
+// CHECK-NOT: tensor.reshape
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
+func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
+ %0 = tensor.reshape %cst(%shape)
+ : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
@@ -683,6 +707,30 @@ func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
// -----
+// CHECK-LABEL: func @fold_pack_constant_splat
+// CHECK-NOT: tensor.pack
+// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
+ %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_unpack_constant_splat
+// CHECK-NOT: tensor.unpack
+// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
+func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
+ %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+ return %0 : tensor<128x256xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
|
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks!
There are some build failures on |
Fixes llvm#60656. This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function: ```mlir func.func @main(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> return %0 : tensor<8x16x8x32xf32> } ``` will be changed into the following with `mlir-opt -canonicalize`: ```mlir func.func @main(%arg0: tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> return %cst : tensor<8x16x8x32xf32> } ``` As a side-note, this patch is essentially an extension of llvm@f79f430.
Fixes #60656.
This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function:
will be changed into the following with
mlir-opt -canonicalize
:As a side-note, this patch is essentially an extension of f79f430.