Skip to content

Commit 34a6598

Browse files
authored
[MLIR] Erase location of folded constants (#75415)
Follow up to the discussion from #75258, and serves as an alternate solution for #74670. Set the location to Unknown for deduplicated / moved / materialized constants by OperationFolder. This makes sure that the folded constants don't end up with an arbitrary location of one of the original ops that became it, and that hoisted ops don't confuse the stepping order.
1 parent a4e1541 commit 34a6598

File tree

7 files changed

+120
-18
lines changed

7 files changed

+120
-18
lines changed

mlir/include/mlir/Transforms/FoldUtils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class Value;
3333
class OperationFolder {
3434
public:
3535
OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
36-
: interfaces(ctx), rewriter(ctx, listener) {}
36+
: erasedFoldedLocation(UnknownLoc::get(ctx)), interfaces(ctx),
37+
rewriter(ctx, listener) {}
3738

3839
/// Tries to perform folding on the given `op`, including unifying
3940
/// deduplicated constants. If successful, replaces `op`'s uses with
@@ -65,7 +66,7 @@ class OperationFolder {
6566
/// be created in a parent block. On success this returns the constant
6667
/// operation, nullptr otherwise.
6768
Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value,
68-
Type type, Location loc);
69+
Type type);
6970

7071
private:
7172
/// This map keeps track of uniqued constants by dialect, attribute, and type.
@@ -95,6 +96,9 @@ class OperationFolder {
9596
Dialect *dialect, Attribute value,
9697
Type type, Location loc);
9798

99+
/// The location to overwrite with for folder-owned constants.
100+
UnknownLoc erasedFoldedLocation;
101+
98102
/// A mapping between an insertion region and the constants that have been
99103
/// created within it.
100104
DenseMap<Region *, ConstantMap> foldScopes;

mlir/lib/Transforms/SCCP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver,
5353
Dialect *dialect = latticeValue.getConstantDialect();
5454
Value constant = folder.getOrCreateConstant(
5555
builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
56-
value.getType(), value.getLoc());
56+
value.getType());
5757
if (!constant)
5858
return failure();
5959

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
7777
// Check to see if we should rehoist, i.e. if a non-constant operation was
7878
// inserted before this one.
7979
Block *opBlock = op->getBlock();
80-
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
80+
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
8181
op->moveBefore(&opBlock->front());
82+
op->setLoc(erasedFoldedLocation);
83+
}
8284
return failure();
8385
}
8486

@@ -112,8 +114,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
112114
// If this is a constant we unique'd, we don't need to insert, but we can
113115
// check to see if we should rehoist it.
114116
if (isFolderOwnedConstant(op)) {
115-
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
117+
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
116118
op->moveBefore(&opBlock->front());
119+
op->setLoc(erasedFoldedLocation);
120+
}
117121
return true;
118122
}
119123

@@ -142,6 +146,7 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
142146
if (folderConstOp) {
143147
notifyRemoval(op);
144148
rewriter.replaceOp(op, folderConstOp->getResults());
149+
folderConstOp->setLoc(erasedFoldedLocation);
145150
return false;
146151
}
147152

@@ -151,8 +156,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
151156
// anything. Otherwise, we move the constant to the insertion block.
152157
Block *insertBlock = &insertRegion->front();
153158
if (opBlock != insertBlock || (&insertBlock->front() != op &&
154-
!isFolderOwnedConstant(op->getPrevNode())))
159+
!isFolderOwnedConstant(op->getPrevNode()))) {
155160
op->moveBefore(&insertBlock->front());
161+
op->setLoc(erasedFoldedLocation);
162+
}
156163

157164
folderConstOp = op;
158165
referencedDialects[op].push_back(op->getDialect());
@@ -193,17 +200,17 @@ void OperationFolder::clear() {
193200
/// Get or create a constant using the given builder. On success this returns
194201
/// the constant operation, nullptr otherwise.
195202
Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
196-
Attribute value, Type type,
197-
Location loc) {
203+
Attribute value, Type type) {
198204
// Find an insertion point for the constant.
199205
auto *insertRegion = getInsertionRegion(interfaces, block);
200206
auto &entry = insertRegion->front();
201207
rewriter.setInsertionPoint(&entry, entry.begin());
202208

203209
// Get the constant map for the insertion region of this operation.
210+
// Use erased location since the op is being built at the front of block.
204211
auto &uniquedConstants = foldScopes[insertRegion];
205-
Operation *constOp =
206-
tryGetOrCreateConstant(uniquedConstants, dialect, value, type, loc);
212+
Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
213+
type, erasedFoldedLocation);
207214
return constOp ? constOp->getResult(0) : Value();
208215
}
209216

@@ -254,8 +261,9 @@ OperationFolder::processFoldResults(Operation *op,
254261
// Check to see if there is a canonicalized version of this constant.
255262
auto res = op->getResult(i);
256263
Attribute attrRepl = foldResults[i].get<Attribute>();
257-
if (auto *constOp = tryGetOrCreateConstant(
258-
uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
264+
if (auto *constOp =
265+
tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
266+
res.getType(), erasedFoldedLocation)) {
259267
// Ensure that this constant dominates the operation we are replacing it
260268
// with. This may not automatically happen if the operation being folded
261269
// was inserted before the constant within the insertion block.
@@ -290,8 +298,11 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
290298
// Check if an existing mapping already exists.
291299
auto constKey = std::make_tuple(dialect, value, type);
292300
Operation *&constOp = uniquedConstants[constKey];
293-
if (constOp)
301+
if (constOp) {
302+
if (loc != constOp->getLoc())
303+
constOp->setLoc(erasedFoldedLocation);
294304
return constOp;
305+
}
295306

296307
// If one doesn't exist, try to materialize one.
297308
if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
@@ -314,6 +325,8 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
314325
notifyRemoval(constOp);
315326
rewriter.eraseOp(constOp);
316327
referencedDialects[existingOp].push_back(dialect);
328+
if (loc != existingOp->getLoc())
329+
existingOp->setLoc(erasedFoldedLocation);
317330
return constOp = existingOp;
318331
}
319332

mlir/test/Dialect/Transform/test-pattern-application.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ module {
179179
// CHECK: return %[[c5]]
180180
func.func @canonicalization(%t: tensor<5xf32>) -> index {
181181
%c0 = arith.constant 0 : index
182-
// expected-remark @below {{op was replaced}}
183182
%dim = tensor.dim %t, %c0 : tensor<5xf32>
184183
return %dim : index
185184
}
@@ -191,7 +190,6 @@ transform.sequence failures(propagate) {
191190
transform.apply_patterns to %1 {
192191
transform.apply_patterns.canonicalization
193192
} : !transform.any_op
194-
transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
195193
}
196194

197195
// -----
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
2+
3+
// CHECK-LABEL: func @merge_constants
4+
func.func @merge_constants() -> (index, index, index, index) {
5+
// CHECK-NEXT: arith.constant 42 : index loc(#[[UnknownLoc:.*]])
6+
%0 = arith.constant 42 : index loc("merge_constants":0:0)
7+
%1 = arith.constant 42 : index loc("merge_constants":1:0)
8+
%2 = arith.constant 42 : index loc("merge_constants":2:0)
9+
%3 = arith.constant 42 : index loc("merge_constants":2:0)
10+
return %0, %1, %2, %3 : index, index, index, index
11+
}
12+
// CHECK: #[[UnknownLoc]] = loc(unknown)
13+
14+
// -----
15+
16+
// CHECK-LABEL: func @simple_hoist
17+
func.func @simple_hoist(%arg0: memref<8xi32>) -> i32 {
18+
// CHECK: arith.constant 88 : i32 loc(#[[UnknownLoc:.*]])
19+
// CHECK: arith.constant 42 : i32 loc(#[[ConstLoc0:.*]])
20+
// CHECK: arith.constant 0 : index loc(#[[ConstLoc1:.*]])
21+
%0 = arith.constant 42 : i32 loc("simple_hoist":0:0)
22+
%1 = arith.constant 0 : index loc("simple_hoist":1:0)
23+
memref.store %0, %arg0[%1] : memref<8xi32>
24+
25+
%2 = arith.constant 88 : i32 loc("simple_hoist":2:0)
26+
27+
return %2 : i32
28+
}
29+
// CHECK-DAG: #[[UnknownLoc]] = loc(unknown)
30+
// CHECK-DAG: #[[ConstLoc0]] = loc("simple_hoist":0:0)
31+
// CHECK-DAG: #[[ConstLoc1]] = loc("simple_hoist":1:0)
32+
33+
// -----
34+
35+
// CHECK-LABEL: func @hoist_and_merge
36+
func.func @hoist_and_merge(%arg0: memref<8xi32>) {
37+
// CHECK-NEXT: arith.constant 42 : i32 loc(#[[UnknownLoc:.*]])
38+
affine.for %arg1 = 0 to 8 {
39+
%0 = arith.constant 42 : i32 loc("hoist_and_merge":0:0)
40+
%1 = arith.constant 42 : i32 loc("hoist_and_merge":1:0)
41+
memref.store %0, %arg0[%arg1] : memref<8xi32>
42+
memref.store %1, %arg0[%arg1] : memref<8xi32>
43+
}
44+
return
45+
} loc("hoist_and_merge":2:0)
46+
// CHECK: #[[UnknownLoc]] = loc(unknown)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt %s -split-input-file -test-constant-fold -mlir-print-debuginfo | FileCheck %s
2+
3+
// CHECK-LABEL: func @fold_and_merge
4+
func.func @fold_and_merge() -> (i32, i32) {
5+
// CHECK-NEXT: [[C:%.+]] = arith.constant 6 : i32 loc(#[[UnknownLoc:.*]])
6+
%0 = arith.constant 1 : i32 loc("fold_and_merge":0:0)
7+
%1 = arith.constant 5 : i32 loc("fold_and_merge":1:0)
8+
%2 = arith.addi %0, %1 : i32 loc("fold_and_merge":2:0)
9+
10+
%3 = arith.constant 6 : i32 loc("fold_and_merge":3:0)
11+
12+
return %2, %3: i32, i32
13+
}
14+
// CHECK: #[[UnknownLoc]] = loc(unknown)
15+
16+
// -----
17+
18+
// CHECK-LABEL: func @materialize_different_dialect
19+
func.func @materialize_different_dialect() -> (f32, f32) {
20+
// CHECK: arith.constant 1.{{0*}}e+00 : f32 loc(#[[UnknownLoc:.*]])
21+
%0 = arith.constant -1.0 : f32 loc("materialize_different_dialect":0:0)
22+
%1 = math.absf %0 : f32 loc("materialize_different_dialect":1:0)
23+
%2 = arith.constant 1.0 : f32 loc("materialize_different_dialect":2:0)
24+
25+
return %1, %2: f32, f32
26+
}
27+
// CHECK: #[[UnknownLoc]] = loc(unknown)
28+
29+
// -----
30+
31+
// CHECK-LABEL: func @materialize_in_front
32+
func.func @materialize_in_front(%arg0: memref<8xi32>) {
33+
// CHECK-NEXT: arith.constant 6 : i32 loc(#[[UnknownLoc:.*]])
34+
affine.for %arg1 = 0 to 8 {
35+
%1 = arith.constant 1 : i32 loc("materialize_in_front":0:0)
36+
%2 = arith.constant 5 : i32 loc("materialize_in_front":1:0)
37+
%3 = arith.addi %1, %2 : i32 loc("materialize_in_front":2:0)
38+
memref.store %3, %arg0[%arg1] : memref<8xi32>
39+
}
40+
return
41+
} loc("materialize_in_front":3:0)
42+
// CHECK: #[[UnknownLoc]] = loc(unknown)

mlir/test/lib/Transforms/TestIntRangeInference.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
4040
maybeDefiningOp ? maybeDefiningOp->getDialect()
4141
: value.getParentRegion()->getParentOp()->getDialect();
4242
Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
43-
Value constant =
44-
folder.getOrCreateConstant(b.getInsertionBlock(), valueDialect, constAttr,
45-
value.getType(), value.getLoc());
43+
Value constant = folder.getOrCreateConstant(
44+
b.getInsertionBlock(), valueDialect, constAttr, value.getType());
4645
if (!constant)
4746
return failure();
4847

0 commit comments

Comments
 (0)