Skip to content

Commit a269b08

Browse files
committed
move all topo sorts into the utils file
1 parent 277ed0a commit a269b08

File tree

8 files changed

+82
-81
lines changed

8 files changed

+82
-81
lines changed

mlir/include/mlir/Analysis/SliceAnalysis.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,6 @@ SetVector<Operation *>
223223
getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
224224
const ForwardSliceOptions &forwardSliceOptions = {});
225225

226-
/// Multi-root DAG topological sort.
227-
/// Performs a topological sort of the Operation in the `toSort` SetVector.
228-
/// Returns a topologically sorted SetVector.
229-
/// Does not support multi-sets.
230-
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
231-
232226
/// Utility to match a generic reduction given a list of iteration-carried
233227
/// arguments, `iterCarriedArgs` and the position of the potential reduction
234228
/// argument within the list, `redPos`. If a reduction is matched, returns the

mlir/include/mlir/Analysis/TopologicalSortUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ bool computeTopologicalSorting(
108108
/// stable.
109109
SetVector<Block *> getBlocksSortedByDominance(Region &region);
110110

111+
/// Sorts all operation in `toSort` topologically while also region semantics.
112+
/// Does not support multi-sets.
113+
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
114+
111115
} // end namespace mlir
112116

113117
#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -165,80 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
165165
return topologicalSort(slice);
166166
}
167167

168-
/// Computes the common ancestor region of all operations in `ops`. Remembers
169-
/// all the traversed regions in `traversedRegions`.
170-
static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
171-
DenseSet<Region *> &traversedRegions) {
172-
// Map to count the number of times a region was encountered.
173-
llvm::DenseMap<Region *, size_t> regionCounts;
174-
size_t expectedCount = ops.size();
175-
176-
// Walk the region tree for each operation towards the root and add to the
177-
// region count.
178-
Region *res = nullptr;
179-
for (Operation *op : ops) {
180-
Region *current = op->getParentRegion();
181-
while (current) {
182-
// Insert or get the count.
183-
auto it = regionCounts.try_emplace(current, 0).first;
184-
size_t count = ++it->getSecond();
185-
if (count == expectedCount) {
186-
res = current;
187-
break;
188-
}
189-
current = current->getParentRegion();
190-
}
191-
}
192-
auto firstRange = llvm::make_first_range(regionCounts);
193-
traversedRegions.insert(firstRange.begin(), firstRange.end());
194-
return res;
195-
}
196-
197-
/// Topologically traverses `region` and insers all encountered operations in
198-
/// `toSort` into the result. Recursively traverses regions when they are
199-
/// present in `relevantRegions`.
200-
static void topoSortRegion(Region &region,
201-
const DenseSet<Region *> &relevantRegions,
202-
const SetVector<Operation *> &toSort,
203-
SetVector<Operation *> &result) {
204-
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
205-
for (Block *block : sortedBlocks) {
206-
for (Operation &op : *block) {
207-
if (toSort.contains(&op))
208-
result.insert(&op);
209-
for (Region &subRegion : op.getRegions()) {
210-
// Skip regions that do not contain operations from `toSort`.
211-
if (!relevantRegions.contains(&region))
212-
continue;
213-
topoSortRegion(subRegion, relevantRegions, toSort, result);
214-
}
215-
}
216-
}
217-
}
218-
219-
SetVector<Operation *>
220-
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
221-
if (toSort.size() <= 1)
222-
return toSort;
223-
224-
assert(llvm::all_of(toSort,
225-
[&](Operation *op) { return toSort.count(op) == 1; }) &&
226-
"expected only unique set entries");
227-
228-
// First, find the root region to start the recursive traversal through the
229-
// IR.
230-
DenseSet<Region *> relevantRegions;
231-
Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
232-
assert(rootRegion && "expected all ops to have a common ancestor");
233-
234-
// Sort all element in `toSort` by recursively traversing the IR.
235-
SetVector<Operation *> result;
236-
topoSortRegion(*rootRegion, relevantRegions, toSort, result);
237-
assert(result.size() == toSort.size() &&
238-
"expected all operations to be present in the result");
239-
return result;
240-
}
241-
242168
/// Returns true if `value` (transitively) depends on iteration-carried values
243169
/// of the given `ancestorOp`.
244170
static bool dependsOnCarriedVals(Value value,

mlir/lib/Analysis/TopologicalSortUtils.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,77 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
167167

168168
return blocks;
169169
}
170+
171+
/// Computes the common ancestor region of all operations in `ops`. Remembers
172+
/// all the traversed regions in `traversedRegions`.
173+
static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
174+
DenseSet<Region *> &traversedRegions) {
175+
// Map to count the number of times a region was encountered.
176+
llvm::DenseMap<Region *, size_t> regionCounts;
177+
size_t expectedCount = ops.size();
178+
179+
// Walk the region tree for each operation towards the root and add to the
180+
// region count.
181+
Region *res = nullptr;
182+
for (Operation *op : ops) {
183+
Region *current = op->getParentRegion();
184+
while (current) {
185+
// Insert or get the count.
186+
auto it = regionCounts.try_emplace(current, 0).first;
187+
size_t count = ++it->getSecond();
188+
if (count == expectedCount) {
189+
res = current;
190+
break;
191+
}
192+
current = current->getParentRegion();
193+
}
194+
}
195+
auto firstRange = llvm::make_first_range(regionCounts);
196+
traversedRegions.insert(firstRange.begin(), firstRange.end());
197+
return res;
198+
}
199+
200+
/// Topologically traverses `region` and insers all encountered operations in
201+
/// `toSort` into the result. Recursively traverses regions when they are
202+
/// present in `relevantRegions`.
203+
static void topoSortRegion(Region &region,
204+
const DenseSet<Region *> &relevantRegions,
205+
const SetVector<Operation *> &toSort,
206+
SetVector<Operation *> &result) {
207+
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
208+
for (Block *block : sortedBlocks) {
209+
for (Operation &op : *block) {
210+
if (toSort.contains(&op))
211+
result.insert(&op);
212+
for (Region &subRegion : op.getRegions()) {
213+
// Skip regions that do not contain operations from `toSort`.
214+
if (!relevantRegions.contains(&region))
215+
continue;
216+
topoSortRegion(subRegion, relevantRegions, toSort, result);
217+
}
218+
}
219+
}
220+
}
221+
222+
SetVector<Operation *>
223+
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
224+
if (toSort.size() <= 1)
225+
return toSort;
226+
227+
assert(llvm::all_of(toSort,
228+
[&](Operation *op) { return toSort.count(op) == 1; }) &&
229+
"expected only unique set entries");
230+
231+
// First, find the root region to start the recursive traversal through the
232+
// IR.
233+
DenseSet<Region *> relevantRegions;
234+
Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
235+
assert(rootRegion && "expected all ops to have a common ancestor");
236+
237+
// Sort all element in `toSort` by recursively traversing the IR.
238+
SetVector<Operation *> result;
239+
topoSortRegion(*rootRegion, relevantRegions, toSort, result);
240+
assert(result.size() == toSort.size() &&
241+
"expected all operations to be present in the result");
242+
return result;
243+
}

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <type_traits>
1616

1717
#include "mlir/Analysis/SliceAnalysis.h"
18+
#include "mlir/Analysis/TopologicalSortUtils.h"
1819
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1920
#include "mlir/Dialect/Arith/IR/Arith.h"
2021
#include "mlir/Dialect/GPU/IR/GPUDialect.h"

mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
1414
#include "mlir/Analysis/SliceAnalysis.h"
15+
#include "mlir/Analysis/TopologicalSortUtils.h"
1516
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1617
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1718
#include "mlir/Dialect/Affine/Analysis/Utils.h"

mlir/lib/Transforms/SROA.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Transforms/SROA.h"
1010
#include "mlir/Analysis/DataLayoutAnalysis.h"
1111
#include "mlir/Analysis/SliceAnalysis.h"
12+
#include "mlir/Analysis/TopologicalSortUtils.h"
1213
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1314
#include "mlir/Transforms/Passes.h"
1415

mlir/test/lib/Analysis/TestSlice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Analysis/SliceAnalysis.h"
9+
#include "mlir/Analysis/TopologicalSortUtils.h"
1010
#include "mlir/IR/BuiltinTypes.h"
1111
#include "mlir/IR/SymbolTable.h"
1212
#include "mlir/Pass/Pass.h"

0 commit comments

Comments
 (0)