Skip to content

Commit 616b6c2

Browse files
committed
first impl of the new topo sort
1 parent e919df5 commit 616b6c2

File tree

5 files changed

+211
-157
lines changed

5 files changed

+211
-157
lines changed

mlir/include/mlir/Analysis/SliceAnalysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
226226
/// Multi-root DAG topological sort.
227227
/// Performs a topological sort of the Operation in the `toSort` SetVector.
228228
/// Returns a topologically sorted SetVector.
229+
/// Does not support multi-sets.
229230
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
230231

231232
/// Utility to match a generic reduction given a list of iteration-carried

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 82 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Analysis/SliceAnalysis.h"
14+
#include "mlir/IR/Block.h"
1415
#include "mlir/IR/BuiltinOps.h"
1516
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/RegionGraphTraits.h"
1618
#include "mlir/Interfaces/SideEffectInterfaces.h"
1719
#include "mlir/Support/LLVM.h"
20+
#include "llvm/ADT/PostOrderIterator.h"
1821
#include "llvm/ADT/SetVector.h"
1922
#include "llvm/ADT/SmallPtrSet.h"
2023

@@ -164,60 +167,95 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
164167
return topologicalSort(slice);
165168
}
166169

167-
namespace {
168-
/// DFS post-order implementation that maintains a global count to work across
169-
/// multiple invocations, to help implement topological sort on multi-root DAGs.
170-
/// We traverse all operations but only record the ones that appear in
171-
/// `toSort` for the final result.
172-
struct DFSState {
173-
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
174-
const SetVector<Operation *> &toSort;
175-
SmallVector<Operation *, 16> topologicalCounts;
176-
DenseSet<Operation *> seen;
177-
};
178-
} // namespace
179-
180-
static void dfsPostorder(Operation *root, DFSState *state) {
181-
SmallVector<Operation *> queue(1, root);
182-
std::vector<Operation *> ops;
183-
while (!queue.empty()) {
184-
Operation *current = queue.pop_back_val();
185-
ops.push_back(current);
186-
for (Operation *op : current->getUsers())
187-
queue.push_back(op);
188-
for (Region &region : current->getRegions()) {
189-
for (Operation &op : region.getOps())
190-
queue.push_back(&op);
170+
/// TODO: deduplicate
171+
static SetVector<Block *> getTopologicallySortedBlocks(Region &region) {
172+
// For each block that has not been visited yet (i.e. that has no
173+
// predecessors), add it to the list as well as its successors.
174+
SetVector<Block *> blocks;
175+
for (Block &b : region) {
176+
if (blocks.count(&b) == 0) {
177+
llvm::ReversePostOrderTraversal<Block *> traversal(&b);
178+
blocks.insert(traversal.begin(), traversal.end());
191179
}
192180
}
181+
assert(blocks.size() == region.getBlocks().size() &&
182+
"some blocks are not sorted");
193183

194-
for (Operation *op : llvm::reverse(ops)) {
195-
if (state->seen.insert(op).second && state->toSort.count(op) > 0)
196-
state->topologicalCounts.push_back(op);
184+
return blocks;
185+
}
186+
187+
/// Computes the common ancestor region of all operations in `ops`. Remembers
188+
/// all the traversed regions in `traversedRegions`.
189+
static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
190+
DenseSet<Region *> &traversedRegions) {
191+
// Map to count the number of times a region was encountered.
192+
llvm::DenseMap<Region *, size_t> regionCounts;
193+
size_t expectedCount = ops.size();
194+
195+
// Walk the region tree for each operation towards the root and add to the
196+
// region count.
197+
Region *res = nullptr;
198+
for (Operation *op : ops) {
199+
Region *current = op->getParentRegion();
200+
while (current) {
201+
// Insert or get the count.
202+
auto it = regionCounts.try_emplace(current, 0).first;
203+
size_t count = ++it->getSecond();
204+
if (count == expectedCount) {
205+
res = current;
206+
break;
207+
}
208+
current = current->getParentRegion();
209+
}
210+
}
211+
auto firstRange = llvm::make_first_range(regionCounts);
212+
traversedRegions.insert(firstRange.begin(), firstRange.end());
213+
return res;
214+
}
215+
216+
/// Topologically traverses `region` and insers all encountered operations in
217+
/// `toSort` into the result. Recursively traverses regions when they are
218+
/// present in `relevantRegions`.
219+
static void topoSortRegion(Region &region,
220+
const DenseSet<Region *> &relevantRegions,
221+
const SetVector<Operation *> &toSort,
222+
SetVector<Operation *> &result) {
223+
SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region);
224+
for (Block *block : sortedBlocks) {
225+
for (Operation &op : *block) {
226+
if (toSort.contains(&op))
227+
result.insert(&op);
228+
for (Region &subRegion : op.getRegions()) {
229+
// Skip regions that do not contain operations from `toSort`.
230+
if (!relevantRegions.contains(&region))
231+
continue;
232+
topoSortRegion(subRegion, relevantRegions, toSort, result);
233+
}
234+
}
197235
}
198236
}
199237

200238
SetVector<Operation *>
201239
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
202-
if (toSort.empty()) {
240+
if (toSort.size() <= 1)
203241
return toSort;
204-
}
205242

206-
// Run from each root with global count and `seen` set.
207-
DFSState state(toSort);
208-
for (auto *s : toSort) {
209-
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
210-
dfsPostorder(s, &state);
211-
}
212-
213-
// Reorder and return.
214-
SetVector<Operation *> res;
215-
for (auto it = state.topologicalCounts.rbegin(),
216-
eit = state.topologicalCounts.rend();
217-
it != eit; ++it) {
218-
res.insert(*it);
219-
}
220-
return res;
243+
assert(llvm::all_of(toSort,
244+
[&](Operation *op) { return toSort.count(op) == 1; }) &&
245+
"expected only unique set entries");
246+
247+
// First, find the root region to start the recursive traversal through the
248+
// IR.
249+
DenseSet<Region *> relevantRegions;
250+
Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
251+
assert(rootRegion && "expected all ops to have a common ancestor");
252+
253+
// Sort all element in `toSort` by recursively traversing the IR.
254+
SetVector<Operation *> result;
255+
topoSortRegion(*rootRegion, relevantRegions, toSort, result);
256+
assert(result.size() == toSort.size() &&
257+
"expected all operations to be present in the result");
258+
return result;
221259
}
222260

223261
/// Returns true if `value` (transitively) depends on iteration-carried values
Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,38 @@
1-
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" 2>&1 | FileCheck %s
1+
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" --split-input-file | FileCheck %s
22

3-
// CHECK-LABEL: Testing : region
4-
// CHECK: arith.addi {{.*}} : index
5-
// CHECK-NEXT: scf.for
6-
// CHECK: } {__test_sort_original_idx__ = 2 : i64}
7-
// CHECK-NEXT: arith.addi {{.*}} : i32
8-
// CHECK-NEXT: arith.subi {{.*}} : i32
9-
func.func @region(
10-
%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
11-
%arg4 : i32, %arg5 : i32, %arg6 : i32,
12-
%buffer : memref<i32>) {
13-
%0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
14-
%idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
15-
scf.for %arg7 = %idx to %arg2 step %arg3 {
16-
%2 = arith.addi %0, %arg5 : i32
17-
%3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
18-
memref.store %3, %buffer[] : memref<i32>
19-
} {__test_sort_original_idx__ = 2}
3+
// CHECK-LABEL: single_element
4+
func.func @single_element() {
5+
// CHECK: test_sort_index = 0
6+
return {test_to_sort}
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: @simple_region
12+
func.func @simple_region(%cond: i1) {
13+
// CHECK: test_sort_index = 0
14+
%0 = arith.constant {test_to_sort} 42 : i32
15+
scf.if %cond {
16+
%1 = arith.addi %0, %0 : i32
17+
// CHECK: test_sort_index = 2
18+
%2 = arith.subi %0, %1 {test_to_sort} : i32
19+
// CHECK: test_sort_index = 1
20+
} {test_to_sort}
21+
return
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: @multi_region
27+
func.func @multi_region(%cond: i1) {
28+
scf.if %cond {
29+
// CHECK: test_sort_index = 0
30+
%0 = arith.constant {test_to_sort} 42 : i32
31+
}
32+
33+
scf.if %cond {
34+
// CHECK: test_sort_index = 1
35+
%0 = arith.constant {test_to_sort} 24 : i32
36+
}
2037
return
2138
}

0 commit comments

Comments
 (0)