Skip to content

Commit 5aa6038

Browse files
committed
[mlir] Make topologicalSort iterative and consider op regions
When doing topological sort we need to make sure an op is scheduled before any of the ops within its regions. Also change the algorithm to not be recursive in order to prevent potential stack overflow. Differential Revision: https://reviews.llvm.org/D113423
1 parent e068c84 commit 5aa6038

File tree

5 files changed

+92
-12
lines changed

5 files changed

+92
-12
lines changed

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,26 @@ struct DFSState {
168168
};
169169
} // namespace
170170

171-
static void DFSPostorder(Operation *current, DFSState *state) {
172-
for (Value result : current->getResults()) {
173-
for (Operation *op : result.getUsers())
174-
DFSPostorder(op, state);
175-
}
176-
bool inserted;
177-
using IterTy = decltype(state->seen.begin());
178-
IterTy iter;
179-
std::tie(iter, inserted) = state->seen.insert(current);
180-
if (inserted) {
181-
if (state->toSort.count(current) > 0) {
182-
state->topologicalCounts.push_back(current);
171+
static void DFSPostorder(Operation *root, DFSState *state) {
172+
SmallVector<Operation *> queue(1, root);
173+
std::vector<Operation *> ops;
174+
while (!queue.empty()) {
175+
Operation *current = queue.pop_back_val();
176+
ops.push_back(current);
177+
for (Value result : current->getResults()) {
178+
for (Operation *op : result.getUsers())
179+
queue.push_back(op);
180+
}
181+
for (Region &region : current->getRegions()) {
182+
for (Operation &op : region.getOps())
183+
queue.push_back(&op);
183184
}
184185
}
186+
187+
for (Operation *op : llvm::reverse(ops)) {
188+
if (state->seen.insert(op).second && state->toSort.count(op) > 0)
189+
state->topologicalCounts.push_back(op);
190+
}
185191
}
186192

187193
SetVector<Operation *>
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt %s -test-print-topological-sort 2>&1 | FileCheck %s
2+
3+
// CHECK-LABEL: Testing : region
4+
// CHECK: arith.addi {{.*}} : index
5+
// CHECK-NEXT: scf.for
6+
// CHECK: } {__test_sort_original_idx__ = 2 : i64}
7+
// CHECK-NEXT: arith.addi {{.*}} : i32
8+
// CHECK-NEXT: arith.subi {{.*}} : i32
9+
func @region(
10+
%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
11+
%arg4 : i32, %arg5 : i32, %arg6 : i32,
12+
%buffer : memref<i32>) {
13+
%0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
14+
%idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
15+
scf.for %arg7 = %idx to %arg2 step %arg3 {
16+
%2 = arith.addi %0, %arg5 : i32
17+
%3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
18+
memref.store %3, %buffer[] : memref<i32>
19+
} {__test_sort_original_idx__ = 2}
20+
return
21+
}

mlir/test/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_library(MLIRTestAnalysis
88
TestMemRefDependenceCheck.cpp
99
TestMemRefStrideCalculation.cpp
1010
TestNumberOfExecutions.cpp
11+
TestSlice.cpp
1112

1213

1314
EXCLUDE_FROM_LIBMLIR

mlir/test/lib/Analysis/TestSlice.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//===------------- TestSlice.cpp - Test slice related analisis ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Analysis/SliceAnalysis.h"
10+
#include "mlir/Pass/Pass.h"
11+
12+
using namespace mlir;
13+
14+
static const StringLiteral kOrderMarker = "__test_sort_original_idx__";
15+
16+
namespace {
17+
18+
struct TestTopologicalSortPass
19+
: public PassWrapper<TestTopologicalSortPass, FunctionPass> {
20+
StringRef getArgument() const final { return "test-print-topological-sort"; }
21+
StringRef getDescription() const final {
22+
return "Print operations in topological order";
23+
}
24+
void runOnFunction() override {
25+
std::map<int, Operation *> ops;
26+
getFunction().walk([&ops](Operation *op) {
27+
if (auto originalOrderAttr = op->getAttrOfType<IntegerAttr>(kOrderMarker))
28+
ops[originalOrderAttr.getInt()] = op;
29+
});
30+
SetVector<Operation *> sortedOp;
31+
for (auto op : ops)
32+
sortedOp.insert(op.second);
33+
sortedOp = topologicalSort(sortedOp);
34+
llvm::errs() << "Testing : " << getFunction().getName() << "\n";
35+
for (Operation *op : sortedOp) {
36+
op->print(llvm::errs());
37+
llvm::errs() << "\n";
38+
}
39+
}
40+
};
41+
42+
} // end anonymous namespace
43+
44+
namespace mlir {
45+
namespace test {
46+
void registerTestSliceAnalysisPass() {
47+
PassRegistration<TestTopologicalSortPass>();
48+
}
49+
} // namespace test
50+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ void registerTestPDLByteCodePass();
106106
void registerTestPreparationPassWithAllowedMemrefResults();
107107
void registerTestRecursiveTypesPass();
108108
void registerTestSCFUtilsPass();
109+
void registerTestSliceAnalysisPass();
109110
void registerTestVectorConversions();
110111
} // namespace test
111112
} // namespace mlir
@@ -195,6 +196,7 @@ void registerTestPasses() {
195196
mlir::test::registerTestPDLByteCodePass();
196197
mlir::test::registerTestRecursiveTypesPass();
197198
mlir::test::registerTestSCFUtilsPass();
199+
mlir::test::registerTestSliceAnalysisPass();
198200
mlir::test::registerTestVectorConversions();
199201
}
200202
#endif

0 commit comments

Comments
 (0)