Skip to content

Commit eed72d4

Browse files
authored
[mlir][ArmSME] Support filling liveness 'holes' in the tile allocator (#98350)
Holes in a live range are points where the corresponding value does not need to be in a tile/register. If the tile allocator keeps track of these holes it can reuse tiles for more values (avoiding spills). Take this simple example: ```mlir func.func @example(%cond: i1) { %tileA = arm_sme.get_tile : vector<[4]x[4]xf32> cf.cond_br %cond, ^bb2, ^bb1 ^bb1: // If we end up here we never use %tileA again! "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> () cf.br ^bb3 ^bb2: "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> () cf.br ^bb3 ^bb3: return } ``` If you were to calculate the liveness of %tileA and %tileB. You'd see there is a hole in the liveness of %tileA in bb1: ``` %tileA %tileB ^bb0: Live ^bb1: Live ^bb2: Live ``` The tile allocator can make use of that hole and reuse the tile ID it assigned to %tileA for %tileB.
1 parent 5e8cd29 commit eed72d4

File tree

2 files changed

+283
-29
lines changed

2 files changed

+283
-29
lines changed

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,18 @@ class TileAllocator {
153153
return failure();
154154
}
155155

156+
/// Acquires a specific tile ID. Asserts the tile is initially free.
157+
void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
158+
TileMask tileMask = getMasks(tileType)[tileId];
159+
assert((tilesInUse & tileMask) == TileMask::kNone &&
160+
"cannot acquire allocated tile!");
161+
tilesInUse |= tileMask;
162+
}
163+
156164
/// Releases a previously allocated tile ID.
157165
void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
158166
TileMask tileMask = getMasks(tileType)[tileId];
159-
assert((tilesInUse & tileMask) != TileMask::kNone &&
167+
assert((tilesInUse & tileMask) == tileMask &&
160168
"cannot release unallocated tile!");
161169
tilesInUse ^= tileMask;
162170
}
@@ -289,6 +297,11 @@ struct LiveRange {
289297
.valid();
290298
}
291299

300+
/// Returns true if this range is active at `point` in the program.
301+
bool overlaps(uint64_t point) const {
302+
return ranges->lookup(point) == kValidLiveRange;
303+
}
304+
292305
/// Unions this live range with `otherRange`, aborts if the ranges overlap.
293306
void unionWith(LiveRange const &otherRange) {
294307
for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
@@ -488,76 +501,139 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
488501
return std::move(coalescedLiveRanges);
489502
}
490503

491-
/// Choose a live range to spill (via some heuristics). This picks either an
492-
/// active live range from `activeRanges` or the new live range `newRange`.
493-
LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
494-
LiveRange *newRange) {
504+
/// Choose a live range to spill (via some heuristics). This picks either a live
505+
/// range from `overlappingRanges`, or the new live range `newRange`.
506+
template <typename OverlappingRangesIterator>
507+
LiveRange *
508+
chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
509+
LiveRange *newRange) {
495510
// Heuristic: Spill trivially copyable operations (usually free).
496-
auto isTrivialSpill = [&](LiveRange *allocatedRange) {
497-
return isTileTypeGreaterOrEqual(allocatedRange->getTileType(),
511+
auto isTrivialSpill = [&](LiveRange &allocatedRange) {
512+
return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
498513
newRange->getTileType()) &&
499-
allocatedRange->values.size() == 1 &&
514+
allocatedRange.values.size() == 1 &&
500515
isTriviallyCloneableTileOp(
501-
allocatedRange->values[0]
502-
.getDefiningOp<ArmSMETileOpInterface>());
516+
allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
503517
};
504-
if (isTrivialSpill(newRange))
518+
if (isTrivialSpill(*newRange))
505519
return newRange;
506-
auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill);
507-
if (trivialSpill != activeRanges.end())
508-
return *trivialSpill;
520+
auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
521+
if (trivialSpill != overlappingRanges.end())
522+
return &*trivialSpill;
509523

510524
// Heuristic: Spill the range that ends last (with a compatible tile type).
511-
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) {
512-
return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) ||
513-
a->end() < b->end();
525+
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
526+
return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
527+
a.end() < b.end();
514528
};
515-
LiveRange *lastActiveLiveRange = *std::max_element(
516-
activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier);
517-
if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange))
518-
return lastActiveLiveRange;
529+
LiveRange &latestEndingLiveRange =
530+
*std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
531+
isSmallerTileTypeOrEndsEarlier);
532+
if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
533+
return &latestEndingLiveRange;
519534
return newRange;
520535
}
521536

522537
/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
523-
/// Note: This does not attempt to fill holes in active live ranges.
524538
void allocateTilesToLiveRanges(
525539
ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
526540
TileAllocator tileAllocator;
541+
// `activeRanges` = Live ranges that need to be in a tile at the
542+
// `currentPoint` in the program.
527543
SetVector<LiveRange *> activeRanges;
544+
// `inactiveRanges` = Live ranges that _do not_ need to be in a tile
545+
// at the `currentPoint` in the program but could become active again later.
546+
// An inactive section of a live range can be seen as a 'hole' in the live
547+
// range, where it is possible to reuse the live range's tile ID _before_ it
548+
// has ended. By identifying 'holes', the allocator can reuse tiles more
549+
// often, which helps avoid costly tile spills.
550+
SetVector<LiveRange *> inactiveRanges;
528551
for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
529-
// Release tile IDs from live ranges that have ended.
552+
auto currentPoint = nextRange->start();
553+
// 1. Update the `activeRanges` at `currentPoint`.
530554
activeRanges.remove_if([&](LiveRange *activeRange) {
531-
if (activeRange->end() <= nextRange->start()) {
555+
// Check for live ranges that have expired.
556+
if (activeRange->end() <= currentPoint) {
532557
tileAllocator.releaseTileId(activeRange->getTileType(),
533558
*activeRange->tileId);
534559
return true;
535560
}
561+
// Check for live ranges that have become inactive.
562+
if (!activeRange->overlaps(currentPoint)) {
563+
tileAllocator.releaseTileId(activeRange->getTileType(),
564+
*activeRange->tileId);
565+
inactiveRanges.insert(activeRange);
566+
return true;
567+
}
536568
return false;
537569
});
570+
// 2. Update the `inactiveRanges` at `currentPoint`.
571+
inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
572+
// Check for live ranges that have expired.
573+
if (inactiveRange->end() <= currentPoint) {
574+
return true;
575+
}
576+
// Check for live ranges that have become active.
577+
if (inactiveRange->overlaps(currentPoint)) {
578+
tileAllocator.acquireTileId(inactiveRange->getTileType(),
579+
*inactiveRange->tileId);
580+
activeRanges.insert(inactiveRange);
581+
return true;
582+
}
583+
return false;
584+
});
585+
586+
// 3. Collect inactive live ranges that overlap with the new live range.
587+
// Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
588+
// whereas this checks if there is an overlap at any future point too.
589+
SmallVector<LiveRange *> overlappingInactiveRanges;
590+
for (LiveRange *inactiveRange : inactiveRanges) {
591+
if (inactiveRange->overlaps(*nextRange)) {
592+
// We need to reserve the tile IDs of overlapping inactive ranges to
593+
// prevent two (overlapping) live ranges from getting the same tile ID.
594+
tileAllocator.acquireTileId(inactiveRange->getTileType(),
595+
*inactiveRange->tileId);
596+
overlappingInactiveRanges.push_back(inactiveRange);
597+
}
598+
}
538599

539-
// Allocate a tile ID to `nextRange`.
600+
// 4. Allocate a tile ID to `nextRange`.
540601
auto rangeTileType = nextRange->getTileType();
541602
auto tileId = tileAllocator.allocateTileId(rangeTileType);
542603
if (succeeded(tileId)) {
543604
nextRange->tileId = *tileId;
544605
} else {
606+
// Create an iterator over all overlapping live ranges.
607+
auto allOverlappingRanges = llvm::concat<LiveRange>(
608+
llvm::make_pointee_range(activeRanges.getArrayRef()),
609+
llvm::make_pointee_range(overlappingInactiveRanges));
610+
// Choose an overlapping live range to spill.
545611
LiveRange *rangeToSpill =
546-
chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange);
612+
chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
547613
if (rangeToSpill != nextRange) {
548-
// Spill an active live range (so release its tile ID first).
614+
// Spill an (in)active live range (so release its tile ID first).
549615
tileAllocator.releaseTileId(rangeToSpill->getTileType(),
550616
*rangeToSpill->tileId);
551-
activeRanges.remove(rangeToSpill);
552617
// This will always succeed after a spill (of an active live range).
553618
nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
619+
// Remove the live range from the active/inactive sets.
620+
if (!activeRanges.remove(rangeToSpill)) {
621+
bool removed = inactiveRanges.remove(rangeToSpill);
622+
assert(removed && "expected a range to be removed!");
623+
}
554624
}
555625
rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
556626
}
557627

558-
// Insert the live range into the active ranges.
628+
// 5. Insert the live range into the active ranges.
559629
if (nextRange->tileId < kInMemoryTileIdBase)
560630
activeRanges.insert(nextRange);
631+
632+
// 6. Release tiles reserved for inactive live ranges (in step 3).
633+
for (LiveRange *range : overlappingInactiveRanges) {
634+
if (*range->tileId < kInMemoryTileIdBase)
635+
tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
636+
}
561637
}
562638
}
563639

mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,181 @@ func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
430430
// Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
431431
return
432432
}
433+
434+
// -----
435+
436+
// CHECK-LIVE-RANGE-LABEL: @fill_holes_in_tile_liveness
437+
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
438+
// CHECK-LIVE-RANGE: ^bb0:
439+
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
440+
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
441+
// CHECK-LIVE-RANGE-NEXT: ^bb1:
442+
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
443+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
444+
// CHECK-LIVE-RANGE-NEXT: E test.some_use
445+
// CHECK-LIVE-RANGE-NEXT: cf.br
446+
// CHECK-LIVE-RANGE-NEXT: ^bb2:
447+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
448+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
449+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
450+
// CHECK-LIVE-RANGE-NEXT: E test.some_use
451+
// CHECK-LIVE-RANGE-NEXT: cf.br
452+
453+
// Here there's a 'hole' in the liveness of %tileA (in bb1) where another value
454+
// can reuse the tile ID assigned to %tileA. The liveness for %tileB is
455+
// entirely within the 'hole' in %tileA's live range, so %tileB should get the
456+
// same tile ID as %tileA.
457+
458+
// CHECK-LABEL: @fill_holes_in_tile_liveness
459+
func.func @fill_holes_in_tile_liveness(%cond: i1) {
460+
// CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A:.*]] : i32}
461+
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
462+
cf.cond_br %cond, ^bb2, ^bb1
463+
^bb1:
464+
// CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A]] : i32}
465+
%tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
466+
"test.dummy"(): () -> ()
467+
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
468+
cf.br ^bb3
469+
^bb2:
470+
"test.dummy"(): () -> ()
471+
"test.dummy"(): () -> ()
472+
"test.dummy"(): () -> ()
473+
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
474+
cf.br ^bb3
475+
^bb3:
476+
return
477+
}
478+
479+
// -----
480+
481+
// CHECK-LIVE-RANGE-LABEL: @holes_in_tile_liveness_inactive_overlaps
482+
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
483+
// CHECK-LIVE-RANGE: ^bb0:
484+
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
485+
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
486+
// CHECK-LIVE-RANGE-NEXT: ^bb1:
487+
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
488+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
489+
// CHECK-LIVE-RANGE-NEXT: | test.some_use
490+
// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
491+
// CHECK-LIVE-RANGE-NEXT: E cf.br
492+
// CHECK-LIVE-RANGE-NEXT: ^bb2:
493+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
494+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
495+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
496+
// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
497+
// CHECK-LIVE-RANGE-NEXT: E| test.some_use
498+
// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
499+
// CHECK-LIVE-RANGE-NEXT: E cf.br
500+
// CHECK-LIVE-RANGE-NEXT: ^bb3:
501+
// CHECK-LIVE-RANGE-NEXT: E test.some_use
502+
// CHECK-LIVE-RANGE-NEXT: func.return
503+
504+
// This tests an edge case in inactive live ranges. The first live range is
505+
// inactive at the start of ^bb1. If the tile allocator did not check if the
506+
// second live range overlapped the first it would wrongly re-use tile ID 0
507+
// (as the first live range is inactive so tile ID 0 is free). This would mean
508+
// in ^bb2 two overlapping live ranges would have the same tile ID (bad!).
509+
510+
// CHECK-LABEL: @holes_in_tile_liveness_inactive_overlaps
511+
func.func @holes_in_tile_liveness_inactive_overlaps(%cond: i1) {
512+
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
513+
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
514+
cf.cond_br %cond, ^bb2, ^bb1
515+
^bb1:
516+
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
517+
%tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
518+
"test.dummy"(): () -> ()
519+
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
520+
cf.br ^bb3(%tileB: vector<[4]x[4]xf32>)
521+
^bb2:
522+
"test.dummy"(): () -> ()
523+
"test.dummy"(): () -> ()
524+
"test.dummy"(): () -> ()
525+
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
526+
%tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
527+
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
528+
cf.br ^bb3(%tileC: vector<[4]x[4]xf32>)
529+
^bb3(%tile: vector<[4]x[4]xf32>):
530+
"test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
531+
return
532+
}
533+
534+
// -----
535+
536+
// This is the same as the previous example, but changes the tile types to
537+
// vector<[16]x[16]xi8>. This means in bb1 the allocator will need to spill the
538+
// first live range (which is inactive).
539+
540+
// Note: The live ranges are the same as the previous example (so are not checked).
541+
542+
// CHECK-LABEL: @spill_inactive_live_range
543+
func.func @spill_inactive_live_range(%cond: i1) {
544+
// CHECK: arm_sme.get_tile {tile_id = 16 : i32}
545+
%tileA = arm_sme.get_tile : vector<[16]x[16]xi8>
546+
cf.cond_br %cond, ^bb2, ^bb1
547+
^bb1:
548+
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
549+
%tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
550+
"test.dummy"(): () -> ()
551+
"test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
552+
cf.br ^bb3(%tileB: vector<[16]x[16]xi8>)
553+
^bb2:
554+
"test.dummy"(): () -> ()
555+
"test.dummy"(): () -> ()
556+
"test.dummy"(): () -> ()
557+
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
558+
%tileC = arm_sme.get_tile : vector<[16]x[16]xi8>
559+
"test.some_use"(%tileA) : (vector<[16]x[16]xi8>) -> ()
560+
cf.br ^bb3(%tileC: vector<[16]x[16]xi8>)
561+
^bb3(%tile: vector<[16]x[16]xi8>):
562+
"test.some_use"(%tile) : (vector<[16]x[16]xi8>) -> ()
563+
return
564+
}
565+
566+
// -----
567+
568+
// CHECK-LIVE-RANGE-LABEL: @reactivate_inactive_live_range
569+
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
570+
// CHECK-LIVE-RANGE: ^bb0:
571+
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
572+
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
573+
// CHECK-LIVE-RANGE-NEXT: ^bb1:
574+
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
575+
// CHECK-LIVE-RANGE-NEXT: | test.dummy
576+
// CHECK-LIVE-RANGE-NEXT: E test.some_use
577+
// CHECK-LIVE-RANGE-NEXT: cf.br
578+
// CHECK-LIVE-RANGE-NEXT: ^bb2:
579+
// CHECK-LIVE-RANGE-NEXT: | S arm_sme.get_tile
580+
// CHECK-LIVE-RANGE-NEXT: | | test.dummy
581+
// CHECK-LIVE-RANGE-NEXT: | | test.dummy
582+
// CHECK-LIVE-RANGE-NEXT: | E test.some_use
583+
// CHECK-LIVE-RANGE-NEXT: E test.some_use
584+
// CHECK-LIVE-RANGE-NEXT: cf.br
585+
586+
// Here the live range for %tileA becomes inactive in bb1 (so %tileB gets tile
587+
// ID 0 too). Then in bb2 the live range for tileA is reactivated as it overlaps
588+
// with the start of %tileC's live range (which means %tileC gets tile ID 1).
589+
590+
func.func @reactivate_inactive_live_range(%cond: i1) {
591+
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
592+
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
593+
cf.cond_br %cond, ^bb2, ^bb1
594+
^bb1:
595+
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
596+
%tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
597+
"test.dummy"(): () -> ()
598+
"test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
599+
cf.br ^bb3
600+
^bb2:
601+
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
602+
%tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
603+
"test.dummy"(): () -> ()
604+
"test.dummy"(): () -> ()
605+
"test.some_use"(%tileC) : (vector<[4]x[4]xf32>) -> ()
606+
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
607+
cf.br ^bb3
608+
^bb3:
609+
return
610+
}

0 commit comments

Comments
 (0)