Skip to content

Commit 90c1e7b

Browse files
committed
fuse locations for hoisted and merged constants
1 parent 48ca868 commit 90c1e7b

File tree

2 files changed

+91
-5
lines changed

2 files changed

+91
-5
lines changed

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,38 @@
1919

2020
using namespace mlir;
2121

22+
// Fuse `foldedLocation` into the Location of `retainedOp`.
23+
// This will result in `retainedOp` having a FusedLoc with a StringAttr tag
24+
// "OpFold" to help trace the source of the fusion. If `retainedOp` already had
25+
// a FusedLoc with the same tag, `foldedLocation` will simply be appended to it.
26+
// Usage:
27+
// - When an op is deduplicated, fuse the location of the op to be removed into
28+
// the op that is retained.
29+
// - When an op is hoisted to the front/back of a block, fuse the location of
30+
// the parent region of the block into the hoisted op.
31+
static void appendFoldedLocation(Operation *retainedOp,
32+
Location foldedLocation) {
33+
constexpr std::string_view tag = "OpFold";
34+
// Append into existing fused location if it has the same tag.
35+
if (auto existingFusedLoc =
36+
retainedOp->getLoc().dyn_cast<FusedLocWith<StringAttr>>()) {
37+
StringAttr existingMetadata = existingFusedLoc.getMetadata();
38+
if (existingMetadata.strref().equals(tag)) {
39+
SmallVector<Location> locations(existingFusedLoc.getLocations());
40+
locations.push_back(foldedLocation);
41+
Location newFusedLoc =
42+
FusedLoc::get(retainedOp->getContext(), locations, existingMetadata);
43+
retainedOp->setLoc(newFusedLoc);
44+
return;
45+
}
46+
}
47+
// Create a new fusedloc with retainedOp's loc and foldedLocation.
48+
Location newFusedLoc = FusedLoc::get(
49+
retainedOp->getContext(), {retainedOp->getLoc(), foldedLocation},
50+
StringAttr::get(retainedOp->getContext(), tag));
51+
retainedOp->setLoc(newFusedLoc);
52+
}
53+
2254
/// Given an operation, find the parent region that folded constants should be
2355
/// inserted into.
2456
static Region *
@@ -77,8 +109,10 @@ LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
77109
// Check to see if we should rehoist, i.e. if a non-constant operation was
78110
// inserted before this one.
79111
Block *opBlock = op->getBlock();
80-
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
112+
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
81113
op->moveBefore(&opBlock->front());
114+
appendFoldedLocation(op, opBlock->getParent()->getLoc());
115+
}
82116
return failure();
83117
}
84118

@@ -112,8 +146,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
112146
// If this is a constant we unique'd, we don't need to insert, but we can
113147
// check to see if we should rehoist it.
114148
if (isFolderOwnedConstant(op)) {
115-
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
149+
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
116150
op->moveBefore(&opBlock->front());
151+
appendFoldedLocation(op, opBlock->getParent()->getLoc());
152+
}
117153
return true;
118154
}
119155

@@ -141,6 +177,7 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
141177
// If there is an existing constant, replace `op`.
142178
if (folderConstOp) {
143179
notifyRemoval(op);
180+
appendFoldedLocation(folderConstOp, op->getLoc());
144181
rewriter.replaceOp(op, folderConstOp->getResults());
145182
return false;
146183
}
@@ -151,8 +188,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
151188
// anything. Otherwise, we move the constant to the insertion block.
152189
Block *insertBlock = &insertRegion->front();
153190
if (opBlock != insertBlock || (&insertBlock->front() != op &&
154-
!isFolderOwnedConstant(op->getPrevNode())))
191+
!isFolderOwnedConstant(op->getPrevNode()))) {
155192
op->moveBefore(&insertBlock->front());
193+
appendFoldedLocation(op, insertBlock->getParent()->getLoc());
194+
}
156195

157196
folderConstOp = op;
158197
referencedDialects[op].push_back(op->getDialect());
@@ -264,8 +303,10 @@ OperationFolder::processFoldResults(Operation *op,
264303
// with. This may not automatically happen if the operation being folded
265304
// was inserted before the constant within the insertion block.
266305
Block *opBlock = op->getBlock();
267-
if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
306+
if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) {
268307
constOp->moveBefore(&opBlock->front());
308+
appendFoldedLocation(constOp, opBlock->getParent()->getLoc());
309+
}
269310

270311
results.push_back(constOp->getResult(0));
271312
continue;
@@ -294,8 +335,10 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
294335
// Check if an existing mapping already exists.
295336
auto constKey = std::make_tuple(dialect, value, type);
296337
Operation *&constOp = uniquedConstants[constKey];
297-
if (constOp)
338+
if (constOp) {
339+
appendFoldedLocation(constOp, loc);
298340
return constOp;
341+
}
299342

300343
// If one doesn't exist, try to materialize one.
301344
if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
@@ -316,6 +359,7 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
316359
// materialized operation in favor of the existing one.
317360
if (auto *existingOp = uniquedConstants.lookup(newKey)) {
318361
notifyRemoval(constOp);
362+
appendFoldedLocation(existingOp, constOp->getLoc());
319363
rewriter.eraseOp(constOp);
320364
referencedDialects[existingOp].push_back(dialect);
321365
return constOp = existingOp;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
2+
3+
// CHECK-LABEL: func @merge_constants
4+
func.func @merge_constants() -> (index, index, index) {
5+
// CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
6+
%0 = arith.constant 42 : index loc("merge_constants":0:0)
7+
%1 = arith.constant 42 : index loc("merge_constants":1:0)
8+
%2 = arith.constant 42 : index loc("merge_constants":2:0)
9+
return %0, %1, %2: index, index, index
10+
}
11+
12+
// CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
13+
// CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
14+
// CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
15+
16+
// CHECK: #[[FusedLoc]] = loc(fused<"OpFold">[
17+
// CHECK-SAME: #[[LocConst0]]
18+
// CHECK-SAME: #[[LocConst1]]
19+
// CHECK-SAME: #[[LocConst2]]
20+
21+
// -----
22+
23+
// CHECK-LABEL: func @hoist_constant
24+
func.func @hoist_constant(%arg0: memref<8xi32>) {
25+
// CHECK-NEXT: arith.constant 42 : i32 loc(#[[FusedWithFunction:.*]])
26+
affine.for %arg1 = 0 to 8 {
27+
%0 = arith.constant 42 : i32 loc("hoist_constant":0:0)
28+
%1 = arith.constant 42 : i32 loc("hoist_constant":1:0)
29+
memref.store %0, %arg0[%arg1] : memref<8xi32>
30+
memref.store %1, %arg0[%arg1] : memref<8xi32>
31+
}
32+
return
33+
} loc("hoist_constant":2:0)
34+
35+
// CHECK-DAG: #[[LocConst0:.*]] = loc("hoist_constant":0:0)
36+
// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant":1:0)
37+
// CHECK-DAG: #[[LocFunc:.*]] = loc("hoist_constant":2:0)
38+
39+
// CHECK: #[[FusedWithFunction]] = loc(fused<"OpFold">[
40+
// CHECK-SAME: #[[LocConst0]]
41+
// CHECK-SAME: #[[LocFunc]]
42+
// CHECK-SAME: #[[LocConst1]]

0 commit comments

Comments
 (0)