Skip to content

Commit 8abe24b

Browse files
committed
fold parent location when hoisting
1 parent 8eff570 commit 8abe24b

File tree

4 files changed

+82
-26
lines changed

4 files changed

+82
-26
lines changed

mlir/include/mlir/Transforms/FoldUtils.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ class OperationFolder {
9696
Dialect *dialect, Attribute value,
9797
Type type, Location loc);
9898

99-
// Fuse `foldedLocation` into the Location of `retainedOp`. This will result
100-
// in `retainedOp` having a FusedLoc with `fusedLocationTag` to help trace the
101-
// source of the fusion. If `retainedOp` already had a FusedLoc with the same
102-
// tag, `foldedLocation` will simply be appended to it.
99+
// Fuse `foldedLocation` into `originalLocation`. This will result in a
100+
// FusedLoc with `fusedLocationTag` to help trace the source of the fusion.
101+
// If `originalLocation` already had a FusedLoc with the same tag,
102+
// `foldedLocation` will simply be appended to it.
103+
Location getFusedLocation(Location originalLocation, Location foldedLocation);
104+
// Update the location of `retainedOp` by applying `getFusedLocation`.
103105
void appendFoldedLocation(Operation *retainedOp, Location foldedLocation);
104106

105107
/// Tag for annotating fused locations as a result of merging constants.

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
152152
// anything. Otherwise, we move the constant to the insertion block.
153153
Block *insertBlock = &insertRegion->front();
154154
if (opBlock != insertBlock || (&insertBlock->front() != op &&
155-
!isFolderOwnedConstant(op->getPrevNode())))
155+
!isFolderOwnedConstant(op->getPrevNode()))) {
156156
op->moveBefore(&insertBlock->front());
157+
appendFoldedLocation(op, insertBlock->getParent()->getLoc());
158+
}
157159

158160
folderConstOp = op;
159161
referencedDialects[op].push_back(op->getDialect());
@@ -237,6 +239,7 @@ OperationFolder::processFoldResults(Operation *op,
237239
auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
238240
auto &entry = insertRegion->front();
239241
rewriter.setInsertionPoint(&entry, entry.begin());
242+
Location loc = getFusedLocation(op->getLoc(), insertRegion->getLoc());
240243

241244
// Get the constant map for the insertion region of this operation.
242245
auto &uniquedConstants = foldScopes[insertRegion];
@@ -259,8 +262,8 @@ OperationFolder::processFoldResults(Operation *op,
259262
// Check to see if there is a canonicalized version of this constant.
260263
auto res = op->getResult(i);
261264
Attribute attrRepl = foldResults[i].get<Attribute>();
262-
if (auto *constOp = tryGetOrCreateConstant(
263-
uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
265+
if (auto *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
266+
attrRepl, res.getType(), loc)) {
264267
// Ensure that this constant dominates the operation we are replacing it
265268
// with. This may not automatically happen if the operation being folded
266269
// was inserted before the constant within the insertion block.
@@ -364,31 +367,36 @@ static Location FlattenFusedLocationRecursively(const Location loc) {
364367
return loc;
365368
}
366369

367-
void OperationFolder::appendFoldedLocation(Operation *retainedOp,
370+
Location OperationFolder::getFusedLocation(Location originalLocation,
368371
Location foldedLocation) {
372+
// If they're already equal, no need to fuse.
373+
if (originalLocation == foldedLocation)
374+
return originalLocation;
375+
369376
// Append into existing fused location if it has the same tag.
370377
if (auto existingFusedLoc =
371-
dyn_cast<FusedLocWith<StringAttr>>(retainedOp->getLoc())) {
378+
dyn_cast<FusedLocWith<StringAttr>>(originalLocation)) {
372379
StringAttr existingMetadata = existingFusedLoc.getMetadata();
373380
if (existingMetadata == fusedLocationTag) {
374381
ArrayRef<Location> existingLocations = existingFusedLoc.getLocations();
375382
SetVector<Location> locations(existingLocations.begin(),
376383
existingLocations.end());
377384
locations.insert(foldedLocation);
378-
Location newFusedLoc = FusedLoc::get(
379-
retainedOp->getContext(), locations.takeVector(), existingMetadata);
380-
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
381-
return;
385+
Location newFusedLoc =
386+
FusedLoc::get(originalLocation->getContext(), locations.takeVector(),
387+
existingMetadata);
388+
return FlattenFusedLocationRecursively(newFusedLoc);
382389
}
383390
}
384391

385392
// Create a new fusedloc with retainedOp's loc and foldedLocation.
386-
// If they're already equal, no need to fuse.
387-
if (retainedOp->getLoc() == foldedLocation)
388-
return;
389-
390393
Location newFusedLoc =
391-
FusedLoc::get(retainedOp->getContext(),
392-
{retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
393-
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
394+
FusedLoc::get(originalLocation->getContext(),
395+
{originalLocation, foldedLocation}, fusedLocationTag);
396+
return FlattenFusedLocationRecursively(newFusedLoc);
397+
}
398+
399+
void OperationFolder::appendFoldedLocation(Operation *retainedOp,
400+
Location foldedLocation) {
401+
retainedOp->setLoc(getFusedLocation(retainedOp->getLoc(), foldedLocation));
394402
}

mlir/test/Transforms/canonicalize-debuginfo.mlir

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,31 @@ func.func @hoist_constant(%arg0: memref<8xi32>) {
3333
memref.store %0, %arg0[%arg1] : memref<8xi32>
3434
memref.store %1, %arg0[%arg1] : memref<8xi32>
3535
}
36+
// CHECK: return
3637
return
37-
}
38+
// CHECK-NEXT: } loc(#[[LocFunc:.*]])
39+
} loc("hoist_constant":2:0)
3840

3941
// CHECK-DAG: #[[LocConst0:.*]] = loc("hoist_constant":0:0)
4042
// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant":1:0)
41-
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]]])
43+
// CHECK-DAG: #[[LocFunc]] = loc("hoist_constant":2:0)
44+
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]], #[[LocConst1]]])
45+
46+
// -----
47+
48+
// CHECK-LABEL: func @hoist_constant_simple
49+
func.func @hoist_constant_simple(%arg0: memref<8xi32>) -> i32 {
50+
// CHECK-NEXT: arith.constant 88 : i32 loc(#[[FusedLoc:.*]])
51+
%0 = arith.constant 42 : i32 loc("hoist_constant_simple":0:0)
52+
%1 = arith.constant 0 : index
53+
memref.store %0, %arg0[%1] : memref<8xi32>
54+
55+
%2 = arith.constant 88 : i32 loc("hoist_constant_simple":1:0)
56+
57+
return %2 : i32
58+
// CHECK: } loc(#[[LocFunc:.*]])
59+
} loc("hoist_constant_simple":2:0)
60+
61+
// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant_simple":1:0)
62+
// CHECK-DAG: #[[LocFunc]] = loc("hoist_constant_simple":2:0)
63+
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]]])

mlir/test/Transforms/constant-fold-debuginfo.mlir

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ func.func @fold_and_merge() -> (i32, i32) {
1111
%3 = arith.constant 6 : i32 loc("fold_and_merge":1:0)
1212

1313
return %2, %3: i32, i32
14-
}
14+
// CHECK: } loc(#[[LocFunc:.*]])
15+
} loc("fold_and_merge":2:0)
1516

1617
// CHECK-DAG: #[[LocConst0:.*]] = loc("fold_and_merge":0:0)
1718
// CHECK-DAG: #[[LocConst1:.*]] = loc("fold_and_merge":1:0)
18-
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
19+
// CHECK-DAG: #[[LocFunc]] = loc("fold_and_merge":2:0)
20+
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])
1921

2022
// -----
2123

@@ -27,8 +29,30 @@ func.func @materialize_different_dialect() -> (f32, f32) {
2729
%2 = arith.constant 1.0 : f32 loc("materialize_different_dialect":1:0)
2830

2931
return %1, %2: f32, f32
30-
}
32+
// CHECK: } loc(#[[LocFunc:.*]])
33+
} loc("materialize_different_dialect":2:0)
3134

3235
// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_different_dialect":0:0)
3336
// CHECK-DAG: #[[LocConst1:.*]] = loc("materialize_different_dialect":1:0)
34-
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
37+
// CHECK-DAG: #[[LocFunc]] = loc("materialize_different_dialect":2:0)
38+
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])
39+
40+
// -----
41+
42+
// CHECK-LABEL: func @materialize_in_front
43+
func.func @materialize_in_front(%arg0: memref<8xi32>) {
44+
// CHECK-NEXT: arith.constant 6 : i32 loc(#[[FusedLoc:.*]])
45+
affine.for %arg1 = 0 to 8 {
46+
%1 = arith.constant 1 : i32
47+
%2 = arith.constant 5 : i32
48+
%3 = arith.addi %1, %2 : i32 loc("materialize_in_front":0:0)
49+
memref.store %3, %arg0[%arg1] : memref<8xi32>
50+
}
51+
// CHECK: return
52+
return
53+
// CHECK-NEXT: } loc(#[[LocFunc:.*]])
54+
} loc("materialize_in_front":1:0)
55+
56+
// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_in_front":0:0)
57+
// CHECK-DAG: #[[LocFunc]] = loc("materialize_in_front":1:0)
58+
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]]])

0 commit comments

Comments
 (0)