Skip to content

[MLIR][Transforms] Correct block sorting utils name (NFC) #92558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
LogicalResult runRegionDCE(RewriterBase &rewriter,
MutableArrayRef<Region> regions);

/// Get a topologically sorted list of blocks of the given region.
SetVector<Block *> getTopologicallySortedBlocks(Region &region);
/// Get a list of blocks that is sorted according to dominance. This sort is
/// stable.
SetVector<Block *> getBlocksSortedByDominance(Region &region);

} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ static LogicalResult convertDataOp(acc::DataOp &op,
llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
ctx, "acc.end_data", builder.GetInsertBlock()->getParent());

SetVector<Block *> blocks = getTopologicallySortedBlocks(op.getRegion());
SetVector<Block *> blocks = getBlocksSortedByDominance(op.getRegion());
for (Block *bb : blocks) {
llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
if (bb->isEntryBlock()) {
Expand Down
54 changes: 26 additions & 28 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ static llvm::BasicBlock *convertOmpOpRegions(

// Convert blocks one by one in topological order to ensure
// defs are converted before uses.
SetVector<Block *> blocks = getTopologicallySortedBlocks(region);
SetVector<Block *> blocks = getBlocksSortedByDominance(region);
for (Block *bb : blocks) {
llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
// Retarget the branch of the entry block to the entry block of the
Expand Down Expand Up @@ -2146,40 +2146,38 @@ getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
llvm::SmallVector<size_t> indices(shape[0]);
std::iota(indices.begin(), indices.end(), 0);

llvm::sort(
indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
auto indexValues = indexAttr.getValues<int32_t>();
for (int i = 0;
i < shape[1];
++i) {
int aIndex = indexValues[a * shape[1] + i];
int bIndex = indexValues[b * shape[1] + i];
llvm::sort(indices.begin(), indices.end(),
[&](const size_t a, const size_t b) {
auto indexValues = indexAttr.getValues<int32_t>();
for (int i = 0; i < shape[1]; ++i) {
int aIndex = indexValues[a * shape[1] + i];
int bIndex = indexValues[b * shape[1] + i];

if (aIndex == bIndex)
continue;
if (aIndex == bIndex)
continue;

if (aIndex != -1 && bIndex == -1)
return false;
if (aIndex != -1 && bIndex == -1)
return false;

if (aIndex == -1 && bIndex != -1)
return true;
if (aIndex == -1 && bIndex != -1)
return true;

// A is earlier in the record type layout than B
if (aIndex < bIndex)
return first;
// A is earlier in the record type layout than B
if (aIndex < bIndex)
return first;

if (bIndex < aIndex)
return !first;
}
if (bIndex < aIndex)
return !first;
}

// Iterated the entire list and couldn't make a decision, all elements
// were likely the same. Return false, since the sort comparator should
// return false for equal elements.
return false;
});
// Iterated the entire list and couldn't make a decision, all
// elements were likely the same. Return false, since the sort
// comparator should return false for equal elements.
return false;
});

return llvm::cast<mlir::omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
return llvm::cast<mlir::omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
}

/// This function calculates the array/pointer offset for map data provided
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {

// Then, convert blocks one by one in topological order to ensure defs are
// converted before uses.
auto blocks = getTopologicallySortedBlocks(func.getBody());
auto blocks = getBlocksSortedByDominance(func.getBody());
for (Block *bb : blocks) {
CapturingIRBuilder builder(llvmContext);
if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/Mem2Reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
return it->second;

DenseMap<Block *, size_t> &blockIndices = it->second;
SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(*region);
SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(*region);
for (auto [index, block] : llvm::enumerate(topologicalOrder))
blockIndices[block] = index;
return blockIndices;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
mergedIdenticalBlocks);
}

SetVector<Block *> mlir::getTopologicallySortedBlocks(Region &region) {
SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
// For each block that has not been visited yet (i.e. that has no
// predecessors), add it to the list as well as its successors.
SetVector<Block *> blocks;
Expand Down
Loading