Skip to content

Commit 9ea5fb9

Browse files
committed
address nit comments
1 parent a269b08 commit 9ea5fb9

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

mlir/include/mlir/Analysis/TopologicalSortUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,12 @@ bool computeTopologicalSorting(
104104
MutableArrayRef<Operation *> ops,
105105
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
106106

107-
/// Get a list of blocks that is sorted according to dominance. This sort is
107+
/// Gets a list of blocks that is sorted according to dominance. This sort is
108108
/// stable.
109109
SetVector<Block *> getBlocksSortedByDominance(Region &region);
110110

111-
/// Sorts all operation in `toSort` topologically while also region semantics.
112-
/// Does not support multi-sets.
111+
/// Sorts all operations in `toSort` topologically while also considering region
112+
/// semantics. Does not support multi-sets.
113113
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
114114

115115
} // end namespace mlir

mlir/lib/Analysis/TopologicalSortUtils.cpp

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,12 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
168168
return blocks;
169169
}
170170

171-
/// Computes the common ancestor region of all operations in `ops`. Remembers
172-
/// all the traversed regions in `traversedRegions`.
173-
static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
174-
DenseSet<Region *> &traversedRegions) {
171+
/// Computes the closest common ancestor region of all operations in `ops`.
172+
/// Remembers all the traversed regions in `traversedRegions`.
173+
static Region *findCommonAncestorRegion(const SetVector<Operation *> &ops,
174+
DenseSet<Region *> &traversedRegions) {
175175
// Map to count the number of times a region was encountered.
176-
llvm::DenseMap<Region *, size_t> regionCounts;
176+
DenseMap<Region *, size_t> regionCounts;
177177
size_t expectedCount = ops.size();
178178

179179
// Walk the region tree for each operation towards the root and add to the
@@ -182,10 +182,8 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
182182
for (Operation *op : ops) {
183183
Region *current = op->getParentRegion();
184184
while (current) {
185-
// Insert or get the count.
186-
auto it = regionCounts.try_emplace(current, 0).first;
187-
size_t count = ++it->getSecond();
188-
if (count == expectedCount) {
185+
// Insert or update the count and compare it.
186+
if (++regionCounts[current] == expectedCount) {
189187
res = current;
190188
break;
191189
}
@@ -197,11 +195,11 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
197195
return res;
198196
}
199197

200-
/// Topologically traverses `region` and insers all encountered operations in
198+
/// Topologically traverses `region` and inserts all encountered operations in
201199
/// `toSort` into the result. Recursively traverses regions when they are
202200
/// present in `relevantRegions`.
203201
static void topoSortRegion(Region &region,
204-
const DenseSet<Region *> &relevantRegions,
202+
const DenseSet<Region *> &ancestorRegions,
205203
const SetVector<Operation *> &toSort,
206204
SetVector<Operation *> &result) {
207205
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
@@ -211,9 +209,9 @@ static void topoSortRegion(Region &region,
211209
result.insert(&op);
212210
for (Region &subRegion : op.getRegions()) {
213211
// Skip regions that do not contain operations from `toSort`.
214-
if (!relevantRegions.contains(&region))
212+
if (!ancestorRegions.contains(&region))
215213
continue;
216-
topoSortRegion(subRegion, relevantRegions, toSort, result);
214+
topoSortRegion(subRegion, ancestorRegions, toSort, result);
217215
}
218216
}
219217
}
@@ -224,19 +222,15 @@ mlir::topologicalSort(const SetVector<Operation *> &toSort) {
224222
if (toSort.size() <= 1)
225223
return toSort;
226224

227-
assert(llvm::all_of(toSort,
228-
[&](Operation *op) { return toSort.count(op) == 1; }) &&
229-
"expected only unique set entries");
230-
231225
// First, find the root region to start the recursive traversal through the
232226
// IR.
233-
DenseSet<Region *> relevantRegions;
234-
Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
227+
DenseSet<Region *> ancestorRegions;
228+
Region *rootRegion = findCommonAncestorRegion(toSort, ancestorRegions);
235229
assert(rootRegion && "expected all ops to have a common ancestor");
236230

237231
// Sort all element in `toSort` by recursively traversing the IR.
238232
SetVector<Operation *> result;
239-
topoSortRegion(*rootRegion, relevantRegions, toSort, result);
233+
topoSortRegion(*rootRegion, ancestorRegions, toSort, result);
240234
assert(result.size() == toSort.size() &&
241235
"expected all operations to be present in the result");
242236
return result;

mlir/lib/Transforms/SROA.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,19 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
108108

109109
// An operation that has blocking uses must be promoted. If it is not
110110
// promotable, destructuring must fail.
111-
if (!promotable)
111+
if (!promotable) {
112+
// user->emitError() << "not promotable";
112113
return {};
114+
}
113115

114116
SmallVector<OpOperand *> newBlockingUses;
115117
// If the operation decides it cannot deal with removing the blocking uses,
116118
// destructuring must fail.
117-
if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
119+
if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
120+
dataLayout)) {
121+
// promotable->emitError() << "not removable";
118122
return {};
123+
}
119124

120125
// Then, register any new blocking uses for coming operations.
121126
for (OpOperand *blockingUse : newBlockingUses) {

mlir/test/lib/Analysis/TestSlice.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ struct TestTopologicalSortPass
2525

2626
StringRef getArgument() const final { return "test-print-topological-sort"; }
2727
StringRef getDescription() const final {
28-
return "Print operations in topological order";
28+
return "Sorts operations topologically and attaches attributes with their "
29+
"corresponding index in the ordering to them";
2930
}
3031
void runOnOperation() override {
3132
SetVector<Operation *> toSort;

0 commit comments

Comments
 (0)