Skip to content

Commit 6df97ad

Browse files
committed
Implement more semantically correct findDependantOps()
This function follows uses of a value through control flow. It understands basic SCF contructs and more generally works on control flow branches. (the previous slice analysis is very basic and does not understand any control flow)
1 parent a3d0a64 commit 6df97ad

File tree

1 file changed

+49
-3
lines changed

1 file changed

+49
-3
lines changed

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@
4040
//
4141
//===----------------------------------------------------------------------===//
4242

43-
#include "mlir/Analysis/SliceAnalysis.h"
4443
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
4544
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
45+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4646
#include "mlir/Dialect/Func/IR/FuncOps.h"
47-
#include "mlir/Transforms/DialectConversion.h"
4847
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4948

5049
#define DEBUG_TYPE "allocate-arm-sme-tiles"
@@ -149,6 +148,53 @@ static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
149148
return failure();
150149
}
151150

151+
/// Collects transitive uses of a root value through control flow. This can
152+
/// handle basic SCF constructs, along with control flow (br and cond_br).
153+
/// Simple loops work at the SCF level, while more complex control flow can be
154+
/// delt with after lowering to CF. This can be used to implement basic tile
155+
/// allocation.
156+
static void findDependantOps(Value rootValue,
157+
SetVector<Operation *> &dependantOps) {
158+
auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
159+
for (auto [idx, value] : llvm::enumerate(inputValues)) {
160+
if (value != rootValue)
161+
continue;
162+
findDependantOps(exitValues[idx], dependantOps);
163+
}
164+
};
165+
for (Operation *user : rootValue.getUsers()) {
166+
if (dependantOps.contains(user))
167+
continue;
168+
dependantOps.insert(user);
169+
if (auto branchOp = llvm::dyn_cast<cf::BranchOp>(user)) {
170+
// (CF) Follow branch.
171+
traverseCorrespondingValues(branchOp.getDestOperands(),
172+
user->getSuccessor(0)->getArguments());
173+
} else if (auto condBranchOp = llvm::dyn_cast<cf::CondBranchOp>(user)) {
174+
// (CF) Follow true branch.
175+
traverseCorrespondingValues(condBranchOp.getTrueOperands(),
176+
user->getSuccessor(0)->getArguments());
177+
// (CF) Follow false branch.
178+
traverseCorrespondingValues(condBranchOp.getFalseOperands(),
179+
user->getSuccessor(1)->getArguments());
180+
} else if (auto loop = llvm::dyn_cast<LoopLikeOpInterface>(user)) {
181+
// (SCF) Follow iter_args of (basic) loops (e.g. for loops).
182+
traverseCorrespondingValues(loop.getInits(), loop.getRegionIterArgs());
183+
} else if (user->hasTrait<OpTrait::ReturnLike>()) {
184+
// (SCF) Follow yields of (basic) control flow (e.g. for loops).
185+
auto parent = user->getParentOp();
186+
// Don't traverse outside a function.
187+
if (llvm::isa<FunctionOpInterface>(parent))
188+
continue;
189+
traverseCorrespondingValues(user->getOperands(), parent->getResults());
190+
} else {
191+
// Otherwise, assume users of _any_ result are dependant.
192+
for (Value result : user->getResults())
193+
findDependantOps(result, dependantOps);
194+
}
195+
}
196+
}
197+
152198
struct AssignTileIDsPattern
153199
: public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
154200
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
@@ -177,7 +223,7 @@ struct AssignTileIDsPattern
177223

178224
// Find all the ops that (transitively) depend on this tile.
179225
SetVector<Operation *> dependantOps;
180-
getForwardSlice(tileOp.getOperation(), &dependantOps);
226+
findDependantOps(tileOp->getResult(0), dependantOps);
181227

182228
// Set all operations to use the same tile ID.
183229
// This is a naive tile allocation scheme, but works for common cases. For

0 commit comments

Comments
 (0)