Skip to content

Commit 63f3b0f

Browse files
committed
Fixups
1 parent aa81f97 commit 63f3b0f

File tree

2 files changed

+48
-43
lines changed

2 files changed

+48
-43
lines changed

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ bool isValidSMETileElementType(Type type);
3535
/// otherwise.
3636
bool isValidSMETileVectorType(VectorType vType);
3737

38-
/// Returns the type of SME tile this vector type corresponds to or none.
38+
/// Returns the type of SME tile this vector type corresponds to, or none if the
39+
/// vector type does not fit within an SME tile.
3940
std::optional<ArmSMETileType> getSMETileType(VectorType);
4041

4142
/// Verifies the tile ID (if set) on this tile operation is valid.

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4646
#include "mlir/Dialect/Func/IR/FuncOps.h"
4747
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
48+
#include "llvm/ADT/TypeSwitch.h"
4849

4950
#define DEBUG_TYPE "allocate-arm-sme-tiles"
5051

@@ -151,47 +152,52 @@ static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
151152
/// Collects transitive uses of a root value through control flow. This can
152153
/// handle basic SCF constructs, along with control flow (br and cond_br).
153154
/// Simple loops work at the SCF level, while more complex control flow can be
154-
/// delt with after lowering to CF. This can be used to implement basic tile
155+
/// dealt with after lowering to CF. This is used to implement basic tile
155156
/// allocation.
156157
static void findDependantOps(Value rootValue,
157158
SetVector<Operation *> &dependantOps) {
158159
auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
159160
for (auto [idx, value] : llvm::enumerate(inputValues)) {
160-
if (value != rootValue)
161-
continue;
162-
findDependantOps(exitValues[idx], dependantOps);
161+
if (value == rootValue)
162+
findDependantOps(exitValues[idx], dependantOps);
163163
}
164164
};
165165
for (Operation *user : rootValue.getUsers()) {
166166
if (dependantOps.contains(user))
167167
continue;
168168
dependantOps.insert(user);
169-
if (auto branchOp = llvm::dyn_cast<cf::BranchOp>(user)) {
170-
// (CF) Follow branch.
171-
traverseCorrespondingValues(branchOp.getDestOperands(),
172-
branchOp.getDest()->getArguments());
173-
} else if (auto condBranchOp = llvm::dyn_cast<cf::CondBranchOp>(user)) {
174-
// (CF) Follow true branch.
175-
traverseCorrespondingValues(condBranchOp.getTrueOperands(),
176-
condBranchOp.getTrueDest()->getArguments());
177-
// (CF) Follow false branch.
178-
traverseCorrespondingValues(condBranchOp.getFalseOperands(),
179-
condBranchOp.getFalseDest()->getArguments());
180-
} else if (auto loop = llvm::dyn_cast<LoopLikeOpInterface>(user)) {
181-
// (SCF) Follow iter_args of (basic) loops (e.g. for loops).
182-
traverseCorrespondingValues(loop.getInits(), loop.getRegionIterArgs());
183-
} else if (user->hasTrait<OpTrait::ReturnLike>()) {
184-
// (SCF) Follow yields of (basic) control flow (e.g. for loops).
185-
auto parent = user->getParentOp();
186-
// Don't traverse outside a function.
187-
if (llvm::isa<FunctionOpInterface>(parent))
188-
continue;
189-
traverseCorrespondingValues(user->getOperands(), parent->getResults());
190-
} else {
191-
// Otherwise, assume users of _any_ result are dependant.
192-
for (Value result : user->getResults())
193-
findDependantOps(result, dependantOps);
194-
}
169+
TypeSwitch<Operation *>(user)
170+
.Case<cf::BranchOp>([&](auto branchOp) {
171+
// (CF) Follow branch.
172+
traverseCorrespondingValues(branchOp.getDestOperands(),
173+
branchOp.getDest()->getArguments());
174+
})
175+
.Case<cf::CondBranchOp>([&](auto condBranchOp) {
176+
// (CF) Follow true branch.
177+
traverseCorrespondingValues(
178+
condBranchOp.getTrueOperands(),
179+
condBranchOp.getTrueDest()->getArguments());
180+
// (CF) Follow false branch.
181+
traverseCorrespondingValues(
182+
condBranchOp.getFalseOperands(),
183+
condBranchOp.getFalseDest()->getArguments());
184+
})
185+
.Case<LoopLikeOpInterface>([&](auto loopOp) {
186+
// (SCF) Follow iter_args of (basic) loops (e.g. for loops).
187+
traverseCorrespondingValues(loopOp.getInits(),
188+
loopOp.getRegionIterArgs());
189+
})
190+
.Case<scf::YieldOp>([&](auto yieldOp) {
191+
// (SCF) Follow yields of (basic) control flow (e.g. for loops).
192+
auto parent = user->getParentOp();
193+
traverseCorrespondingValues(user->getOperands(),
194+
parent->getResults());
195+
})
196+
.Default([&](auto) {
197+
// Otherwise, assume users of _any_ result are dependant.
198+
for (Value result : user->getResults())
199+
findDependantOps(result, dependantOps);
200+
});
195201
}
196202
}
197203

@@ -208,38 +214,36 @@ struct AssignTileIDsPattern
208214
return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
209215

210216
auto func = tileOp->getParentOfType<FunctionOpInterface>();
211-
TileMask tilesInUse;
212-
if (auto tilesInUseAttr = func->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
217+
TileMask tilesInUse = TileMask::kNone;
218+
if (auto tilesInUseAttr = llvm::dyn_cast_or_null<IntegerAttr>(
219+
func->getDiscardableAttr(kTilesInUseAttr)))
213220
tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
214-
else
215-
tilesInUse = TileMask::kNone;
216221

217222
auto tileId = allocateTileId(*tileType, tilesInUse);
218223
if (failed(tileId))
219224
return tileOp.emitError("ran out of SME virtual tiles!");
220225

221-
func->setAttr(kTilesInUseAttr,
222-
rewriter.getI32IntegerAttr((unsigned)tilesInUse));
226+
func->setDiscardableAttr(kTilesInUseAttr,
227+
rewriter.getI32IntegerAttr((unsigned)tilesInUse));
223228

224229
// Find all the ops that (transitively) depend on this tile.
225230
SetVector<Operation *> dependantOps;
226231
findDependantOps(tileOp->getResult(0), dependantOps);
227232

228-
// Set all operations to use the same tile ID.
233+
// Set all operations dependent on `tileOp` to use the same tile ID.
229234
// This is a naive tile allocation scheme, but works for common cases. For
230235
// example, as this only allocates tile IDs to existing ops, it can't solve
231-
// cases like:
236+
// cases like this (%tileA and %tileB come from different root operations):
232237
//
233238
// %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
234239
// scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
235240
// } else {
236241
// scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
237242
// }
238243
//
239-
// Where %tileA and %tileB come from different root operations. This case
240-
// would require allocating a new tile for the result of the scf.if, and
241-
// moving the contents of %tileA or %tileB to result tile (based on the
242-
// %some_cond).
244+
// This case would require allocating a new tile for the result of the
245+
// scf.if, and moving the contents of %tileA or %tileB to result tile (based
246+
// on the %some_cond).
243247
auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
244248
tileOp.setTileId(tileIDAttr);
245249
for (auto *op : dependantOps) {

0 commit comments

Comments
 (0)