Skip to content

Commit 31fbdab

Browse files
[mlir][transforms] Add topological sort analysis
This change add a helper function for computing a topological sorting of a list of ops. E.g. this can be useful in transforms where a subset of ops should be cloned without dominance errors. The analysis reuses the existing implementation in TopologicalSortUtils.cpp. Differential Revision: https://reviews.llvm.org/D131669
1 parent e86119b commit 31fbdab

File tree

6 files changed

+177
-36
lines changed

6 files changed

+177
-36
lines changed

mlir/include/mlir/Transforms/TopologicalSortUtils.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,23 @@ bool sortTopologically(
9090
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
9191

9292
/// Given a block, sort its operations in topological order, excluding its
93-
/// terminator if it has one.
93+
/// terminator if it has one. This sort is stable.
9494
bool sortTopologically(
9595
Block *block,
9696
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
9797

98+
/// Compute a topological ordering of the given ops. All ops must belong to the
99+
/// specified block.
100+
///
101+
/// This sort is not stable.
102+
///
103+
/// Note: If the specified ops contain incomplete/interrupted SSA use-def
104+
/// chains, the result may not actually be a topological sorting with respect to
105+
/// the entire program.
106+
bool computeTopologicalSorting(
107+
Block *block, MutableArrayRef<Operation *> ops,
108+
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
109+
98110
} // end namespace mlir
99111

100112
#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H

mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,19 @@
88

99
#include "mlir/Transforms/TopologicalSortUtils.h"
1010
#include "mlir/IR/OpDefinition.h"
11+
#include "llvm/ADT/SetVector.h"
1112

1213
using namespace mlir;
1314

14-
bool mlir::sortTopologically(
15-
Block *block, llvm::iterator_range<Block::iterator> ops,
16-
function_ref<bool(Value, Operation *)> isOperandReady) {
17-
if (ops.empty())
18-
return true;
19-
20-
// The set of operations that have not yet been scheduled.
21-
DenseSet<Operation *> unscheduledOps;
22-
// Mark all operations as unscheduled.
23-
for (Operation &op : ops)
24-
unscheduledOps.insert(&op);
25-
26-
Block::iterator nextScheduledOp = ops.begin();
27-
Block::iterator end = ops.end();
28-
15+
/// Return `true` if the given operation is ready to be scheduled.
16+
static bool isOpReady(Block *block, Operation *op,
17+
DenseSet<Operation *> &unscheduledOps,
18+
function_ref<bool(Value, Operation *)> isOperandReady) {
2919
// An operation is ready to be scheduled if all its operands are ready. An
3020
// operation is ready if:
3121
const auto isReady = [&](Value value, Operation *top) {
3222
// - the user-provided callback marks it as ready,
33-
if (isOperandReady && isOperandReady(value, top))
23+
if (isOperandReady && isOperandReady(value, op))
3424
return true;
3525
Operation *parent = value.getDefiningOp();
3626
// - it is a block argument,
@@ -41,12 +31,38 @@ bool mlir::sortTopologically(
4131
if (!ancestor)
4232
return true;
4333
// - it is defined in a nested region, or
44-
if (ancestor == top)
34+
if (ancestor == op)
4535
return true;
4636
// - its ancestor in the block is scheduled.
4737
return !unscheduledOps.contains(ancestor);
4838
};
4939

40+
// An operation is recursively ready to be scheduled of it and its nested
41+
// operations are ready.
42+
WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
43+
return llvm::all_of(nestedOp->getOperands(),
44+
[&](Value operand) { return isReady(operand, op); })
45+
? WalkResult::advance()
46+
: WalkResult::interrupt();
47+
});
48+
return !readyToSchedule.wasInterrupted();
49+
}
50+
51+
bool mlir::sortTopologically(
52+
Block *block, llvm::iterator_range<Block::iterator> ops,
53+
function_ref<bool(Value, Operation *)> isOperandReady) {
54+
if (ops.empty())
55+
return true;
56+
57+
// The set of operations that have not yet been scheduled.
58+
DenseSet<Operation *> unscheduledOps;
59+
// Mark all operations as unscheduled.
60+
for (Operation &op : ops)
61+
unscheduledOps.insert(&op);
62+
63+
Block::iterator nextScheduledOp = ops.begin();
64+
Block::iterator end = ops.end();
65+
5066
bool allOpsScheduled = true;
5167
while (!unscheduledOps.empty()) {
5268
bool scheduledAtLeastOnce = false;
@@ -56,16 +72,7 @@ bool mlir::sortTopologically(
5672
// set, and "schedule" it (move it before the `nextScheduledOp`).
5773
for (Operation &op :
5874
llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
59-
// An operation is recursively ready to be scheduled of it and its nested
60-
// operations are ready.
61-
WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
62-
return llvm::all_of(
63-
nestedOp->getOperands(),
64-
[&](Value operand) { return isReady(operand, &op); })
65-
? WalkResult::advance()
66-
: WalkResult::interrupt();
67-
});
68-
if (readyToSchedule.wasInterrupted())
75+
if (!isOpReady(block, &op, unscheduledOps, isOperandReady))
6976
continue;
7077

7178
// Schedule the operation by moving it to the start.
@@ -96,3 +103,48 @@ bool mlir::sortTopologically(
96103
isOperandReady);
97104
return sortTopologically(block, *block, isOperandReady);
98105
}
106+
107+
bool mlir::computeTopologicalSorting(
108+
Block *block, MutableArrayRef<Operation *> ops,
109+
function_ref<bool(Value, Operation *)> isOperandReady) {
110+
if (ops.empty())
111+
return true;
112+
113+
// The set of operations that have not yet been scheduled.
114+
DenseSet<Operation *> unscheduledOps;
115+
116+
// Mark all operations as unscheduled.
117+
for (Operation *op : ops) {
118+
assert(op->getBlock() == block && "op must belong to block");
119+
unscheduledOps.insert(op);
120+
}
121+
122+
unsigned nextScheduledOp = 0;
123+
124+
bool allOpsScheduled = true;
125+
while (!unscheduledOps.empty()) {
126+
bool scheduledAtLeastOnce = false;
127+
128+
// Loop over the ops that are not sorted yet, try to find the ones "ready",
129+
// i.e. the ones for which there aren't any operand produced by an op in the
130+
// set, and "schedule" it (swap it with the op at `nextScheduledOp`).
131+
for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
132+
if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady))
133+
continue;
134+
135+
// Schedule the operation by moving it to the start.
136+
unscheduledOps.erase(ops[i]);
137+
std::swap(ops[i], ops[nextScheduledOp]);
138+
scheduledAtLeastOnce = true;
139+
++nextScheduledOp;
140+
}
141+
142+
// If no operations were scheduled, just schedule the first op and continue.
143+
if (!scheduledAtLeastOnce) {
144+
allOpsScheduled = false;
145+
unscheduledOps.erase(ops[nextScheduledOp++]);
146+
}
147+
}
148+
149+
return allOpsScheduled;
150+
}

mlir/test/Transforms/test-toposort.mlir

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
// RUN: mlir-opt -topological-sort %s | FileCheck %s
2+
// RUN: mlir-opt -test-topological-sort-analysis %s | FileCheck %s -check-prefix=CHECK-ANALYSIS
23

34
// Test producer is after user.
45
// CHECK-LABEL: test.graph_region
5-
test.graph_region {
6+
// CHECK-ANALYSIS-LABEL: test.graph_region
7+
test.graph_region attributes{"root"} {
68
// CHECK-NEXT: test.foo
79
// CHECK-NEXT: test.baz
810
// CHECK-NEXT: test.bar
9-
%0 = "test.foo"() : () -> i32
10-
"test.bar"(%1, %0) : (i32, i32) -> ()
11-
%1 = "test.baz"() : () -> i32
11+
12+
// CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0
13+
// CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2
14+
// CHECK-ANALYSIS-NEXT: test.baz{{.*}} {pos = 1
15+
%0 = "test.foo"() {selected} : () -> i32
16+
"test.bar"(%1, %0) {selected} : (i32, i32) -> ()
17+
%1 = "test.baz"() {selected} : () -> i32
1218
}
1319

1420
// Test cycles.
1521
// CHECK-LABEL: test.graph_region
16-
test.graph_region {
22+
// CHECK-ANALYSIS-LABEL: test.graph_region
23+
test.graph_region attributes{"root"} {
1724
// CHECK-NEXT: test.d
1825
// CHECK-NEXT: test.a
1926
// CHECK-NEXT: test.c
2027
// CHECK-NEXT: test.b
21-
%2 = "test.c"(%1) : (i32) -> i32
28+
29+
// CHECK-ANALYSIS-NEXT: test.c{{.*}} {pos = 0
30+
// CHECK-ANALYSIS-NEXT: test.b{{.*}} : (
31+
// CHECK-ANALYSIS-NEXT: test.a{{.*}} {pos = 2
32+
// CHECK-ANALYSIS-NEXT: test.d{{.*}} {pos = 1
33+
%2 = "test.c"(%1) {selected} : (i32) -> i32
2234
%1 = "test.b"(%0, %2) : (i32, i32) -> i32
23-
%0 = "test.a"(%3) : (i32) -> i32
24-
%3 = "test.d"() : () -> i32
35+
%0 = "test.a"(%3) {selected} : (i32) -> i32
36+
%3 = "test.d"() {selected} : () -> i32
2537
}
2638

2739
// Test block arguments.

mlir/test/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTestTransforms
55
TestControlFlowSink.cpp
66
TestInlining.cpp
77
TestIntRangeInference.cpp
8+
TestTopologicalSort.cpp
89

910
EXCLUDE_FROM_LIBMLIR
1011

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/IR/Builders.h"
10+
#include "mlir/IR/BuiltinOps.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/TopologicalSortUtils.h"
13+
14+
using namespace mlir;
15+
16+
namespace {
17+
struct TestTopologicalSortAnalysisPass
18+
: public PassWrapper<TestTopologicalSortAnalysisPass,
19+
OperationPass<ModuleOp>> {
20+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass)
21+
22+
StringRef getArgument() const final {
23+
return "test-topological-sort-analysis";
24+
}
25+
StringRef getDescription() const final {
26+
return "Test topological sorting of ops";
27+
}
28+
29+
void runOnOperation() override {
30+
Operation *op = getOperation();
31+
OpBuilder builder(op->getContext());
32+
33+
op->walk([&](Operation *root) {
34+
if (!root->hasAttr("root"))
35+
return WalkResult::advance();
36+
37+
assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() &&
38+
"expected one block");
39+
Block *block = &root->getRegion(0).front();
40+
SmallVector<Operation *> selectedOps;
41+
block->walk([&](Operation *op) {
42+
if (op->hasAttr("selected"))
43+
selectedOps.push_back(op);
44+
});
45+
46+
computeTopologicalSorting(block, selectedOps);
47+
for (const auto &it : llvm::enumerate(selectedOps))
48+
it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
49+
50+
return WalkResult::advance();
51+
});
52+
}
53+
};
54+
} // namespace
55+
56+
namespace mlir {
57+
namespace test {
58+
void registerTestTopologicalSortAnalysisPass() {
59+
PassRegistration<TestTopologicalSortAnalysisPass>();
60+
}
61+
} // namespace test
62+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ void registerTestSCFUtilsPass();
111111
void registerTestSliceAnalysisPass();
112112
void registerTestTensorTransforms();
113113
void registerTestTilingInterface();
114+
void registerTestTopologicalSortAnalysisPass();
114115
void registerTestTransformDialectInterpreterPass();
115116
void registerTestVectorLowerings();
116117
void registerTestNvgpuLowerings();
@@ -207,6 +208,7 @@ void registerTestPasses() {
207208
mlir::test::registerTestSliceAnalysisPass();
208209
mlir::test::registerTestTensorTransforms();
209210
mlir::test::registerTestTilingInterface();
211+
mlir::test::registerTestTopologicalSortAnalysisPass();
210212
mlir::test::registerTestTransformDialectInterpreterPass();
211213
mlir::test::registerTestVectorLowerings();
212214
mlir::test::registerTestNvgpuLowerings();

0 commit comments

Comments
 (0)