45
45
#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
46
46
#include " mlir/Dialect/Func/IR/FuncOps.h"
47
47
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
48
+ #include " llvm/ADT/TypeSwitch.h"
48
49
49
50
#define DEBUG_TYPE " allocate-arm-sme-tiles"
50
51
@@ -151,47 +152,52 @@ static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
151
152
// / Collects transitive uses of a root value through control flow. This can
152
153
// / handle basic SCF constructs, along with control flow (br and cond_br).
153
154
// / Simple loops work at the SCF level, while more complex control flow can be
154
- // / delt with after lowering to CF. This can be used to implement basic tile
155
+ // / dealt with after lowering to CF. This is used to implement basic tile
155
156
// / allocation.
156
157
static void findDependantOps (Value rootValue,
157
158
SetVector<Operation *> &dependantOps) {
158
159
auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
159
160
for (auto [idx, value] : llvm::enumerate (inputValues)) {
160
- if (value != rootValue)
161
- continue ;
162
- findDependantOps (exitValues[idx], dependantOps);
161
+ if (value == rootValue)
162
+ findDependantOps (exitValues[idx], dependantOps);
163
163
}
164
164
};
165
165
for (Operation *user : rootValue.getUsers ()) {
166
166
if (dependantOps.contains (user))
167
167
continue ;
168
168
dependantOps.insert (user);
169
- if (auto branchOp = llvm::dyn_cast<cf::BranchOp>(user)) {
170
- // (CF) Follow branch.
171
- traverseCorrespondingValues (branchOp.getDestOperands (),
172
- branchOp.getDest ()->getArguments ());
173
- } else if (auto condBranchOp = llvm::dyn_cast<cf::CondBranchOp>(user)) {
174
- // (CF) Follow true branch.
175
- traverseCorrespondingValues (condBranchOp.getTrueOperands (),
176
- condBranchOp.getTrueDest ()->getArguments ());
177
- // (CF) Follow false branch.
178
- traverseCorrespondingValues (condBranchOp.getFalseOperands (),
179
- condBranchOp.getFalseDest ()->getArguments ());
180
- } else if (auto loop = llvm::dyn_cast<LoopLikeOpInterface>(user)) {
181
- // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
182
- traverseCorrespondingValues (loop.getInits (), loop.getRegionIterArgs ());
183
- } else if (user->hasTrait <OpTrait::ReturnLike>()) {
184
- // (SCF) Follow yields of (basic) control flow (e.g. for loops).
185
- auto parent = user->getParentOp ();
186
- // Don't traverse outside a function.
187
- if (llvm::isa<FunctionOpInterface>(parent))
188
- continue ;
189
- traverseCorrespondingValues (user->getOperands (), parent->getResults ());
190
- } else {
191
- // Otherwise, assume users of _any_ result are dependant.
192
- for (Value result : user->getResults ())
193
- findDependantOps (result, dependantOps);
194
- }
169
+ TypeSwitch<Operation *>(user)
170
+ .Case <cf::BranchOp>([&](auto branchOp) {
171
+ // (CF) Follow branch.
172
+ traverseCorrespondingValues (branchOp.getDestOperands (),
173
+ branchOp.getDest ()->getArguments ());
174
+ })
175
+ .Case <cf::CondBranchOp>([&](auto condBranchOp) {
176
+ // (CF) Follow true branch.
177
+ traverseCorrespondingValues (
178
+ condBranchOp.getTrueOperands (),
179
+ condBranchOp.getTrueDest ()->getArguments ());
180
+ // (CF) Follow false branch.
181
+ traverseCorrespondingValues (
182
+ condBranchOp.getFalseOperands (),
183
+ condBranchOp.getFalseDest ()->getArguments ());
184
+ })
185
+ .Case <LoopLikeOpInterface>([&](auto loopOp) {
186
+ // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
187
+ traverseCorrespondingValues (loopOp.getInits (),
188
+ loopOp.getRegionIterArgs ());
189
+ })
190
+ .Case <scf::YieldOp>([&](auto yieldOp) {
191
+ // (SCF) Follow yields of (basic) control flow (e.g. for loops).
192
+ auto parent = user->getParentOp ();
193
+ traverseCorrespondingValues (user->getOperands (),
194
+ parent->getResults ());
195
+ })
196
+ .Default ([&](auto ) {
197
+ // Otherwise, assume users of _any_ result are dependant.
198
+ for (Value result : user->getResults ())
199
+ findDependantOps (result, dependantOps);
200
+ });
195
201
}
196
202
}
197
203
@@ -208,38 +214,36 @@ struct AssignTileIDsPattern
208
214
return rewriter.notifyMatchFailure (tileOp, " op does not allocate a tile" );
209
215
210
216
auto func = tileOp->getParentOfType <FunctionOpInterface>();
211
- TileMask tilesInUse;
212
- if (auto tilesInUseAttr = func->getAttrOfType <IntegerAttr>(kTilesInUseAttr ))
217
+ TileMask tilesInUse = TileMask::kNone ;
218
+ if (auto tilesInUseAttr = llvm::dyn_cast_or_null<IntegerAttr>(
219
+ func->getDiscardableAttr (kTilesInUseAttr )))
213
220
tilesInUse = static_cast <TileMask>(tilesInUseAttr.getInt ());
214
- else
215
- tilesInUse = TileMask::kNone ;
216
221
217
222
auto tileId = allocateTileId (*tileType, tilesInUse);
218
223
if (failed (tileId))
219
224
return tileOp.emitError (" ran out of SME virtual tiles!" );
220
225
221
- func->setAttr (kTilesInUseAttr ,
222
- rewriter.getI32IntegerAttr ((unsigned )tilesInUse));
226
+ func->setDiscardableAttr (kTilesInUseAttr ,
227
+ rewriter.getI32IntegerAttr ((unsigned )tilesInUse));
223
228
224
229
// Find all the ops that (transitively) depend on this tile.
225
230
SetVector<Operation *> dependantOps;
226
231
findDependantOps (tileOp->getResult (0 ), dependantOps);
227
232
228
- // Set all operations to use the same tile ID.
233
+ // Set all operations dependent on `tileOp` to use the same tile ID.
229
234
// This is a naive tile allocation scheme, but works for common cases. For
230
235
// example, as this only allocates tile IDs to existing ops, it can't solve
231
- // cases like:
236
+ // cases like this (%tileA and %tileB come from different root operations) :
232
237
//
233
238
// %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
234
239
// scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
235
240
// } else {
236
241
// scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
237
242
// }
238
243
//
239
- // Where %tileA and %tileB come from different root operations. This case
240
- // would require allocating a new tile for the result of the scf.if, and
241
- // moving the contents of %tileA or %tileB to result tile (based on the
242
- // %some_cond).
244
+ // This case would require allocating a new tile for the result of the
245
+ // scf.if, and moving the contents of %tileA or %tileB to result tile (based
246
+ // on the %some_cond).
243
247
auto tileIDAttr = rewriter.getI32IntegerAttr (*tileId);
244
248
tileOp.setTileId (tileIDAttr);
245
249
for (auto *op : dependantOps) {
0 commit comments