Skip to content

Commit 65a6874

Browse files
committed
[MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest
Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the `other` operand, by the contraction with `other` as its dest. Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.
1 parent 8f50dbd commit 65a6874

File tree

8 files changed

+497
-0
lines changed

8 files changed

+497
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
7373
let assemblyFormat = "attr-dict";
7474
}
7575

76+
def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
77+
"apply_patterns.linalg.fold_add_into_dest",
78+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
79+
let description = [{
80+
Collects patterns to replace linalg.add when destination passing suffices
81+
for achieving the sum.
82+
}];
83+
84+
let assemblyFormat = "attr-dict";
85+
}
86+
7687
//===----------------------------------------------------------------------===//
7788
// BufferizeToAllocationOp
7889
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
17471747
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
17481748
const ControlFusionFn &controlFn);
17491749

1750+
/// Pattern to replace `linalg.add` when destination passing on a contraction op
1751+
/// suffices for achieving the sum.
1752+
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
1753+
17501754
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
17511755
/// if the producer is a `linalg` operation with all parallel iterator types.
17521756
void populateFuseTensorPadWithProducerLinalgOpPatterns(

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ namespace linalg {
3838
// General utilities
3939
//===----------------------------------------------------------------------===//
4040

41+
// Returns true if `val` represents a zero-filled tensor, per its defining op.
42+
bool isZeroTensor(Value val);
43+
44+
// Returns true if the operation defines a zero-filled tensor.
45+
bool isZeroOp(Operation *);
46+
4147
/// Check if all indexing maps are projected permutations.
4248
bool allIndexingsAreProjectedPermutation(LinalgOp op);
4349

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
248248
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
249249
}
250250

251+
void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
252+
RewritePatternSet &patterns) {
253+
linalg::populateFoldAddIntoDestPatterns(patterns);
254+
}
255+
251256
//===----------------------------------------------------------------------===//
252257
// BufferizeToAllocationOp
253258
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
1313
ElementwiseToLinalg.cpp
1414
EliminateEmptyTensors.cpp
1515
EraseUnusedOperandsAndResults.cpp
16+
FoldAddIntoDest.cpp
1617
FusePadOpWithLinalgProducer.cpp
1718
Fusion.cpp
1819
Generalization.cpp
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
11+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12+
#include "mlir/IR/Dominance.h"
13+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
14+
15+
using namespace mlir;
16+
17+
/// Replace a linalg.add with one operand the single user of a contraction,
18+
/// which has a zero-filled, "identity-mapped" destination and is dominated by
19+
/// the `other` operand, by the contraction with `other` as its dest.
20+
struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
21+
using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
22+
23+
LogicalResult matchAndRewrite(linalg::AddOp addOp,
24+
PatternRewriter &rewriter) const override {
25+
Value dominatingOperand = nullptr;
26+
linalg::LinalgOp dominatedOp = nullptr;
27+
{
28+
auto firstOperand = addOp.getOperand(0);
29+
auto secondOperand = addOp.getOperand(1);
30+
31+
// Can only put one of addOp's operands in the dest/out arg of the other's
32+
// defining op based on suitable dominance.
33+
if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
34+
DominanceInfo domInfo(secondOp);
35+
if (domInfo.properlyDominates(firstOperand, secondOp)) {
36+
dominatingOperand = firstOperand;
37+
dominatedOp = secondOp;
38+
}
39+
}
40+
if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
41+
DominanceInfo domInfo(firstOp);
42+
if (domInfo.properlyDominates(secondOperand, firstOp)) {
43+
dominatingOperand = secondOperand;
44+
dominatedOp = firstOp;
45+
}
46+
}
47+
if (!dominatingOperand || !dominatedOp)
48+
return failure();
49+
// NB: As linalg.add's generalisation ignores the out argument in its
50+
// region there is no need to perform checks on addOp's out argument.
51+
}
52+
53+
// Dominated op must be a contraction for it to accumulate on its out arg.
54+
// E.g., AddOp is not a contraction and hence ignores its out arg's value.
55+
auto dominatedDestOp =
56+
dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
57+
if (dominatedOp->getNumResults() != 1 ||
58+
!linalg::isaContractionOpInterface(dominatedOp) ||
59+
(!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
60+
return rewriter.notifyMatchFailure(
61+
dominatedOp, "expected dominated op to be single-result "
62+
"destination-passing contraction");
63+
64+
// To change the contraction's result, `addOp` must be its only user.
65+
if (!dominatedOp->getResult(0).hasOneUse())
66+
return rewriter.notifyMatchFailure(
67+
dominatedOp,
68+
"expected linalg.add to be single user of contraction's result");
69+
70+
// As `dominatedOp` was already accumulating on its out argument, it is only
71+
// safe to no longer use its current out arg when it is the additive zero.
72+
auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
73+
if (!linalg::isZeroTensor(destOperand->get()))
74+
return rewriter.notifyMatchFailure(
75+
dominatedOp, "expected dominated op's dest to be additive zero");
76+
// TODO: If the other op is a contraction and has additive zero as dest, we
77+
// can swap the dests and achieve the proper sum, given suitable dominance.
78+
79+
// As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
80+
// Hence, we can only substitute `dominatingOperand` for the dest of the
81+
// contraction when dest's indexing_map corresponds to an identity map
82+
// w.r.t. just the dimensions of dest, i.e. is an ordered projection.
83+
SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
84+
int prevDimPos = -1;
85+
for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
86+
auto dim = dyn_cast<AffineDimExpr>(expr);
87+
if (!dim || prevDimPos >= (int)dim.getPosition())
88+
return rewriter.notifyMatchFailure(
89+
dominatedOp, "expected index_map for contraction's dest to be an "
90+
"ordered projection");
91+
prevDimPos = dim.getPosition();
92+
}
93+
94+
// Replace the additive-zero out argument of the dominated op by the
95+
// dominating summand. This makes the dominated op's result the sum of both
96+
// of addOp's arguments - therefore we replace addOp and it uses by it.
97+
rewriter.modifyOpInPlace(
98+
dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
99+
rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
100+
return success();
101+
}
102+
};
103+
104+
void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) {
105+
// Replace linalg.add when destination passing suffices for achieving the sum.
106+
patterns.add<FoldAddIntoDest>(patterns.getContext());
107+
}

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,5 +870,80 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
870870
return reassociation;
871871
}
872872

873+
// Returns true if the value is a constant float or integer.
874+
bool isValConstZero(Value val) {
875+
return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero());
876+
}
877+
878+
// Returns true if the attribute represent "all zeros".
879+
static bool isZeroAttr(Attribute attribute) {
880+
return TypeSwitch<Attribute, bool>(attribute)
881+
.Case<FloatAttr>([](auto attr) { return attr.getValueAsDouble() == 0.0; })
882+
.Case<IntegerAttr>([](auto attr) { return attr.getInt() == 0; })
883+
.Case<DenseElementsAttr>([](auto attr) {
884+
if (!attr.getElementType().isIntOrFloat())
885+
return false;
886+
if (!attr.isSplat())
887+
return false;
888+
auto splat = attr.template getSplatValue<Attribute>();
889+
return isZeroAttr(splat);
890+
})
891+
.Default([](auto attr) { return false; });
892+
}
893+
894+
// Recurses into isZeroOp for defining ops if not immediately obvious.
895+
// Looks past linalg generic's argument (which don't have defining ops).
896+
bool isZeroTensor(Value val) {
897+
if (!val)
898+
return false;
899+
if (isValConstZero(val))
900+
return true;
901+
902+
Operation *defOp = nullptr;
903+
904+
// Block arguments don't have a defining op, but they do have an op arg.
905+
if (auto arg = dyn_cast<BlockArgument>(val)) {
906+
// We need to find the argument to the linalg on the same order as this one.
907+
auto *linalgOp = arg.getParentRegion()->getParentOp();
908+
if (!isa<linalg::GenericOp>(linalgOp))
909+
return false;
910+
auto index = arg.getArgNumber();
911+
auto linalgArg = linalgOp->getOperand(index);
912+
defOp = linalgArg.getDefiningOp();
913+
} else {
914+
defOp = val.getDefiningOp();
915+
}
916+
return isZeroOp(defOp);
917+
}
918+
919+
// Recurses into isZeroTensor for operands and isZeroAttr for attributes.
920+
bool isZeroOp(Operation *defOp) {
921+
if (!defOp)
922+
return false;
923+
924+
return TypeSwitch<Operation *, bool>(defOp)
925+
.Case<arith::ConstantOp>([&](auto op) {
926+
// Dense attributes don't match APFloat.isZero().
927+
Attribute attr = op.getValue();
928+
return isZeroAttr(attr);
929+
})
930+
.Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
931+
if (op.getInputs().size() != 1)
932+
return false;
933+
return isZeroTensor(op.getInputs()[0]);
934+
})
935+
.Case<memref::CopyOp, memref::SubViewOp, tensor::CastOp,
936+
tensor::ExtractSliceOp>(
937+
[&](auto op) { return isZeroTensor(op.getSource()); })
938+
.Case<memref::GetGlobalOp>([&](auto op) {
939+
auto name = op.getName();
940+
auto module = defOp->getParentOfType<ModuleOp>();
941+
auto global = module.lookupSymbol<memref::GlobalOp>(name);
942+
auto attr = global.getInitialValueAttr();
943+
return isZeroAttr(attr);
944+
})
945+
.Default([&](Operation *op) { return false; });
946+
}
947+
873948
} // namespace linalg
874949
} // namespace mlir

0 commit comments

Comments
 (0)