Skip to content

[MLIR] Fuse parent region location when hoisting constants #75258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mlir/include/mlir/Transforms/FoldUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ class OperationFolder {
Dialect *dialect, Attribute value,
Type type, Location loc);

// Fuse `foldedLocation` into the Location of `retainedOp`. This will result
// in `retainedOp` having a FusedLoc with `fusedLocationTag` to help trace the
// source of the fusion. If `retainedOp` already had a FusedLoc with the same
// tag, `foldedLocation` will simply be appended to it.
// Fuse `foldedLocation` into `originalLocation`. This will result in a
// FusedLoc with `fusedLocationTag` to help trace the source of the fusion.
// If `originalLocation` already had a FusedLoc with the same tag,
// `foldedLocation` will simply be appended to it.
Location getFusedLocation(Location originalLocation, Location foldedLocation);
// Update the location of `retainedOp` by applying `getFusedLocation`.
void appendFoldedLocation(Operation *retainedOp, Location foldedLocation);

/// Tag for annotating fused locations as a result of merging constants.
Expand Down
40 changes: 24 additions & 16 deletions mlir/lib/Transforms/Utils/FoldUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
// anything. Otherwise, we move the constant to the insertion block.
Block *insertBlock = &insertRegion->front();
if (opBlock != insertBlock || (&insertBlock->front() != op &&
!isFolderOwnedConstant(op->getPrevNode())))
!isFolderOwnedConstant(op->getPrevNode()))) {
op->moveBefore(&insertBlock->front());
appendFoldedLocation(op, insertBlock->getParent()->getLoc());
}

folderConstOp = op;
referencedDialects[op].push_back(op->getDialect());
Expand Down Expand Up @@ -237,6 +239,7 @@ OperationFolder::processFoldResults(Operation *op,
auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
auto &entry = insertRegion->front();
rewriter.setInsertionPoint(&entry, entry.begin());
Location loc = getFusedLocation(op->getLoc(), insertRegion->getLoc());

// Get the constant map for the insertion region of this operation.
auto &uniquedConstants = foldScopes[insertRegion];
Expand All @@ -259,8 +262,8 @@ OperationFolder::processFoldResults(Operation *op,
// Check to see if there is a canonicalized version of this constant.
auto res = op->getResult(i);
Attribute attrRepl = foldResults[i].get<Attribute>();
if (auto *constOp = tryGetOrCreateConstant(
uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
if (auto *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
attrRepl, res.getType(), loc)) {
// Ensure that this constant dominates the operation we are replacing it
// with. This may not automatically happen if the operation being folded
// was inserted before the constant within the insertion block.
Expand Down Expand Up @@ -364,31 +367,36 @@ static Location FlattenFusedLocationRecursively(const Location loc) {
return loc;
}

void OperationFolder::appendFoldedLocation(Operation *retainedOp,
Location OperationFolder::getFusedLocation(Location originalLocation,
Location foldedLocation) {
// If they're already equal, no need to fuse.
if (originalLocation == foldedLocation)
return originalLocation;

// Append into existing fused location if it has the same tag.
if (auto existingFusedLoc =
dyn_cast<FusedLocWith<StringAttr>>(retainedOp->getLoc())) {
dyn_cast<FusedLocWith<StringAttr>>(originalLocation)) {
StringAttr existingMetadata = existingFusedLoc.getMetadata();
if (existingMetadata == fusedLocationTag) {
ArrayRef<Location> existingLocations = existingFusedLoc.getLocations();
SetVector<Location> locations(existingLocations.begin(),
existingLocations.end());
locations.insert(foldedLocation);
Location newFusedLoc = FusedLoc::get(
retainedOp->getContext(), locations.takeVector(), existingMetadata);
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
return;
Location newFusedLoc =
FusedLoc::get(originalLocation->getContext(), locations.takeVector(),
existingMetadata);
return FlattenFusedLocationRecursively(newFusedLoc);
}
}

// Create a new fusedloc with retainedOp's loc and foldedLocation.
// If they're already equal, no need to fuse.
if (retainedOp->getLoc() == foldedLocation)
return;

Location newFusedLoc =
FusedLoc::get(retainedOp->getContext(),
{retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
FusedLoc::get(originalLocation->getContext(),
{originalLocation, foldedLocation}, fusedLocationTag);
return FlattenFusedLocationRecursively(newFusedLoc);
}

void OperationFolder::appendFoldedLocation(Operation *retainedOp,
Location foldedLocation) {
retainedOp->setLoc(getFusedLocation(retainedOp->getLoc(), foldedLocation));
}
26 changes: 24 additions & 2 deletions mlir/test/Transforms/canonicalize-debuginfo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,31 @@ func.func @hoist_constant(%arg0: memref<8xi32>) {
memref.store %0, %arg0[%arg1] : memref<8xi32>
memref.store %1, %arg0[%arg1] : memref<8xi32>
}
// CHECK: return
return
}
// CHECK-NEXT: } loc(#[[LocFunc:.*]])
} loc("hoist_constant":2:0)

// CHECK-DAG: #[[LocConst0:.*]] = loc("hoist_constant":0:0)
// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant":1:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]]])
// CHECK-DAG: #[[LocFunc]] = loc("hoist_constant":2:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]], #[[LocConst1]]])

// -----

// CHECK-LABEL: func @hoist_constant_simple
func.func @hoist_constant_simple(%arg0: memref<8xi32>) -> i32 {
// CHECK-NEXT: arith.constant 88 : i32 loc(#[[FusedLoc:.*]])
%0 = arith.constant 42 : i32 loc("hoist_constant_simple":0:0)
%1 = arith.constant 0 : index
memref.store %0, %arg0[%1] : memref<8xi32>

%2 = arith.constant 88 : i32 loc("hoist_constant_simple":1:0)

return %2 : i32
// CHECK: } loc(#[[LocFunc:.*]])
} loc("hoist_constant_simple":2:0)

// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant_simple":1:0)
// CHECK-DAG: #[[LocFunc]] = loc("hoist_constant_simple":2:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]]])
32 changes: 28 additions & 4 deletions mlir/test/Transforms/constant-fold-debuginfo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ func.func @fold_and_merge() -> (i32, i32) {
%3 = arith.constant 6 : i32 loc("fold_and_merge":1:0)

return %2, %3: i32, i32
}
// CHECK: } loc(#[[LocFunc:.*]])
} loc("fold_and_merge":2:0)

// CHECK-DAG: #[[LocConst0:.*]] = loc("fold_and_merge":0:0)
// CHECK-DAG: #[[LocConst1:.*]] = loc("fold_and_merge":1:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
// CHECK-DAG: #[[LocFunc]] = loc("fold_and_merge":2:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])

// -----

Expand All @@ -27,8 +29,30 @@ func.func @materialize_different_dialect() -> (f32, f32) {
%2 = arith.constant 1.0 : f32 loc("materialize_different_dialect":1:0)

return %1, %2: f32, f32
}
// CHECK: } loc(#[[LocFunc:.*]])
} loc("materialize_different_dialect":2:0)

// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_different_dialect":0:0)
// CHECK-DAG: #[[LocConst1:.*]] = loc("materialize_different_dialect":1:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
// CHECK-DAG: #[[LocFunc]] = loc("materialize_different_dialect":2:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])

// -----

// CHECK-LABEL: func @materialize_in_front
func.func @materialize_in_front(%arg0: memref<8xi32>) {
// CHECK-NEXT: arith.constant 6 : i32 loc(#[[FusedLoc:.*]])
affine.for %arg1 = 0 to 8 {
%1 = arith.constant 1 : i32
%2 = arith.constant 5 : i32
%3 = arith.addi %1, %2 : i32 loc("materialize_in_front":0:0)
memref.store %3, %arg0[%arg1] : memref<8xi32>
}
// CHECK: return
return
// CHECK-NEXT: } loc(#[[LocFunc:.*]])
} loc("materialize_in_front":1:0)

// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_in_front":0:0)
// CHECK-DAG: #[[LocFunc]] = loc("materialize_in_front":1:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]]])