Skip to content

Commit 94cf80d

Browse files
authored
[MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (llvm#110514)
Replaces a linalg.add with one operand the single user of a contraction, which has a zero-filled, "identity-mapped" destination and is dominated by the `other` operand, by the contraction with `other` as its dest. Benefits include elision of an elementwise op, namely the linalg.add, and removing a tensor.empty as a destination which is likely to require an allocation upon bufferization.
1 parent 870bdc6 commit 94cf80d

File tree

6 files changed

+500
-0
lines changed

6 files changed

+500
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
7373
let assemblyFormat = "attr-dict";
7474
}
7575

76+
def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
77+
"apply_patterns.linalg.fold_add_into_dest",
78+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
79+
let description = [{
80+
Collects patterns to replace linalg.add when destination passing suffices
81+
for achieving the sum.
82+
}];
83+
84+
let assemblyFormat = "attr-dict";
85+
}
86+
7687
//===----------------------------------------------------------------------===//
7788
// BufferizeToAllocationOp
7889
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns(
17471747
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
17481748
const ControlFusionFn &controlFn);
17491749

1750+
/// Pattern to replace `linalg.add` when destination passing on a contraction op
1751+
/// suffices for achieving the sum.
1752+
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
1753+
17501754
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
17511755
/// if the producer is a `linalg` operation with all parallel iterator types.
17521756
void populateFuseTensorPadWithProducerLinalgOpPatterns(

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
248248
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
249249
}
250250

251+
void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
252+
RewritePatternSet &patterns) {
253+
linalg::populateFoldAddIntoDestPatterns(patterns);
254+
}
255+
251256
//===----------------------------------------------------------------------===//
252257
// BufferizeToAllocationOp
253258
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
1313
ElementwiseToLinalg.cpp
1414
EliminateEmptyTensors.cpp
1515
EraseUnusedOperandsAndResults.cpp
16+
FoldAddIntoDest.cpp
1617
FusePadOpWithLinalgProducer.cpp
1718
Fusion.cpp
1819
Generalization.cpp
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
11+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12+
#include "mlir/IR/Dominance.h"
13+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
14+
15+
using namespace mlir;
16+
17+
// Determine whether the value is defined to be zero.
18+
static bool isDefinedAsZero(Value val) {
19+
if (!val)
20+
return false;
21+
22+
// Check whether val is a constant scalar / vector splat / tensor splat float
23+
// or integer zero.
24+
if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
25+
return true;
26+
27+
return TypeSwitch<Operation *, bool>(val.getDefiningOp())
28+
.Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
29+
return op && op.getInputs().size() == 1 &&
30+
isDefinedAsZero(op.getInputs()[0]);
31+
})
32+
.Default([&](auto) { return false; });
33+
}
34+
35+
/// Replace a linalg.add with one operand the single user of a contraction,
36+
/// which has a zero-filled, "identity-mapped" destination and is dominated by
37+
/// the `other` operand, by the contraction with `other` as its dest.
38+
///
39+
/// As an example, the following pseudo-code will be rewritten
40+
/// %cst = arith.constant 0.000000e+00
41+
/// %empty = tensor.empty()
42+
/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
43+
/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
44+
/// %empty2 = tensor.empty()
45+
/// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
46+
/// %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
47+
/// %out = linalg.add ins(%C, %F) outs(%empty)
48+
/// to:
49+
/// %cst = arith.constant 0.000000e+00
50+
/// %empty = tensor.empty()
51+
/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
52+
/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
53+
/// %out = linalg.matmul ins(%D, %E) outs(%C)
54+
///
55+
struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
56+
using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
57+
58+
LogicalResult matchAndRewrite(linalg::AddOp addOp,
59+
PatternRewriter &rewriter) const override {
60+
// For now, pattern only applies on tensor types (memref support is TODO).
61+
if (!addOp.hasPureTensorSemantics())
62+
return failure();
63+
64+
Value dominatingOperand = nullptr;
65+
linalg::LinalgOp dominatedOp = nullptr;
66+
{ // We will forget about which operand was left or right after this block.
67+
Value lhs = addOp.getInputs()[0];
68+
Value rhs = addOp.getInputs()[1];
69+
70+
// Can only put one of addOp's operands in the dest/out arg of the other's
71+
// defining op based on suitable dominance.
72+
// TODO: Can be generalized to move ops around as long as that still
73+
// respects use-def chains and doesn't affect side-effects.
74+
if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) {
75+
DominanceInfo domInfo(rhsOp);
76+
if (domInfo.properlyDominates(lhs, rhsOp)) {
77+
dominatingOperand = lhs;
78+
dominatedOp = rhsOp;
79+
}
80+
}
81+
if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) {
82+
DominanceInfo domInfo(lhsOp);
83+
if (domInfo.properlyDominates(rhs, lhsOp)) {
84+
dominatingOperand = rhs;
85+
dominatedOp = lhsOp;
86+
}
87+
}
88+
if (!dominatingOperand || !dominatedOp)
89+
return failure();
90+
// NB: As linalg.add's generalisation ignores the out argument in its
91+
// region there is no need to perform checks on addOp's out argument.
92+
}
93+
94+
// When dominated op is a contraction we know it accumulates on its out arg.
95+
// E.g., AddOp is not a contraction and hence ignores its out arg's value.
96+
// TODO: Generalize check to also pass in case of other LinalgOps that
97+
// accumulate on their out arg but are not (binary) contraction ops.
98+
auto dominatedDestOp =
99+
dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
100+
if (dominatedOp->getNumResults() != 1 ||
101+
!linalg::isaContractionOpInterface(dominatedOp) ||
102+
(!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
103+
return rewriter.notifyMatchFailure(
104+
dominatedOp, "expected dominated op to be single-result "
105+
"destination-passing contraction");
106+
107+
// To change the contraction's result, `addOp` must be its only user.
108+
if (!dominatedOp->getResult(0).hasOneUse())
109+
return rewriter.notifyMatchFailure(
110+
dominatedOp,
111+
"expected linalg.add to be single user of contraction's result");
112+
113+
// As `dominatedOp` was already accumulating on its out argument, it is only
114+
// safe to no longer use its current out arg when it is the additive ident.
115+
auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
116+
if (!isDefinedAsZero(destOperand->get()))
117+
return rewriter.notifyMatchFailure(
118+
dominatedOp, "expected dominated op's dest to be additive zero");
119+
// TODO: If the other op is a contraction and has additive ident as dest, we
120+
// can swap the dests and achieve the proper sum, given suitable dominance.
121+
122+
// As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
123+
// Hence, we can only substitute `dominatingOperand` for the dest of the
124+
// contraction when dest's indexing_map corresponds to an identity map
125+
// w.r.t. just the dimensions of dest, i.e. is an ordered projection.
126+
SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
127+
int prevDimPos = -1;
128+
for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
129+
auto dim = dyn_cast<AffineDimExpr>(expr);
130+
if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
131+
return rewriter.notifyMatchFailure(
132+
dominatedOp, "expected index_map for contraction's dest to be an "
133+
"ordered projection");
134+
prevDimPos = dim.getPosition();
135+
}
136+
137+
// Replace the additive-ident, i.e. zero, out arg of the dominated op by the
138+
// dominating summand. This makes the dominated op's result the sum of both
139+
// of addOp's arguments - therefore we replace addOp and it uses by it.
140+
rewriter.modifyOpInPlace(
141+
dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
142+
rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
143+
return success();
144+
}
145+
};
146+
147+
void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) {
148+
// Replace linalg.add when destination passing suffices for achieving the sum.
149+
patterns.add<FoldAddIntoDest>(patterns.getContext());
150+
}

0 commit comments

Comments
 (0)