Skip to content

Commit 1373a8e

Browse files
committed
More stuff
1 parent 3136b73 commit 1373a8e

File tree

5 files changed

+148
-73
lines changed

5 files changed

+148
-73
lines changed

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def TestTileAllocation
138138
Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
139139
"bool", /*default=*/"false",
140140
"Dump the live ranges of SME tiles (for debugging)">,
141-
Option<"tileCopiesOnly", "tile-copies-only", "bool", /*default=*/"false",
142-
"Only insert tile copies needed for tile allocation "
141+
Option<"preprocessOnly", "preprocess-only", "bool", /*default=*/"false",
142+
"Only preprocess IR so it is ready for tile allocation "
143143
"(but do not allocate any tiles)">
144144
];
145145
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,16 @@ VectorType getSMETileTypeForElement(Type elementType);
7474
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
7575
FunctionOpInterface function);
7676

77-
/// Returns true if `tileOp` can be cloned to resolve conflicts.
77+
/// Returns true if `tileOp` is trivially cloneable. A tile operation is
78+
/// trivially cloneable if:
79+
///
80+
/// 1. It has no operands (and only a single tile result)
81+
/// 2. It is 'Pure'
82+
///
83+
/// This ensures that the cloned operation will not share any dependencies with
84+
/// the original operation (which could also need to be considered), and that
85+
/// inserting the cloned operation at a different point in the program won't
86+
/// change the semantics of the program (as it has no side effects).
7887
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp);
7988

8089
/// Returns true if `tileOp` produces a tile result.

mlir/lib/Dialect/ArmSME/IR/Utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
150150
}
151151

152152
OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
153+
if (!tileOp)
154+
return nullptr;
153155
auto isTileOperandType = [](OpOperand &operand) {
154156
return arm_sme::isValidSMETileVectorType(operand.get().getType());
155157
};

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Lines changed: 133 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This transform allocates SME tiles at the 'func.func' op level for ArmSME
10-
// operations. It does this using a 16-bit tile mask that has a bit for each
11-
// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
10+
// operations. It roughly implements a linear scan register allocator, similar
11+
// to the one outlined in [1], but with simplifications and assumptions made for
12+
// our use case. Note that this is a greedy allocator (so it may not always find
13+
// the most optimal allocation of tiles).
14+
//
15+
// The allocator operates at the CF dialect level. It is the responsibility of
16+
// users to ensure the IR has been lowered to CF before invoking the tile
17+
// allocator.
1218
//
1319
// The 128-bit tiles overlap with other element tiles as follows (see section
14-
// B2.3.2 of SME spec [1]):
20+
// B2.3.2 of SME spec [2]):
1521
//
1622
// Tile Overlaps
1723
// ---------------------------------------------------------------------------
@@ -32,7 +38,10 @@
3238
// ZA6.D ZA6.Q, ZA14.Q
3339
// ZA7.D ZA7.Q, ZA15.Q
3440
//
35-
// [1] https://developer.arm.com/documentation/ddi0616/aa
41+
// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42+
// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43+
// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44+
// [2] https://developer.arm.com/documentation/ddi0616/aa
3645
//
3746
//===----------------------------------------------------------------------===//
3847

@@ -214,8 +223,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
214223
}
215224
}
216225

217-
/// Splits conditional branches (see `splitCondBranches`), then inserts tile
218-
/// copies at `cf.br` operations.
226+
/// Inserts tile copies at `cf.br` operations.
219227
///
220228
/// BEFORE:
221229
/// ```mlir
@@ -228,7 +236,6 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
228236
/// ```
229237
void insertCopiesAtBranches(IRRewriter &rewriter,
230238
FunctionOpInterface function) {
231-
splitCondBranches(rewriter, function);
232239
for (Block &block : function.getBlocks()) {
233240
Operation *terminator = block.getTerminator();
234241
if (!isa<cf::BranchOp>(terminator))
@@ -244,6 +251,20 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
244251
}
245252
}
246253

254+
/// Prepares the IR for tile allocation. It does this by first 'splitting'
255+
/// conditional branches (see `splitCondBranches`), then inserting tile copies
256+
/// at branch operations. The conditional branches are split to prevent the
257+
/// copies needed for them overlapping between the true and false paths of the
258+
/// branch (see `tile-allocation-copies.mlir` and
259+
/// `tile-allocation-liveness.mlir` for examples). The copies break up live
260+
/// ranges and ensure when moving out of SSA the semantics of the program are
261+
/// persevered.
262+
void preprocessForTileAllocation(IRRewriter &rewriter,
263+
FunctionOpInterface function) {
264+
splitCondBranches(rewriter, function);
265+
insertCopiesAtBranches(rewriter, function);
266+
}
267+
247268
/// A live range for a (collection of) tile values. A live range is built up of
248269
/// intervals [start, end) which represent parts of the program where the value
249270
/// needs to be live (i.e. in an SME virtual tile).
@@ -295,6 +316,9 @@ struct LiveRange {
295316
};
296317

297318
/// Number operations within a function to allow computing live ranges.
319+
/// Operations are numbered consecutively wihin blocks, and the blocks are
320+
/// topologically sorted (using forward edges). This function is only correct if
321+
/// all ArmSME have been converted to CF (which is asserted).
298322
DenseMap<Operation *, unsigned>
299323
generateOperationNumbering(FunctionOpInterface function) {
300324
unsigned index = 0;
@@ -304,7 +328,6 @@ generateOperationNumbering(FunctionOpInterface function) {
304328
for (Block *block : blocks) {
305329
index++; // We want block args to have their own number.
306330
for (Operation &op : block->getOperations()) {
307-
// This is only correct if all ArmSME have been converted to CF.
308331
#ifndef NDEBUG
309332
op.walk([&](ArmSMETileOpInterface nestedOp) {
310333
assert(&op == nestedOp.getOperation() &&
@@ -324,7 +347,9 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
324347
Liveness &liveness, FunctionOpInterface function) {
325348
DenseMap<Value, LiveRange> liveRanges;
326349
/// Defines or updates a live range for an SME tile value. Live-ins may update
327-
/// an existing live range (rather than define a new one).
350+
/// an existing live range (rather than define a new one). Note: If
351+
/// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
352+
/// the block.
328353
auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
329354
LivenessBlockInfo const &livenessInfo,
330355
bool liveAtBlockEntry = false) {
@@ -335,10 +360,10 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
335360
LiveRange &valueLiveRange = it->second;
336361
auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
337362
// Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
338-
unsigned start =
363+
unsigned startOpIdx =
339364
operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
340-
unsigned end = operationToIndexMap.at(lastUseInBlock);
341-
valueLiveRange.insert(value, start, end);
365+
unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
366+
valueLiveRange.insert(value, startOpIdx, endOpIdx);
342367
};
343368

344369
for (Block &block : function.getBlocks()) {
@@ -511,6 +536,20 @@ void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
511536
}
512537
}
513538

539+
/// Assigns a tile ID to an MLIR value.
540+
void assignTileIdToValue(IRRewriter &rewriter, Value value,
541+
IntegerAttr tileIdAttr) {
542+
if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
543+
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
544+
for (Operation *user : value.getUsers()) {
545+
if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
546+
// Ensure ArmSME ops that don't produce a value still get a tile ID.
547+
if (!hasTileResult(tileOp))
548+
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
549+
}
550+
}
551+
}
552+
514553
/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
515554
LogicalResult assignTileIdsAndResolveTrivialConflicts(
516555
IRRewriter &rewriter, FunctionOpInterface function,
@@ -523,63 +562,88 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
523562
return true;
524563
return liveRange->values.contains(value);
525564
};
526-
for (Value value : liveRange->values) {
527-
for (Operation *user : value.getUsers()) {
528-
if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
529-
// Ensure ArmSME ops that don't produce a value still get a tile ID.
530-
if (!hasTileResult(tileOp))
531-
rewriter.modifyOpInPlace(tileOp,
532-
[&] { tileOp.setTileId(tileIdAttr); });
533-
}
534-
}
565+
566+
/// Eliminates copies where the operand has the same tile ID.
567+
auto foldRedundantCopies = [&](Value value) -> LogicalResult {
535568
auto copyOp = value.getDefiningOp<CopyTileOp>();
536-
if (copyOp && isAllocatedToSameTile(copyOp.getTile())) {
537-
// Fold redundant copies.
538-
rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
539-
} else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
540-
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
541-
// Rectify operand tile IDs with result tile IDs.
542-
OpOperand *tileOperand = getTileOpOperand(tileOp);
543-
if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
544-
continue;
545-
auto operandTileOp =
546-
tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
547-
if (!isTriviallyCloneableTileOp(operandTileOp)) {
548-
auto error =
549-
tileOp.emitOpError("tile operand allocated to different SME "
550-
"virtial tile (move required)");
551-
error.attachNote(tileOperand->get().getLoc())
552-
<< "tile operand is: " << tileOperand->get();
553-
return error;
569+
if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
570+
return failure();
571+
rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
572+
return success();
573+
};
574+
575+
/// Validates each predecessor to a tile block argument has been assigned
576+
/// the same tile ID.
577+
auto validateBlockArguments = [&](Value value) {
578+
auto blockArg = dyn_cast<BlockArgument>(value);
579+
if (!blockArg) {
580+
// Not a block argument (nothing to validate).
581+
return success();
582+
}
583+
bool tileMismatch = false;
584+
forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
585+
if (tileMismatch)
586+
return;
587+
if (!isAllocatedToSameTile(predecessorTile)) {
588+
blockArg.getOwner()->getParentOp()->emitOpError(
589+
"block argument not allocated to the same SME virtial tile as "
590+
"predecessors");
591+
tileMismatch = true;
554592
}
555-
// Cloning prevents a move/spill (though may require recomputation).
556-
rewriter.setInsertionPoint(tileOp);
557-
auto clonedOp = operandTileOp.clone();
593+
});
594+
return success(/*isSuccess=*/!tileMismatch);
595+
};
596+
597+
/// Attempts to resolve (trivial) tile ID conflicts.
598+
auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
599+
auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
600+
OpOperand *tileOperand = getTileOpOperand(tileOp);
601+
if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
602+
// Operand already allocated to the correct tile.
603+
// No conflict to resolve.
604+
return success();
605+
}
606+
auto operandTileOp =
607+
tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
608+
if (!isTriviallyCloneableTileOp(operandTileOp)) {
609+
auto error =
610+
tileOp.emitOpError("tile operand allocated to different SME "
611+
"virtial tile (move required)");
612+
error.attachNote(tileOperand->get().getLoc())
613+
<< "tile operand is: " << tileOperand->get();
614+
return error;
615+
}
616+
// Cloning prevents a move/spill (though may require recomputation).
617+
rewriter.setInsertionPoint(tileOp);
618+
auto clonedOp = operandTileOp.clone();
619+
rewriter.modifyOpInPlace(clonedOp,
620+
[&] { clonedOp.setTileId(tileOp.getTileId()); });
621+
rewriter.insert(clonedOp);
622+
if (isa<CopyTileOp>(tileOp)) {
623+
rewriter.replaceAllUsesWith(tileOp->getResult(0),
624+
clonedOp->getResult(0));
625+
} else {
558626
rewriter.modifyOpInPlace(
559-
clonedOp, [&] { clonedOp.setTileId(tileOp.getTileId()); });
560-
rewriter.insert(clonedOp);
561-
if (copyOp) {
562-
rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
563-
} else {
564-
rewriter.modifyOpInPlace(
565-
tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
566-
}
567-
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
568-
// Validate block arguments.
569-
bool tileMismatch = false;
570-
forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
571-
if (tileMismatch)
572-
return;
573-
if (!isAllocatedToSameTile(predecessorTile)) {
574-
blockArg.getOwner()->getParentOp()->emitOpError(
575-
"block argument not allocated to the same SME virtial tile as "
576-
"predecessors");
577-
tileMismatch = true;
578-
}
579-
});
580-
if (tileMismatch)
581-
return failure();
627+
tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
582628
}
629+
return success();
630+
};
631+
632+
for (Value value : liveRange->values) {
633+
// 1. Assign the tile ID to the value.
634+
assignTileIdToValue(rewriter, value, tileIdAttr);
635+
636+
// 2. Attempt to eliminate redundant tile copies.
637+
if (succeeded(foldRedundantCopies(value)))
638+
continue;
639+
640+
// 3. Validate tile block arguments.
641+
if (failed(validateBlockArguments(value)))
642+
return failure();
643+
644+
// 4. Attempt to resolve (trivial) tile ID conflicts.
645+
if (failed(resolveTrivialTileConflicts(value)))
646+
return failure();
583647
}
584648
}
585649
return success();
@@ -619,9 +683,9 @@ struct TestTileAllocationPass
619683
using TestTileAllocationBase::TestTileAllocationBase;
620684
void runOnOperation() override {
621685
FunctionOpInterface function = getOperation();
622-
if (tileCopiesOnly) {
686+
if (preprocessOnly) {
623687
IRRewriter rewriter(function);
624-
return insertCopiesAtBranches(rewriter, function);
688+
return preprocessForTileAllocation(rewriter, function);
625689
}
626690
if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
627691
signalPassFailure();
@@ -634,8 +698,8 @@ LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
634698
LiveRange::Allocator liveRangeAllocator;
635699
IRRewriter rewriter(function.getContext());
636700

637-
// 1. Insert copy operations at branch operations.
638-
insertCopiesAtBranches(rewriter, function);
701+
// 1. Preprocess the IR for tile allocation.
702+
preprocessForTileAllocation(rewriter, function);
639703

640704
// 2. Gather live ranges for each ArmSME tile within the function.
641705
Liveness liveness(function);

mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-arm-sme-tile-allocation=tile-copies-only -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-arm-sme-tile-allocation=preprocess-only -split-input-file | FileCheck %s
22

33
// This file tests the inserting copies for the SME tile allocation. Copies are
44
// inserted at `cf.br` ops (the predecessors to block arguments). Conditional

0 commit comments

Comments
 (0)