Skip to content

[mlir][Linalg] Move linalg.fill -> linalg.pack pattern into fill canonicalization patterns. #66002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2023

Conversation

MaheshRavishankar
Copy link
Contributor

This pattern fits better with the other canonicalization patterns that exist for linalg.fill.

…` canonicalization patterns.

This pattern fits better with the other canonicalization patterns that
exist for `linalg.fill`.
@MaheshRavishankar MaheshRavishankar requested a review from a team as a code owner September 11, 2023 19:56
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir labels Sep 11, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir

Changes

This pattern fits better with the other canonicalization patterns that exist for linalg.fill.

Full diff: https://github.com/llvm/llvm-project/pull/66002.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+57-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (-60)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+44)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (-44)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bcc36d6bd3e95e5..e05a82855c66bd0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -737,7 +737,8 @@ struct FoldFillWithTensorExtract : public OpRewritePattern {
 
   LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    // See if tensor input of tensor.extract op is the result of a linalg.fill op.
+    // See if tensor input of tensor.extract op is the result of a linalg.fill
+    // op.
     auto fillOp = extractOp.getTensor().getDefiningOp();
     if (!fillOp)
       return failure();
@@ -751,15 +752,65 @@ struct FoldFillWithTensorExtract : public OpRewritePattern {
   }
 };
 
+/// Folds pack(fill) into a single fill op if
+///   1. The pack op does not have padding value, or
+///   2. The filled value and padding value are the same.
+static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter,
+                                                tensor::PackOp packOp) {
+  auto fillOp = packOp.getSource().getDefiningOp();
+  if (!fillOp)
+    return failure();
+
+  if (auto paddingValue = packOp.getPaddingValue())
+    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
+      return failure();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(fillOp);
+
+  Value packOpDest = packOp.getDest();
+  if (!packOpDest.hasOneUse())
+    return failure();
+  if (auto emptyOp = packOpDest.getDefiningOp()) {
+    packOpDest = tensor::PackOp::createDestinationTensor(
+        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
+        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+        packOp.getOuterDimsPerm());
+  } else {
+    DominanceInfo dom(fillOp);
+    if (!dom.properlyDominates(packOpDest, fillOp))
+      return failure();
+  }
+
+  Value fillDest = packOpDest;
+  return clone(rewriter, fillOp, packOpDest.getType(),
+               {fillOp.value(), fillDest});
+}
+
+/// Wrapper pattern that applies foldFillPackIntoFillOp method.
+struct FoldFillWithPack : public OpRewritePattern {
+public:
+  FoldFillWithPack(MLIRContext *context)
+      : OpRewritePattern(context) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
+    if (failed(fillOp))
+      return failure();
+    rewriter.replaceOp(packOp, fillOp.value().result());
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results
-      .add,
-           FoldFillWithTensorReshape,
-           FoldInsertPadIntoFill>(context);
+  results.add,
+              FoldFillWithTensorReshape,
+              FoldInsertPadIntoFill>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 1ddd8b144c60e85..95a20f2369f9e07 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -448,46 +448,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
                        *packInfo);
 }
 
-/// Folds pack(fill) into a single fill op if
-///   1. The pack op does not have padding value, or
-///   2. The filled value and padding value are the same.
-static FailureOr
-foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp,
-                       ControlPropagationFn controlFn) {
-  auto fillOp = packOp.getSource().getDefiningOp();
-  if (!fillOp)
-    return failure();
-
-  // User controlled propagation function.
-  if (!controlFn(fillOp))
-    return failure();
-
-  if (auto paddingValue = packOp.getPaddingValue())
-    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
-      return failure();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(fillOp);
-
-  Value packOpDest = packOp.getDest();
-  if (!packOpDest.hasOneUse())
-    return failure();
-  if (auto emptyOp = packOpDest.getDefiningOp()) {
-    packOpDest = tensor::PackOp::createDestinationTensor(
-        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
-        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
-        packOp.getOuterDimsPerm());
-  } else {
-    DominanceInfo dom(fillOp);
-    if (!dom.properlyDominates(packOpDest, fillOp))
-      return failure();
-  }
-
-  Value fillDest = packOpDest;
-  return clone(rewriter, fillOp, packOpDest.getType(),
-               {fillOp.value(), fillDest});
-}
-
 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
 struct BubbleUpPackOpThroughGenericOpPattern
     : public OpRewritePattern {
@@ -510,25 +470,6 @@ struct BubbleUpPackOpThroughGenericOpPattern
   ControlPropagationFn controlFn;
 };
 
-/// Wrapper pattern that applies foldFillPackIntoFillOp method.
-struct FoldFillPackIntoFillOpPattern : public OpRewritePattern {
-public:
-  FoldFillPackIntoFillOpPattern(MLIRContext *context, ControlPropagationFn fun)
-      : OpRewritePattern(context), controlFn(std::move(fun)) {}
-
-  LogicalResult matchAndRewrite(tensor::PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp, controlFn);
-    if (failed(fillOp))
-      return failure();
-    rewriter.replaceOp(packOp, fillOp.value().result());
-    return success();
-  }
-
-private:
-  ControlPropagationFn controlFn;
-};
-
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -750,7 +691,6 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation) {
   patterns.insert(
       patterns.getContext(), controlPackUnPackPropagation);
 }
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 297b5c4e332c811..7793e435582746c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -353,6 +353,50 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
 
 // -----
 
+func.func @fill_pack() -> tensor<24x32x16x16xf32> {
+  %dest = tensor.empty() : tensor<384x512xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<24x32x16x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
+  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
+  return %pack : tensor<24x32x16x16xf32>
+}
+// CHECK-LABEL: func.func @fill_pack
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 ceildiv 16)>
+func.func @dynamic_fill_pack(%arg0: tensor) -> tensor {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor) -> tensor
+  %dim = tensor.dim %0, %c0 : tensor
+  %dim_0 = tensor.dim %0, %c1 : tensor
+  %1 = affine.apply #map()[%dim]
+  %2 = affine.apply #map()[%dim_0]
+  %3 = tensor.empty(%1, %2) : tensor
+  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor -> tensor
+  return %pack : tensor
+}
+// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK:       func.func @dynamic_fill_pack
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
+// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
+// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
 // CHECK: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 3f98bf06433d2ec..4c59c97aecc2519 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -839,47 +839,3 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [16]
 // CHECK-SAME:     into %[[UNPACK_NEW_DEST]]
 // CHECK:        return %[[UNPACK]] : tensor<16x540x960xi32>
-
-// -----
-
-func.func @fill_pack() -> tensor<24x32x16x16xf32> {
-  %dest = tensor.empty() : tensor<384x512xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = tensor.empty() : tensor<24x32x16x16xf32>
-  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
-  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
-  return %pack : tensor<24x32x16x16xf32>
-}
-// CHECK-LABEL: func.func @fill_pack
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]
-
-// -----
-
-#map = affine_map<()[s0] -> (s0 ceildiv 16)>
-func.func @dynamic_fill_pack(%arg0: tensor) -> tensor {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor) -> tensor
-  %dim = tensor.dim %0, %c0 : tensor
-  %dim_0 = tensor.dim %0, %c1 : tensor
-  %1 = affine.apply #map()[%dim]
-  %2 = affine.apply #map()[%dim_0]
-  %3 = tensor.empty(%1, %2) : tensor
-  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor -> tensor
-  return %pack : tensor
-}
-// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-// CHECK:       func.func @dynamic_fill_pack
-// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
-// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
-// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir-linalg

Changes

This pattern fits better with the other canonicalization patterns that exist for linalg.fill.

Full diff: https://github.com/llvm/llvm-project/pull/66002.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+57-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (-60)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+44)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (-44)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bcc36d6bd3e95e5..e05a82855c66bd0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -737,7 +737,8 @@ struct FoldFillWithTensorExtract : public OpRewritePattern {
 
   LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    // See if tensor input of tensor.extract op is the result of a linalg.fill op.
+    // See if tensor input of tensor.extract op is the result of a linalg.fill
+    // op.
     auto fillOp = extractOp.getTensor().getDefiningOp();
     if (!fillOp)
       return failure();
@@ -751,15 +752,65 @@ struct FoldFillWithTensorExtract : public OpRewritePattern {
   }
 };
 
+/// Folds pack(fill) into a single fill op if
+///   1. The pack op does not have padding value, or
+///   2. The filled value and padding value are the same.
+static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter,
+                                                tensor::PackOp packOp) {
+  auto fillOp = packOp.getSource().getDefiningOp();
+  if (!fillOp)
+    return failure();
+
+  if (auto paddingValue = packOp.getPaddingValue())
+    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
+      return failure();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(fillOp);
+
+  Value packOpDest = packOp.getDest();
+  if (!packOpDest.hasOneUse())
+    return failure();
+  if (auto emptyOp = packOpDest.getDefiningOp()) {
+    packOpDest = tensor::PackOp::createDestinationTensor(
+        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
+        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+        packOp.getOuterDimsPerm());
+  } else {
+    DominanceInfo dom(fillOp);
+    if (!dom.properlyDominates(packOpDest, fillOp))
+      return failure();
+  }
+
+  Value fillDest = packOpDest;
+  return clone(rewriter, fillOp, packOpDest.getType(),
+               {fillOp.value(), fillDest});
+}
+
+/// Wrapper pattern that applies foldFillPackIntoFillOp method.
+struct FoldFillWithPack : public OpRewritePattern {
+public:
+  FoldFillWithPack(MLIRContext *context)
+      : OpRewritePattern(context) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
+    if (failed(fillOp))
+      return failure();
+    rewriter.replaceOp(packOp, fillOp.value().result());
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results
-      .add,
-           FoldFillWithTensorReshape,
-           FoldInsertPadIntoFill>(context);
+  results.add,
+              FoldFillWithTensorReshape,
+              FoldInsertPadIntoFill>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 1ddd8b144c60e85..95a20f2369f9e07 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -448,46 +448,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
                        *packInfo);
 }
 
-/// Folds pack(fill) into a single fill op if
-///   1. The pack op does not have padding value, or
-///   2. The filled value and padding value are the same.
-static FailureOr
-foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp,
-                       ControlPropagationFn controlFn) {
-  auto fillOp = packOp.getSource().getDefiningOp();
-  if (!fillOp)
-    return failure();
-
-  // User controlled propagation function.
-  if (!controlFn(fillOp))
-    return failure();
-
-  if (auto paddingValue = packOp.getPaddingValue())
-    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
-      return failure();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(fillOp);
-
-  Value packOpDest = packOp.getDest();
-  if (!packOpDest.hasOneUse())
-    return failure();
-  if (auto emptyOp = packOpDest.getDefiningOp()) {
-    packOpDest = tensor::PackOp::createDestinationTensor(
-        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
-        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
-        packOp.getOuterDimsPerm());
-  } else {
-    DominanceInfo dom(fillOp);
-    if (!dom.properlyDominates(packOpDest, fillOp))
-      return failure();
-  }
-
-  Value fillDest = packOpDest;
-  return clone(rewriter, fillOp, packOpDest.getType(),
-               {fillOp.value(), fillDest});
-}
-
 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
 struct BubbleUpPackOpThroughGenericOpPattern
     : public OpRewritePattern {
@@ -510,25 +470,6 @@ struct BubbleUpPackOpThroughGenericOpPattern
   ControlPropagationFn controlFn;
 };
 
-/// Wrapper pattern that applies foldFillPackIntoFillOp method.
-struct FoldFillPackIntoFillOpPattern : public OpRewritePattern {
-public:
-  FoldFillPackIntoFillOpPattern(MLIRContext *context, ControlPropagationFn fun)
-      : OpRewritePattern(context), controlFn(std::move(fun)) {}
-
-  LogicalResult matchAndRewrite(tensor::PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp, controlFn);
-    if (failed(fillOp))
-      return failure();
-    rewriter.replaceOp(packOp, fillOp.value().result());
-    return success();
-  }
-
-private:
-  ControlPropagationFn controlFn;
-};
-
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -750,7 +691,6 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation) {
   patterns.insert(
       patterns.getContext(), controlPackUnPackPropagation);
 }
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 297b5c4e332c811..7793e435582746c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -353,6 +353,50 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
 
 // -----
 
+func.func @fill_pack() -> tensor<24x32x16x16xf32> {
+  %dest = tensor.empty() : tensor<384x512xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<24x32x16x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
+  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
+  return %pack : tensor<24x32x16x16xf32>
+}
+// CHECK-LABEL: func.func @fill_pack
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 ceildiv 16)>
+func.func @dynamic_fill_pack(%arg0: tensor) -> tensor {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor) -> tensor
+  %dim = tensor.dim %0, %c0 : tensor
+  %dim_0 = tensor.dim %0, %c1 : tensor
+  %1 = affine.apply #map()[%dim]
+  %2 = affine.apply #map()[%dim_0]
+  %3 = tensor.empty(%1, %2) : tensor
+  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor -> tensor
+  return %pack : tensor
+}
+// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK:       func.func @dynamic_fill_pack
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
+// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
+// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
 // CHECK: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 3f98bf06433d2ec..4c59c97aecc2519 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -839,47 +839,3 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [16]
 // CHECK-SAME:     into %[[UNPACK_NEW_DEST]]
 // CHECK:        return %[[UNPACK]] : tensor<16x540x960xi32>
-
-// -----
-
-func.func @fill_pack() -> tensor<24x32x16x16xf32> {
-  %dest = tensor.empty() : tensor<384x512xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = tensor.empty() : tensor<24x32x16x16xf32>
-  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
-  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
-  return %pack : tensor<24x32x16x16xf32>
-}
-// CHECK-LABEL: func.func @fill_pack
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]
-
-// -----
-
-#map = affine_map<()[s0] -> (s0 ceildiv 16)>
-func.func @dynamic_fill_pack(%arg0: tensor) -> tensor {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor) -> tensor
-  %dim = tensor.dim %0, %c0 : tensor
-  %dim_0 = tensor.dim %0, %c1 : tensor
-  %1 = affine.apply #map()[%dim]
-  %2 = affine.apply #map()[%dim_0]
-  %3 = tensor.empty(%1, %2) : tensor
-  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor -> tensor
-  return %pack : tensor
-}
-// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-// CHECK:       func.func @dynamic_fill_pack
-// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
-// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
-// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir-core

Changes

This pattern fits better with the other canonicalization patterns that exist for linalg.fill.

Full diff: https://github.com/llvm/llvm-project/pull/66002.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+57-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (-60)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+44)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (-44)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bcc36d6bd3e95e5..e05a82855c66bd0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -737,7 +737,8 @@ struct FoldFillWithTensorExtract : public OpRewritePattern {
 
   LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    // See if tensor input of tensor.extract op is the result of a linalg.fill op.
+    // See if tensor input of tensor.extract op is the result of a linalg.fill
+    // op.
     auto fillOp = extractOp.getTensor().getDefiningOp();
     if (!fillOp)
       return failure();
@@ -751,15 +752,65 @@ struct FoldFillWithTensorExtract : public OpRewritePattern {
   }
 };
 
+/// Folds pack(fill) into a single fill op if
+///   1. The pack op does not have padding value, or
+///   2. The filled value and padding value are the same.
+static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter,
+                                                tensor::PackOp packOp) {
+  auto fillOp = packOp.getSource().getDefiningOp();
+  if (!fillOp)
+    return failure();
+
+  if (auto paddingValue = packOp.getPaddingValue())
+    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
+      return failure();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(fillOp);
+
+  Value packOpDest = packOp.getDest();
+  if (!packOpDest.hasOneUse())
+    return failure();
+  if (auto emptyOp = packOpDest.getDefiningOp()) {
+    packOpDest = tensor::PackOp::createDestinationTensor(
+        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
+        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+        packOp.getOuterDimsPerm());
+  } else {
+    DominanceInfo dom(fillOp);
+    if (!dom.properlyDominates(packOpDest, fillOp))
+      return failure();
+  }
+
+  Value fillDest = packOpDest;
+  return clone(rewriter, fillOp, packOpDest.getType(),
+               {fillOp.value(), fillDest});
+}
+
+/// Wrapper pattern that applies foldFillPackIntoFillOp method.
+struct FoldFillWithPack : public OpRewritePattern {
+public:
+  FoldFillWithPack(MLIRContext *context)
+      : OpRewritePattern(context) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
+    if (failed(fillOp))
+      return failure();
+    rewriter.replaceOp(packOp, fillOp.value().result());
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results
-      .add,
-           FoldFillWithTensorReshape,
-           FoldInsertPadIntoFill>(context);
+  results.add,
+              FoldFillWithTensorReshape,
+              FoldInsertPadIntoFill>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 1ddd8b144c60e85..95a20f2369f9e07 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -448,46 +448,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
                        *packInfo);
 }
 
-/// Folds pack(fill) into a single fill op if
-///   1. The pack op does not have padding value, or
-///   2. The filled value and padding value are the same.
-static FailureOr
-foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp,
-                       ControlPropagationFn controlFn) {
-  auto fillOp = packOp.getSource().getDefiningOp();
-  if (!fillOp)
-    return failure();
-
-  // User controlled propagation function.
-  if (!controlFn(fillOp))
-    return failure();
-
-  if (auto paddingValue = packOp.getPaddingValue())
-    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
-      return failure();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(fillOp);
-
-  Value packOpDest = packOp.getDest();
-  if (!packOpDest.hasOneUse())
-    return failure();
-  if (auto emptyOp = packOpDest.getDefiningOp()) {
-    packOpDest = tensor::PackOp::createDestinationTensor(
-        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
-        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
-        packOp.getOuterDimsPerm());
-  } else {
-    DominanceInfo dom(fillOp);
-    if (!dom.properlyDominates(packOpDest, fillOp))
-      return failure();
-  }
-
-  Value fillDest = packOpDest;
-  return clone(rewriter, fillOp, packOpDest.getType(),
-               {fillOp.value(), fillDest});
-}
-
 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
 struct BubbleUpPackOpThroughGenericOpPattern
     : public OpRewritePattern {
@@ -510,25 +470,6 @@ struct BubbleUpPackOpThroughGenericOpPattern
   ControlPropagationFn controlFn;
 };
 
-/// Wrapper pattern that applies foldFillPackIntoFillOp method.
-struct FoldFillPackIntoFillOpPattern : public OpRewritePattern {
-public:
-  FoldFillPackIntoFillOpPattern(MLIRContext *context, ControlPropagationFn fun)
-      : OpRewritePattern(context), controlFn(std::move(fun)) {}
-
-  LogicalResult matchAndRewrite(tensor::PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp, controlFn);
-    if (failed(fillOp))
-      return failure();
-    rewriter.replaceOp(packOp, fillOp.value().result());
-    return success();
-  }
-
-private:
-  ControlPropagationFn controlFn;
-};
-
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -750,7 +691,6 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation) {
   patterns.insert(
       patterns.getContext(), controlPackUnPackPropagation);
 }
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 297b5c4e332c811..7793e435582746c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -353,6 +353,50 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
 
 // -----
 
+func.func @fill_pack() -> tensor<24x32x16x16xf32> {
+  %dest = tensor.empty() : tensor<384x512xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<24x32x16x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
+  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
+  return %pack : tensor<24x32x16x16xf32>
+}
+// CHECK-LABEL: func.func @fill_pack
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 ceildiv 16)>
+func.func @dynamic_fill_pack(%arg0: tensor) -> tensor {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor) -> tensor
+  %dim = tensor.dim %0, %c0 : tensor
+  %dim_0 = tensor.dim %0, %c1 : tensor
+  %1 = affine.apply #map()[%dim]
+  %2 = affine.apply #map()[%dim_0]
+  %3 = tensor.empty(%1, %2) : tensor
+  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor -> tensor
+  return %pack : tensor
+}
+// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK:       func.func @dynamic_fill_pack
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
+// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
+// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
 // CHECK: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 3f98bf06433d2ec..4c59c97aecc2519 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -839,47 +839,3 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [16]
 // CHECK-SAME:     into %[[UNPACK_NEW_DEST]]
 // CHECK:        return %[[UNPACK]] : tensor<16x540x960xi32>
-
-// -----
-
-func.func @fill_pack() -> tensor<24x32x16x16xf32> {
-  %dest = tensor.empty() : tensor<384x512xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = tensor.empty() : tensor<24x32x16x16xf32>
-  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
-  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
-  return %pack : tensor<24x32x16x16xf32>
-}
-// CHECK-LABEL: func.func @fill_pack
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]
-
-// -----
-
-#map = affine_map<()[s0] -> (s0 ceildiv 16)>
-func.func @dynamic_fill_pack(%arg0: tensor) -> tensor {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor) -> tensor
-  %dim = tensor.dim %0, %c0 : tensor
-  %dim_0 = tensor.dim %0, %c1 : tensor
-  %1 = affine.apply #map()[%dim]
-  %2 = affine.apply #map()[%dim_0]
-  %3 = tensor.empty(%1, %2) : tensor
-  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor -> tensor
-  return %pack : tensor
-}
-// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-// CHECK:       func.func @dynamic_fill_pack
-// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
-// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
-// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, LGTM.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MaheshRavishankar
Copy link
Contributor Author

Whats the PR error? I am new to how llvm github is setup.

@qedawkins
Copy link
Contributor

Might be a CI flake? I see similar looking failures in other recently submitted PRs, although I am not familiar with the new Github setup either.

@MaheshRavishankar MaheshRavishankar merged commit fd8349c into llvm:main Sep 11, 2023
AntonRydahl pushed a commit to AntonRydahl/llvm-project that referenced this pull request Sep 11, 2023
…` canonicalization patterns. (llvm#66002)

This pattern fits better with the other canonicalization patterns that
exist for `linalg.fill`.
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…` canonicalization patterns. (llvm#66002)

This pattern fits better with the other canonicalization patterns that
exist for `linalg.fill`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants