-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest #110514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesReplaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization. Patch is 23.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110514.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..a997502c34299c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.linalg.fold_add_into_dest",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects patterns to replace linalg.add when destination passing suffices
+ for achieving the sum.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 48e657cca96e39..cc12ed7cfa6b54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
const ControlFusionFn &controlFn);
+/// Pattern to replace `linalg.add` when destination passing on a contraction op
+/// suffices for achieving the sum.
+void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1e4f3004dec7e7..1d2759d2a91db1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -38,6 +38,12 @@ namespace linalg {
// General utilities
//===----------------------------------------------------------------------===//
+// Returns true if `val` represents a zero-filled tensor, per its defining op.
+bool isZeroTensor(Value val);
+
+// Returns true if the operation defines a zero-filled tensor.
+bool isZeroOp(Operation *);
+
/// Check if all indexing maps are projected permutations.
bool allIndexingsAreProjectedPermutation(LinalgOp op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 46c8510f4ed514..3b7b367d3cf2d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}
+void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateFoldAddIntoDestPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 47af392def94ac..b3cd5537aad9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ElementwiseToLinalg.cpp
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
+ FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
new file mode 100644
index 00000000000000..d8c4e338fddbbc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -0,0 +1,108 @@
+//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+/// Replace a linalg.add with one operand the single user of a contraction,
+/// which has a zero-filled, "identity-mapped" destination and is dominated by
+/// the `other` operand, by the contraction with `other` as its dest.
+struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
+ using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::AddOp addOp,
+ PatternRewriter &rewriter) const override {
+ Value dominatingOperand = nullptr;
+ linalg::LinalgOp dominatedOp = nullptr;
+ {
+ auto firstOperand = addOp.getOperand(0);
+ auto secondOperand = addOp.getOperand(1);
+
+ // Can only put one of addOp's operands in the dest/out arg of the other's
+ // defining op based on suitable dominance.
+ if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(secondOp);
+ if (domInfo.properlyDominates(firstOperand, secondOp)) {
+ dominatingOperand = firstOperand;
+ dominatedOp = secondOp;
+ }
+ }
+ if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
+ DominanceInfo domInfo(firstOp);
+ if (domInfo.properlyDominates(secondOperand, firstOp)) {
+ dominatingOperand = secondOperand;
+ dominatedOp = firstOp;
+ }
+ }
+ if (!dominatingOperand || !dominatedOp)
+ return failure();
+ // NB: As linalg.add's generalisation ignores the out argument in its
+ // region there is no need to perform checks on addOp's out argument.
+ }
+
+ // Dominated op must be a contraction for it to accumulate on its out arg.
+ // E.g., AddOp is not a contraction and hence ignores its out arg's value.
+ auto dominatedDestOp =
+ dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
+ if (dominatedOp->getNumResults() != 1 ||
+ !linalg::isaContractionOpInterface(dominatedOp) ||
+ (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op to be single-result "
+ "destination-passing contraction");
+
+ // To change the contraction's result, `addOp` must be its only user.
+ if (!dominatedOp->getResult(0).hasOneUse())
+ return rewriter.notifyMatchFailure(
+ dominatedOp,
+ "expected linalg.add to be single user of contraction's result");
+
+ // As `dominatedOp` was already accumulating on its out argument, it is only
+ // safe to no longer use its current out arg when it is the additive zero.
+ auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
+ if (!linalg::isZeroTensor(destOperand->get()))
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected dominated op's dest to be additive zero");
+ // TODO: If the other op is a contraction and has additive zero as dest, we
+ // can swap the dests and achieve the proper sum, given suitable dominance.
+
+ // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
+ // Hence, we can only substitute `dominatingOperand` for the dest of the
+ // contraction when dest's indexing_map corresponds to an identity map
+ // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
+ SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
+ int prevDimPos = -1;
+ for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ if (!dim || prevDimPos >= (int)dim.getPosition())
+ return rewriter.notifyMatchFailure(
+ dominatedOp, "expected index_map for contraction's dest to be an "
+ "ordered projection");
+ prevDimPos = dim.getPosition();
+ }
+
+ // Replace the additive-zero out argument of the dominated op by the
+ // dominating summand. This makes the dominated op's result the sum of both
+ // of addOp's arguments - therefore we replace addOp and it uses by it.
+ rewriter.modifyOpInPlace(
+ dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
+ rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
+ return success();
+ }
+};
+
+void linalg::populateFoldAddIntoDestPatterns(
+ RewritePatternSet &patterns) {
+ // Replace linalg.add when destination passing suffices for achieving the sum.
+ patterns.add<FoldAddIntoDest>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 38e427af1c4846..a6a9ca5fd66330 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -870,5 +870,80 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
return reassociation;
}
+// Returns true if the value is a constant float or integer.
+bool isValConstZero(Value val) {
+ return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero());
+}
+
+// Returns true if the attribute represent "all zeros".
+static bool isZeroAttr(Attribute attribute) {
+ return TypeSwitch<Attribute, bool>(attribute)
+ .Case<FloatAttr>([](auto attr) { return attr.getValueAsDouble() == 0.0; })
+ .Case<IntegerAttr>([](auto attr) { return attr.getInt() == 0; })
+ .Case<DenseElementsAttr>([](auto attr) {
+ if (!attr.getElementType().isIntOrFloat())
+ return false;
+ if (!attr.isSplat())
+ return false;
+ auto splat = attr.template getSplatValue<Attribute>();
+ return isZeroAttr(splat);
+ })
+ .Default([](auto attr) { return false; });
+}
+
+// Recurses into isZeroOp for defining ops if not immediately obvious.
+// Looks past linalg generic's argument (which don't have defining ops).
+bool isZeroTensor(Value val) {
+ if (!val)
+ return false;
+ if (isValConstZero(val))
+ return true;
+
+ Operation *defOp = nullptr;
+
+ // Block arguments don't have a defining op, but they do have an op arg.
+ if (auto arg = dyn_cast<BlockArgument>(val)) {
+ // We need to find the argument to the linalg on the same order as this one.
+ auto *linalgOp = arg.getParentRegion()->getParentOp();
+ if (!isa<linalg::GenericOp>(linalgOp))
+ return false;
+ auto index = arg.getArgNumber();
+ auto linalgArg = linalgOp->getOperand(index);
+ defOp = linalgArg.getDefiningOp();
+ } else {
+ defOp = val.getDefiningOp();
+ }
+ return isZeroOp(defOp);
+}
+
+// Recurses into isZeroTensor for operands and isZeroAttr for attributes.
+bool isZeroOp(Operation *defOp) {
+ if (!defOp)
+ return false;
+
+ return TypeSwitch<Operation *, bool>(defOp)
+ .Case<arith::ConstantOp>([&](auto op) {
+ // Dense attributes don't match APFloat.isZero().
+ Attribute attr = op.getValue();
+ return isZeroAttr(attr);
+ })
+ .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+ if (op.getInputs().size() != 1)
+ return false;
+ return isZeroTensor(op.getInputs()[0]);
+ })
+ .Case<memref::CopyOp, memref::SubViewOp, tensor::CastOp,
+ tensor::ExtractSliceOp>(
+ [&](auto op) { return isZeroTensor(op.getSource()); })
+ .Case<memref::GetGlobalOp>([&](auto op) {
+ auto name = op.getName();
+ auto module = defOp->getParentOfType<ModuleOp>();
+ auto global = module.lookupSymbol<memref::GlobalOp>(name);
+ auto attr = global.getInitialValueAttr();
+ return isZeroAttr(attr);
+ })
+ .Default([&](Operation *op) { return false; });
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
new file mode 100644
index 00000000000000..4dbe253fd3221b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -0,0 +1,288 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_add_to_fold(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_add_to_fold
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+
+// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul_transpose_b
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_orig_dest_not_additive_zero
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type
+ %5 = lin...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
65a6874
to
6ac2db5
Compare
Original downstream PR is here: libxsmm/tpp-mlir#966 Credit to @adam-smnk for suggesting the transform and thanks for the thorough review on the original PR! Credit for the original isZero utility functions goes to @rengolin and @chelini. |
89e1ec6
to
0232dea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The helpers that you added in Linalg/Utils/Utils.h are not really specific to Linalg and should be moved to somewhere more general. Given the location of isZeroIndex
, I suggest StaticValueUtils.h. Also, it feels like some of the logic that you added there should already exist 🤔
I also suggest bailing out early when encountering memrefs. Unless that's something that this is meant to support?
Thanks for the comments, @banach-space! I have simplified the utility functions quite a bit. As I dropped the memref support and restricted the zero check to just what is needed for this pattern, I have put the (now one) zero check in the pattern's file. If you could have another look at the now simpler code, that would be much appreciated. As for where utility functions like this should exist, I am interested as well if the functionality already exists somewhere. I have not been able to find it so far (which ties into the question what the right location would be for such functions). 😢 |
// ----- | ||
|
||
!type = tensor<2048x2048xf32> | ||
func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very similar to @fold_add_on_two_matmuls
. I think that it would be beneficial to group all similar tests together so that it's clear what edge cases have been exercised. To me, this would make sense:
@expect_no_fold_as_orig_dest_not_additive_zero
->@expect_no_fold_add_on_matmuls_orig_dest_not_additive_zero
and then move near @fold_add_on_two_matmuls
. Similar comment for other tests.
This is a non-blocker, just a kind request :) It makes maintaining tests much easier - I'm learning the hard way while going through Vector dialect tests 😅 Having said all that, these tests are already very nicely formatted, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made the function names more specific and have now tried to group them as you suggest.
Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.
Good point, I'm clearly biased towards my favourite my approach :) Your approach also works, but then adding a big comment separating "positive" from "negative" tests would help. It wasn't obvious to me that that's how you organised them.
I really like how you simplified that, thanks!
I couldn't anything myself. Lets just be mindful of this and make sure we don't re-implement this anywhere - that wouldn't be the first time 😂 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing my comments! My remaining suggestions are nits/nice-to-haves.
Please wait for one more +1 before landing. @adam-smnk , perhaps you could take another look?
auto firstOperand = addOp.getOperand(0); | ||
auto secondOperand = addOp.getOperand(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] I always advocate for long descriptive names, but in this case short rhs
and lhs
might be even more descriptive :)
Could you spell out auto
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - and done.
PatternRewriter &rewriter) const override { | ||
Value dominatingOperand = nullptr; | ||
linalg::LinalgOp dominatedOp = nullptr; | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the extra braces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to limit the scope of the variables lhs
and rhs
as after this block all that matters are the values of dominatingOperand
and dominatedOp
. It is a way to keep myself/the code after this block to the contract that the order of the operands shouldn't matter.
Have added a comment on the block to explain.
// ----- | ||
|
||
!type = tensor<2048x2048xf32> | ||
func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.
Good point, I'm clearly biased towards my favourite my approach :) Your approach also works, but then adding a big comment separating "positive" from "negative" tests would help. It wasn't obvious to me that that's how you organised them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks 👍
8e44bcd
to
1be694b
Compare
…op's dest Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the `other` operand, by the contraction with `other` as its dest. Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.
1be694b
to
8636f5f
Compare
Thank you for the comments, @adam-smnk and @banach-space - it really helped clean-up the PR! I just now rebased, squashed and ran a |
Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the
other
operand, by the contraction withother
as its dest.Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.