Skip to content

[mlir][tensor] Fold when source is const #71643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 9, 2023
Merged

[mlir][tensor] Fold when source is const #71643

merged 2 commits into from
Nov 9, 2023

Conversation

rikhuijzer
Copy link
Member

@rikhuijzer rikhuijzer commented Nov 8, 2023

Fixes #60656.

This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function:

func.func @main(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
  %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
  %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
  return %0 : tensor<8x16x8x32xf32>
}

will be changed into the following with mlir-opt -canonicalize:

func.func @main(%arg0: tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
  %cst = arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
  return %cst : tensor<8x16x8x32xf32>
}

As a side-note, this patch is essentially an extension of f79f430.

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Rik Huijzer (rikhuijzer)

Changes

Fixes #60656.

This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function:

func.func @<!-- -->main(%dest : tensor&lt;8x16x8x32xf32&gt;) -&gt; tensor&lt;8x16x8x32xf32&gt; {
  %cst = arith.constant dense&lt;1.000000e-01&gt; : tensor&lt;64x128xf32&gt;
  %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
    inner_tiles = [8, 32] into %dest : tensor&lt;64x128xf32&gt; -&gt; tensor&lt;8x16x8x32xf32&gt;
  return %0 : tensor&lt;8x16x8x32xf32&gt;
}

will be changed into the following with mlir-opt -canonicalize:

func.func @<!-- -->main(%arg0: tensor&lt;8x16x8x32xf32&gt;) -&gt; tensor&lt;8x16x8x32xf32&gt; {
  %cst = arith.constant dense&lt;1.000000e-01&gt; : tensor&lt;8x16x8x32xf32&gt;
  return %cst : tensor&lt;8x16x8x32xf32&gt;
}

As a side-note, this patch is essentially an extension of f79f430. That commit implemented the folding for tensor.extract_slice, whereas this patch also implements it for tensor.gather, tensor.reshape, tensor.pack, and tensor.unpack.


Full diff: https://github.com/llvm/llvm-project/pull/71643.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+6)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+47-6)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+48)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 21e1f87bfa53709..c184971e478195e 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -659,6 +659,7 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
     }
   }];
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -986,6 +987,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [
     $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
   }];
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1867,6 +1869,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
   }];
 
   let hasCanonicalizeMethod = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1948,6 +1952,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
   }];
 
   let hasCanonicalizeMethod = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6fc45379111fc34..c33dd603cb02899 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -834,6 +834,16 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
               ReplaceEmptyTensorStaticShapeDims>(context);
 }
 
+/// Try to remove a tensor operation if it would only reshape a constant.
+/// Removes the op and replaces the constant with a new constant of the result shape.
+static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
+                                  TensorType result) {
+  if (source && source.isSplat() && result.hasStaticShape())
+    return source.resizeSplat(result);
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractOp
 //===----------------------------------------------------------------------===//
@@ -1089,6 +1099,14 @@ LogicalResult GatherOp::verify() {
   return success();
 }
 
+OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // InsertOp
 //===----------------------------------------------------------------------===//
@@ -1367,6 +1385,14 @@ LogicalResult ReshapeOp::verify() {
   return success();
 }
 
+OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
@@ -2153,12 +2179,10 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
 }
 
 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
-  if (auto splat =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    auto resultType = llvm::cast<ShapedType>(getResult().getType());
-    if (resultType.hasStaticShape())
-      return splat.resizeSplat(resultType);
-  }
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
   if (getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
     return this->getSource();
@@ -3823,6 +3847,14 @@ bool PackOp::isLikePad() {
   return isLikePadUnPad(*this, packedTensorType);
 }
 
+OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // UnPackOp
 //===----------------------------------------------------------------------===//
@@ -3951,6 +3983,15 @@ bool UnPackOp::isLikeUnPad() {
   RankedTensorType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
+
+OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Common Canonicalizers and Folders.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1078ee3b59a4306..ea8c17640d7c143 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -672,6 +672,30 @@ func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8x
 
 // -----
 
+// CHECK-LABEL: func @fold_gather_constant_splat
+//   CHECK-NOT: tensor.gather
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
+func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
+  %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
+    (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
+  return %0 : tensor<1x2x 1x1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_reshape_constant_splat
+//   CHECK-NOT: tensor.reshape
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
+func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
+  %0 = tensor.reshape %cst(%shape)
+             : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_constant_splat
 //   CHECK-NOT: tensor.extract_slice
 //       CHECK: arith.constant dense<42> : tensor<4x4xi32>
@@ -683,6 +707,30 @@ func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
 
 // -----
 
+// CHECK-LABEL: func @fold_pack_constant_splat
+//   CHECK-NOT: tensor.pack
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
+  %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
+  return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_unpack_constant_splat
+//   CHECK-NOT: tensor.unpack
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
+func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
+  %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+  return %0 : tensor<128x256xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {

This comment was marked as outdated.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks!

@rikhuijzer
Copy link
Member Author

rikhuijzer commented Nov 9, 2023

There are some build failures on main currently and I'll merge once those are green again. Thank you for the review @MaheshRavishankar! Much appreciated.

@rikhuijzer rikhuijzer merged commit d0da3d8 into llvm:main Nov 9, 2023
@rikhuijzer rikhuijzer deleted the rh/fold-constant-splat branch November 9, 2023 19:36
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
Fixes llvm#60656.

This patch implements a basic fold for various reshape/resize tensor
operations. Specifically, the folding removes tensor reshape/resize ops
when they are applied to a constant tensor. For example, the following
function:

```mlir
func.func @main(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
  %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
  %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
  return %0 : tensor<8x16x8x32xf32>
}
```
will be changed into the following with `mlir-opt -canonicalize`:
```mlir
func.func @main(%arg0: tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
  %cst = arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
  return %cst : tensor<8x16x8x32xf32>
}
```

As a side-note, this patch is essentially an extension of
llvm@f79f430.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] tensor.pack with a constant arguments should fold
3 participants