7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
9
// This transform allocates SME tiles at the 'func.func' op level for ArmSME
10
- // operations. It does this using a 16-bit tile mask that has a bit for each
11
- // 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
10
+ // operations. It roughly implements a linear scan register allocator, similar
11
+ // to the one outlined in [1], but with simplifications and assumptions made for
12
+ // our use case. Note that this is a greedy allocator (so it may not always find
13
+ // the most optimal allocation of tiles).
14
+ //
15
+ // The allocator operates at the CF dialect level. It is the responsibility of
16
+ // users to ensure the IR has been lowered to CF before invoking the tile
17
+ // allocator.
12
18
//
13
19
// The 128-bit tiles overlap with other element tiles as follows (see section
14
- // B2.3.2 of SME spec [1 ]):
20
+ // B2.3.2 of SME spec [2 ]):
15
21
//
16
22
// Tile Overlaps
17
23
// ---------------------------------------------------------------------------
32
38
// ZA6.D ZA6.Q, ZA14.Q
33
39
// ZA7.D ZA7.Q, ZA15.Q
34
40
//
35
- // [1] https://developer.arm.com/documentation/ddi0616/aa
41
+ // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
42
+ // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
43
+ // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
44
+ // [2] https://developer.arm.com/documentation/ddi0616/aa
36
45
//
37
46
// ===----------------------------------------------------------------------===//
38
47
@@ -214,8 +223,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
214
223
}
215
224
}
216
225
217
- // / Splits conditional branches (see `splitCondBranches`), then inserts tile
218
- // / copies at `cf.br` operations.
226
+ // / Inserts tile copies at `cf.br` operations.
219
227
// /
220
228
// / BEFORE:
221
229
// / ```mlir
@@ -228,7 +236,6 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
228
236
// / ```
229
237
void insertCopiesAtBranches (IRRewriter &rewriter,
230
238
FunctionOpInterface function) {
231
- splitCondBranches (rewriter, function);
232
239
for (Block &block : function.getBlocks ()) {
233
240
Operation *terminator = block.getTerminator ();
234
241
if (!isa<cf::BranchOp>(terminator))
@@ -244,6 +251,20 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
244
251
}
245
252
}
246
253
254
+ // / Prepares the IR for tile allocation. It does this by first 'splitting'
255
+ // / conditional branches (see `splitCondBranches`), then inserting tile copies
256
+ // / at branch operations. The conditional branches are split to prevent the
257
+ // / copies needed for them overlapping between the true and false paths of the
258
+ // / branch (see `tile-allocation-copies.mlir` and
259
+ // / `tile-allocation-liveness.mlir` for examples). The copies break up live
260
+ // / ranges and ensure when moving out of SSA the semantics of the program are
261
+ // / persevered.
262
+ void preprocessForTileAllocation (IRRewriter &rewriter,
263
+ FunctionOpInterface function) {
264
+ splitCondBranches (rewriter, function);
265
+ insertCopiesAtBranches (rewriter, function);
266
+ }
267
+
247
268
// / A live range for a (collection of) tile values. A live range is built up of
248
269
// / intervals [start, end) which represent parts of the program where the value
249
270
// / needs to be live (i.e. in an SME virtual tile).
@@ -295,6 +316,9 @@ struct LiveRange {
295
316
};
296
317
297
318
// / Number operations within a function to allow computing live ranges.
319
+ // / Operations are numbered consecutively wihin blocks, and the blocks are
320
+ // / topologically sorted (using forward edges). This function is only correct if
321
+ // / all ArmSME have been converted to CF (which is asserted).
298
322
DenseMap<Operation *, unsigned >
299
323
generateOperationNumbering (FunctionOpInterface function) {
300
324
unsigned index = 0 ;
@@ -304,7 +328,6 @@ generateOperationNumbering(FunctionOpInterface function) {
304
328
for (Block *block : blocks) {
305
329
index++; // We want block args to have their own number.
306
330
for (Operation &op : block->getOperations ()) {
307
- // This is only correct if all ArmSME have been converted to CF.
308
331
#ifndef NDEBUG
309
332
op.walk ([&](ArmSMETileOpInterface nestedOp) {
310
333
assert (&op == nestedOp.getOperation () &&
@@ -324,7 +347,9 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
324
347
Liveness &liveness, FunctionOpInterface function) {
325
348
DenseMap<Value, LiveRange> liveRanges;
326
349
// / Defines or updates a live range for an SME tile value. Live-ins may update
327
- // / an existing live range (rather than define a new one).
350
+ // / an existing live range (rather than define a new one). Note: If
351
+ // / `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
352
+ // / the block.
328
353
auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
329
354
LivenessBlockInfo const &livenessInfo,
330
355
bool liveAtBlockEntry = false ) {
@@ -335,10 +360,10 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
335
360
LiveRange &valueLiveRange = it->second ;
336
361
auto lastUseInBlock = livenessInfo.getEndOperation (value, firstUseOrDef);
337
362
// Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
338
- unsigned start =
363
+ unsigned startOpIdx =
339
364
operationToIndexMap.at (firstUseOrDef) + (liveAtBlockEntry ? -1 : 0 );
340
- unsigned end = operationToIndexMap.at (lastUseInBlock);
341
- valueLiveRange.insert (value, start, end );
365
+ unsigned endOpIdx = operationToIndexMap.at (lastUseInBlock);
366
+ valueLiveRange.insert (value, startOpIdx, endOpIdx );
342
367
};
343
368
344
369
for (Block &block : function.getBlocks ()) {
@@ -511,6 +536,20 @@ void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
511
536
}
512
537
}
513
538
539
+ // / Assigns a tile ID to an MLIR value.
540
+ void assignTileIdToValue (IRRewriter &rewriter, Value value,
541
+ IntegerAttr tileIdAttr) {
542
+ if (auto tileOp = value.getDefiningOp <ArmSMETileOpInterface>())
543
+ rewriter.modifyOpInPlace (tileOp, [&] { tileOp.setTileId (tileIdAttr); });
544
+ for (Operation *user : value.getUsers ()) {
545
+ if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
546
+ // Ensure ArmSME ops that don't produce a value still get a tile ID.
547
+ if (!hasTileResult (tileOp))
548
+ rewriter.modifyOpInPlace (tileOp, [&] { tileOp.setTileId (tileIdAttr); });
549
+ }
550
+ }
551
+ }
552
+
514
553
// / Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
515
554
LogicalResult assignTileIdsAndResolveTrivialConflicts (
516
555
IRRewriter &rewriter, FunctionOpInterface function,
@@ -523,63 +562,88 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
523
562
return true ;
524
563
return liveRange->values .contains (value);
525
564
};
526
- for (Value value : liveRange->values ) {
527
- for (Operation *user : value.getUsers ()) {
528
- if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
529
- // Ensure ArmSME ops that don't produce a value still get a tile ID.
530
- if (!hasTileResult (tileOp))
531
- rewriter.modifyOpInPlace (tileOp,
532
- [&] { tileOp.setTileId (tileIdAttr); });
533
- }
534
- }
565
+
566
+ // / Eliminates copies where the operand has the same tile ID.
567
+ auto foldRedundantCopies = [&](Value value) -> LogicalResult {
535
568
auto copyOp = value.getDefiningOp <CopyTileOp>();
536
- if (copyOp && isAllocatedToSameTile (copyOp.getTile ())) {
537
- // Fold redundant copies.
538
- rewriter.replaceAllUsesWith (copyOp, copyOp.getTile ());
539
- } else if (auto tileOp = value.getDefiningOp <ArmSMETileOpInterface>()) {
540
- rewriter.modifyOpInPlace (tileOp, [&] { tileOp.setTileId (tileIdAttr); });
541
- // Rectify operand tile IDs with result tile IDs.
542
- OpOperand *tileOperand = getTileOpOperand (tileOp);
543
- if (!tileOperand || isAllocatedToSameTile (tileOperand->get ()))
544
- continue ;
545
- auto operandTileOp =
546
- tileOperand->get ().getDefiningOp <ArmSMETileOpInterface>();
547
- if (!isTriviallyCloneableTileOp (operandTileOp)) {
548
- auto error =
549
- tileOp.emitOpError (" tile operand allocated to different SME "
550
- " virtial tile (move required)" );
551
- error.attachNote (tileOperand->get ().getLoc ())
552
- << " tile operand is: " << tileOperand->get ();
553
- return error;
569
+ if (!copyOp || !isAllocatedToSameTile (copyOp.getTile ()))
570
+ return failure ();
571
+ rewriter.replaceAllUsesWith (copyOp, copyOp.getTile ());
572
+ return success ();
573
+ };
574
+
575
+ // / Validates each predecessor to a tile block argument has been assigned
576
+ // / the same tile ID.
577
+ auto validateBlockArguments = [&](Value value) {
578
+ auto blockArg = dyn_cast<BlockArgument>(value);
579
+ if (!blockArg) {
580
+ // Not a block argument (nothing to validate).
581
+ return success ();
582
+ }
583
+ bool tileMismatch = false ;
584
+ forEachPredecessorTileValue (blockArg, [&](Value predecessorTile) {
585
+ if (tileMismatch)
586
+ return ;
587
+ if (!isAllocatedToSameTile (predecessorTile)) {
588
+ blockArg.getOwner ()->getParentOp ()->emitOpError (
589
+ " block argument not allocated to the same SME virtial tile as "
590
+ " predecessors" );
591
+ tileMismatch = true ;
554
592
}
555
- // Cloning prevents a move/spill (though may require recomputation).
556
- rewriter.setInsertionPoint (tileOp);
557
- auto clonedOp = operandTileOp.clone ();
593
+ });
594
+ return success (/* isSuccess=*/ !tileMismatch);
595
+ };
596
+
597
+ // / Attempts to resolve (trivial) tile ID conflicts.
598
+ auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
599
+ auto tileOp = value.getDefiningOp <ArmSMETileOpInterface>();
600
+ OpOperand *tileOperand = getTileOpOperand (tileOp);
601
+ if (!tileOperand || isAllocatedToSameTile (tileOperand->get ())) {
602
+ // Operand already allocated to the correct tile.
603
+ // No conflict to resolve.
604
+ return success ();
605
+ }
606
+ auto operandTileOp =
607
+ tileOperand->get ().getDefiningOp <ArmSMETileOpInterface>();
608
+ if (!isTriviallyCloneableTileOp (operandTileOp)) {
609
+ auto error =
610
+ tileOp.emitOpError (" tile operand allocated to different SME "
611
+ " virtial tile (move required)" );
612
+ error.attachNote (tileOperand->get ().getLoc ())
613
+ << " tile operand is: " << tileOperand->get ();
614
+ return error;
615
+ }
616
+ // Cloning prevents a move/spill (though may require recomputation).
617
+ rewriter.setInsertionPoint (tileOp);
618
+ auto clonedOp = operandTileOp.clone ();
619
+ rewriter.modifyOpInPlace (clonedOp,
620
+ [&] { clonedOp.setTileId (tileOp.getTileId ()); });
621
+ rewriter.insert (clonedOp);
622
+ if (isa<CopyTileOp>(tileOp)) {
623
+ rewriter.replaceAllUsesWith (tileOp->getResult (0 ),
624
+ clonedOp->getResult (0 ));
625
+ } else {
558
626
rewriter.modifyOpInPlace (
559
- clonedOp, [&] { clonedOp.setTileId (tileOp.getTileId ()); });
560
- rewriter.insert (clonedOp);
561
- if (copyOp) {
562
- rewriter.replaceAllUsesWith (copyOp, clonedOp->getResult (0 ));
563
- } else {
564
- rewriter.modifyOpInPlace (
565
- tileOp, [&] { tileOperand->assign (clonedOp->getResult (0 )); });
566
- }
567
- } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
568
- // Validate block arguments.
569
- bool tileMismatch = false ;
570
- forEachPredecessorTileValue (blockArg, [&](Value predecessorTile) {
571
- if (tileMismatch)
572
- return ;
573
- if (!isAllocatedToSameTile (predecessorTile)) {
574
- blockArg.getOwner ()->getParentOp ()->emitOpError (
575
- " block argument not allocated to the same SME virtial tile as "
576
- " predecessors" );
577
- tileMismatch = true ;
578
- }
579
- });
580
- if (tileMismatch)
581
- return failure ();
627
+ tileOp, [&] { tileOperand->assign (clonedOp->getResult (0 )); });
582
628
}
629
+ return success ();
630
+ };
631
+
632
+ for (Value value : liveRange->values ) {
633
+ // 1. Assign the tile ID to the value.
634
+ assignTileIdToValue (rewriter, value, tileIdAttr);
635
+
636
+ // 2. Attempt to eliminate redundant tile copies.
637
+ if (succeeded (foldRedundantCopies (value)))
638
+ continue ;
639
+
640
+ // 3. Validate tile block arguments.
641
+ if (failed (validateBlockArguments (value)))
642
+ return failure ();
643
+
644
+ // 4. Attempt to resolve (trivial) tile ID conflicts.
645
+ if (failed (resolveTrivialTileConflicts (value)))
646
+ return failure ();
583
647
}
584
648
}
585
649
return success ();
@@ -619,9 +683,9 @@ struct TestTileAllocationPass
619
683
using TestTileAllocationBase::TestTileAllocationBase;
620
684
void runOnOperation () override {
621
685
FunctionOpInterface function = getOperation ();
622
- if (tileCopiesOnly ) {
686
+ if (preprocessOnly ) {
623
687
IRRewriter rewriter (function);
624
- return insertCopiesAtBranches (rewriter, function);
688
+ return preprocessForTileAllocation (rewriter, function);
625
689
}
626
690
if (failed (arm_sme::allocateSMETiles (function, dumpTileLiveRanges)))
627
691
signalPassFailure ();
@@ -634,8 +698,8 @@ LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
634
698
LiveRange::Allocator liveRangeAllocator;
635
699
IRRewriter rewriter (function.getContext ());
636
700
637
- // 1. Insert copy operations at branch operations .
638
- insertCopiesAtBranches (rewriter, function);
701
+ // 1. Preprocess the IR for tile allocation .
702
+ preprocessForTileAllocation (rewriter, function);
639
703
640
704
// 2. Gather live ranges for each ArmSME tile within the function.
641
705
Liveness liveness (function);
0 commit comments