|
40 | 40 | //
|
41 | 41 | //===----------------------------------------------------------------------===//
|
42 | 42 |
|
43 |
| -#include "mlir/Analysis/SliceAnalysis.h" |
44 | 43 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
45 | 44 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
|
| 45 | +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
46 | 46 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
47 |
| -#include "mlir/Transforms/DialectConversion.h" |
48 | 47 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
49 | 48 |
|
50 | 49 | #define DEBUG_TYPE "allocate-arm-sme-tiles"
|
@@ -149,6 +148,53 @@ static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
|
149 | 148 | return failure();
|
150 | 149 | }
|
151 | 150 |
|
| 151 | +/// Collects transitive uses of a root value through control flow. This can |
| 152 | +/// handle basic SCF constructs, along with control flow (br and cond_br). |
| 153 | +/// Simple loops work at the SCF level, while more complex control flow can be |
| 154 | +/// delt with after lowering to CF. This can be used to implement basic tile |
| 155 | +/// allocation. |
| 156 | +static void findDependantOps(Value rootValue, |
| 157 | + SetVector<Operation *> &dependantOps) { |
| 158 | + auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) { |
| 159 | + for (auto [idx, value] : llvm::enumerate(inputValues)) { |
| 160 | + if (value != rootValue) |
| 161 | + continue; |
| 162 | + findDependantOps(exitValues[idx], dependantOps); |
| 163 | + } |
| 164 | + }; |
| 165 | + for (Operation *user : rootValue.getUsers()) { |
| 166 | + if (dependantOps.contains(user)) |
| 167 | + continue; |
| 168 | + dependantOps.insert(user); |
| 169 | + if (auto branchOp = llvm::dyn_cast<cf::BranchOp>(user)) { |
| 170 | + // (CF) Follow branch. |
| 171 | + traverseCorrespondingValues(branchOp.getDestOperands(), |
| 172 | + user->getSuccessor(0)->getArguments()); |
| 173 | + } else if (auto condBranchOp = llvm::dyn_cast<cf::CondBranchOp>(user)) { |
| 174 | + // (CF) Follow true branch. |
| 175 | + traverseCorrespondingValues(condBranchOp.getTrueOperands(), |
| 176 | + user->getSuccessor(0)->getArguments()); |
| 177 | + // (CF) Follow false branch. |
| 178 | + traverseCorrespondingValues(condBranchOp.getFalseOperands(), |
| 179 | + user->getSuccessor(1)->getArguments()); |
| 180 | + } else if (auto loop = llvm::dyn_cast<LoopLikeOpInterface>(user)) { |
| 181 | + // (SCF) Follow iter_args of (basic) loops (e.g. for loops). |
| 182 | + traverseCorrespondingValues(loop.getInits(), loop.getRegionIterArgs()); |
| 183 | + } else if (user->hasTrait<OpTrait::ReturnLike>()) { |
| 184 | + // (SCF) Follow yields of (basic) control flow (e.g. for loops). |
| 185 | + auto parent = user->getParentOp(); |
| 186 | + // Don't traverse outside a function. |
| 187 | + if (llvm::isa<FunctionOpInterface>(parent)) |
| 188 | + continue; |
| 189 | + traverseCorrespondingValues(user->getOperands(), parent->getResults()); |
| 190 | + } else { |
| 191 | + // Otherwise, assume users of _any_ result are dependant. |
| 192 | + for (Value result : user->getResults()) |
| 193 | + findDependantOps(result, dependantOps); |
| 194 | + } |
| 195 | + } |
| 196 | +} |
| 197 | + |
152 | 198 | struct AssignTileIDsPattern
|
153 | 199 | : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
|
154 | 200 | using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
|
@@ -177,7 +223,7 @@ struct AssignTileIDsPattern
|
177 | 223 |
|
178 | 224 | // Find all the ops that (transitively) depend on this tile.
|
179 | 225 | SetVector<Operation *> dependantOps;
|
180 |
| - getForwardSlice(tileOp.getOperation(), &dependantOps); |
| 226 | + findDependantOps(tileOp->getResult(0), dependantOps); |
181 | 227 |
|
182 | 228 | // Set all operations to use the same tile ID.
|
183 | 229 | // This is a naive tile allocation scheme, but works for common cases. For
|
|
0 commit comments