Skip to content

[MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest #110514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 3, 2024

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Sep 30, 2024

Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the other operand, by the contraction with other as its dest.

Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.

@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the other operand, by the contraction with other as its dest.

Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.


Patch is 23.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110514.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+11)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4)
  • (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+6)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp (+108)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+75)
  • (added) mlir/test/Dialect/Linalg/fold-add-into-dest.mlir (+288)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..a997502c34299c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.linalg.fold_add_into_dest",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns to replace linalg.add when destination passing suffices
+    for achieving the sum.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48e657cca96e39..cc12ed7cfa6b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
 void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
                                           const ControlFusionFn &controlFn);
 
+/// Pattern to replace `linalg.add` when destination passing on a contraction op
+/// suffices for achieving the sum.
+void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
 /// if the producer is a `linalg` operation with all parallel iterator types.
 void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1e4f3004dec7e7..1d2759d2a91db1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -38,6 +38,12 @@ namespace linalg {
 // General utilities
 //===----------------------------------------------------------------------===//
 
+// Returns true if `val` represents a zero-filled tensor, per its defining op.
+bool isZeroTensor(Value val);
+
+// Returns true if the operation defines a zero-filled tensor.
+bool isZeroOp(Operation *);
+
 /// Check if all indexing maps are projected permutations.
 bool allIndexingsAreProjectedPermutation(LinalgOp op);
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 46c8510f4ed514..3b7b367d3cf2d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
 }
 
+void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateFoldAddIntoDestPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 47af392def94ac..b3cd5537aad9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   ElementwiseToLinalg.cpp
   EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
+  FoldAddIntoDest.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
new file mode 100644
index 00000000000000..d8c4e338fddbbc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -0,0 +1,108 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+  using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::AddOp addOp,
+                                PatternRewriter &rewriter) const override {
+    Value dominatingOperand = nullptr;
+    linalg::LinalgOp dominatedOp = nullptr;
+    {
+      auto firstOperand = addOp.getOperand(0);
+      auto secondOperand = addOp.getOperand(1);
+
+      // Can only put one of addOp's operands in the dest/out arg of the other's
+      // defining op based on suitable dominance.
+      if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
+        DominanceInfo domInfo(secondOp);
+        if (domInfo.properlyDominates(firstOperand, secondOp)) {
+          dominatingOperand = firstOperand;
+          dominatedOp = secondOp;
+        }
+      }
+      if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
+        DominanceInfo domInfo(firstOp);
+        if (domInfo.properlyDominates(secondOperand, firstOp)) {
+          dominatingOperand = secondOperand;
+          dominatedOp = firstOp;
+        }
+      }
+      if (!dominatingOperand || !dominatedOp)
+        return failure();
+      // NB: As linalg.add's generalisation ignores the out argument in its
+      //     region there is no need to perform checks on addOp's out argument.
+    }
+
+    // Dominated op must be a contraction for it to accumulate on its out arg.
+    // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+    auto dominatedDestOp =
+        dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+    if (dominatedOp->getNumResults() != 1 ||
+        !linalg::isaContractionOpInterface(dominatedOp) ||
+        (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+      return rewriter.notifyMatchFailure(
+          dominatedOp, "expected dominated op to be single-result "
+                       "destination-passing contraction");
+
+    // To change the contraction's result, `addOp` must be its only user.
+    if (!dominatedOp->getResult(0).hasOneUse())
+      return rewriter.notifyMatchFailure(
+          dominatedOp,
+          "expected linalg.add to be single user of contraction's result");
+
+    // As `dominatedOp` was already accumulating on its out argument, it is only
+    // safe to no longer use its current out arg when it is the additive zero.
+    auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+    if (!linalg::isZeroTensor(destOperand->get()))
+      return rewriter.notifyMatchFailure(
+          dominatedOp, "expected dominated op's dest to be additive zero");
+    // TODO: If the other op is a contraction and has additive zero as dest, we
+    // can swap the dests and achieve the proper sum, given suitable dominance.
+
+    // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+    // Hence, we can only substitute `dominatingOperand` for the dest of the
+    // contraction when dest's indexing_map corresponds to an identity map
+    // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+    SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+    int prevDimPos = -1;
+    for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+      auto dim = dyn_cast<AffineDimExpr>(expr);
+      if (!dim || prevDimPos >= (int)dim.getPosition())
+        return rewriter.notifyMatchFailure(
+            dominatedOp, "expected index_map for contraction's dest to be an "
+                         "ordered projection");
+      prevDimPos = dim.getPosition();
+    }
+
+    // Replace the additive-zero out argument of the dominated op by the
+    // dominating summand. This makes the dominated op's result the sum of both
+    // of addOp's arguments - therefore we replace addOp and it uses by it.
+    rewriter.modifyOpInPlace(
+        dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+    rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
+    return success();
+  }
+};
+
+void linalg::populateFoldAddIntoDestPatterns(
+    RewritePatternSet &patterns) {
+  // Replace linalg.add when destination passing suffices for achieving the sum.
+  patterns.add<FoldAddIntoDest>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 38e427af1c4846..a6a9ca5fd66330 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -870,5 +870,80 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
   return reassociation;
 }
 
+// Returns true if the value is a constant float or integer.
+bool isValConstZero(Value val) {
+  return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero());
+}
+
+// Returns true if the attribute represent "all zeros".
+static bool isZeroAttr(Attribute attribute) {
+  return TypeSwitch<Attribute, bool>(attribute)
+      .Case<FloatAttr>([](auto attr) { return attr.getValueAsDouble() == 0.0; })
+      .Case<IntegerAttr>([](auto attr) { return attr.getInt() == 0; })
+      .Case<DenseElementsAttr>([](auto attr) {
+        if (!attr.getElementType().isIntOrFloat())
+          return false;
+        if (!attr.isSplat())
+          return false;
+        auto splat = attr.template getSplatValue<Attribute>();
+        return isZeroAttr(splat);
+      })
+      .Default([](auto attr) { return false; });
+}
+
+// Recurses into isZeroOp for defining ops if not immediately obvious.
+// Looks past linalg generic's argument (which don't have defining ops).
+bool isZeroTensor(Value val) {
+  if (!val)
+    return false;
+  if (isValConstZero(val))
+    return true;
+
+  Operation *defOp = nullptr;
+
+  // Block arguments don't have a defining op, but they do have an op arg.
+  if (auto arg = dyn_cast<BlockArgument>(val)) {
+    // We need to find the argument to the linalg on the same order as this one.
+    auto *linalgOp = arg.getParentRegion()->getParentOp();
+    if (!isa<linalg::GenericOp>(linalgOp))
+      return false;
+    auto index = arg.getArgNumber();
+    auto linalgArg = linalgOp->getOperand(index);
+    defOp = linalgArg.getDefiningOp();
+  } else {
+    defOp = val.getDefiningOp();
+  }
+  return isZeroOp(defOp);
+}
+
+// Recurses into isZeroTensor for operands and isZeroAttr for attributes.
+bool isZeroOp(Operation *defOp) {
+  if (!defOp)
+    return false;
+
+  return TypeSwitch<Operation *, bool>(defOp)
+      .Case<arith::ConstantOp>([&](auto op) {
+        // Dense attributes don't match APFloat.isZero().
+        Attribute attr = op.getValue();
+        return isZeroAttr(attr);
+      })
+      .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+        if (op.getInputs().size() != 1)
+          return false;
+        return isZeroTensor(op.getInputs()[0]);
+      })
+      .Case<memref::CopyOp, memref::SubViewOp, tensor::CastOp,
+            tensor::ExtractSliceOp>(
+          [&](auto op) { return isZeroTensor(op.getSource()); })
+      .Case<memref::GetGlobalOp>([&](auto op) {
+        auto name = op.getName();
+        auto module = defOp->getParentOfType<ModuleOp>();
+        auto global = module.lookupSymbol<memref::GlobalOp>(name);
+        auto attr = global.getInitialValueAttr();
+        return isZeroAttr(attr);
+      })
+      .Default([&](Operation *op) { return false; });
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
new file mode 100644
index 00000000000000..4dbe253fd3221b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -0,0 +1,288 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+
+// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul_transpose_b
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_orig_dest_not_additive_zero
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+  %5 = lin...
[truncated]

Copy link

github-actions bot commented Sep 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@rolfmorel rolfmorel force-pushed the fold-add-into-dest branch 2 times, most recently from 65a6874 to 6ac2db5 Compare September 30, 2024 14:59
@rolfmorel
Copy link
Contributor Author

rolfmorel commented Oct 1, 2024

Original downstream PR is here: libxsmm/tpp-mlir#966

Credit to @adam-smnk for suggesting the transform and thanks for the thorough review on the original PR!

Credit for the original isZero utility functions goes to @rengolin and @chelini.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The helpers that you added in Linalg/Utils/Utils.h are not really specific to Linalg and should be moved to somewhere more general. Given the location of isZeroIndex, I suggest StaticValueUtils.h. Also, it feels like some of the logic that you added there should already exist 🤔

I also suggest bailing out early when encountering memrefs. Unless that's something that this is meant to support?

@rolfmorel
Copy link
Contributor Author

Thanks for the comments, @banach-space!

I have simplified the utility functions quite a bit. As I dropped the memref support and restricted the zero check to just what is needed for this pattern, I have put the (now one) zero check in the pattern's file.

If you could have another look at the now simpler code, that would be much appreciated.

As for where utility functions like this should exist, I am interested as well if the functionality already exists somewhere. I have not been able to find it so far (which ties into the question what the right location would be for such functions). 😢

// -----

!type = tensor<2048x2048xf32>
func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very similar to @fold_add_on_two_matmuls. I think that it would be beneficial to group all similar tests together so that it's clear what edge cases have been exercised. To me, this would make sense:

  • @expect_no_fold_as_orig_dest_not_additive_zero -> @expect_no_fold_add_on_matmuls_orig_dest_not_additive_zero

and then move near @fold_add_on_two_matmuls. Similar comment for other tests.

This is a non-blocker, just a kind request :) It makes maintaining tests much easier - I'm learning the hard way while going through Vector dialect tests 😅 Having said all that, these tests are already very nicely formatted, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the function names more specific and have now tried to group them as you suggest.

Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.

Good point, I'm clearly biased towards my favourite my approach :) Your approach also works, but then adding a big comment separating "positive" from "negative" tests would help. It wasn't obvious to me that that's how you organised them.

@banach-space
Copy link
Contributor

If you could have another look at the now simpler code, that would be much appreciated.

I really like how you simplified that, thanks!

As for where utility functions like this should exist, I am interested as well if the functionality already exists somewhere. I have not been able to find it so far (which ties into the question what the right location would be for such functions). 😢

I couldn't anything myself. Lets just be mindful of this and make sure we don't re-implement this anywhere - that wouldn't be the first time 😂

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my comments! My remaining suggestions are nits/nice-to-haves.

Please wait for one more +1 before landing. @adam-smnk , perhaps you could take another look?

Comment on lines 63 to 64
auto firstOperand = addOp.getOperand(0);
auto secondOperand = addOp.getOperand(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I always advocate for long descriptive names, but in this case short rhs and lhs might be even more descriptive :)

Could you spell out auto?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - and done.

PatternRewriter &rewriter) const override {
Value dominatingOperand = nullptr;
linalg::LinalgOp dominatedOp = nullptr;
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the extra braces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to limit the scope of the variables lhs and rhs as after this block all that matters are the values of dominatingOperand and dominatedOp. It is a way to keep myself/the code after this block to the contract that the order of the operands shouldn't matter.

Have added a comment on the block to explain.

// -----

!type = tensor<2048x2048xf32>
func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.

Good point, I'm clearly biased towards my favourite my approach :) Your approach also works, but then adding a big comment separating "positive" from "negative" tests would help. It wasn't obvious to me that that's how you organised them.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks 👍

…op's dest

Replaces a linalg.add with one operand the single user of a contraction,
which has a zero-filled, "identity-mapped" destination and is dominated by
the `other` operand, by the contraction with `other` as its dest.

Benefits include elision of an elementwise op, namely the linalg.add,
and removing a tensor.empty as a destination which is likely to require an
allocation upon bufferization.
@rolfmorel
Copy link
Contributor Author

Thank you for the comments, @adam-smnk and @banach-space - it really helped clean-up the PR!

I just now rebased, squashed and ran a check-all. The PR is ready to merge, I say. If either of you could help with merging, that would be great!

@adam-smnk adam-smnk merged commit 94cf80d into llvm:main Oct 3, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants