Skip to content

[mlir][ArmSME] Support filling liveness 'holes' in the tile allocator #98350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 105 additions & 29 deletions mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,18 @@ class TileAllocator {
return failure();
}

/// Acquires a specific tile ID. Asserts the tile is initially free.
void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) == TileMask::kNone &&
"cannot acquire allocated tile!");
tilesInUse |= tileMask;
}

/// Releases a previously allocated tile ID.
void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) != TileMask::kNone &&
assert((tilesInUse & tileMask) == tileMask &&
"cannot release unallocated tile!");
tilesInUse ^= tileMask;
}
Expand Down Expand Up @@ -289,6 +297,11 @@ struct LiveRange {
.valid();
}

/// Returns true if this range is active at `point` in the program.
bool overlaps(uint64_t point) const {
return ranges->lookup(point) == kValidLiveRange;
}

/// Unions this live range with `otherRange`, aborts if the ranges overlap.
void unionWith(LiveRange const &otherRange) {
for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
Expand Down Expand Up @@ -488,76 +501,139 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
return std::move(coalescedLiveRanges);
}

/// Choose a live range to spill (via some heuristics). This picks either an
/// active live range from `activeRanges` or the new live range `newRange`.
LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
LiveRange *newRange) {
/// Choose a live range to spill (via some heuristics). This picks either a live
/// range from `overlappingRanges`, or the new live range `newRange`.
template <typename OverlappingRangesIterator>
LiveRange *
chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
LiveRange *newRange) {
// Heuristic: Spill trivially copyable operations (usually free).
auto isTrivialSpill = [&](LiveRange *allocatedRange) {
return isTileTypeGreaterOrEqual(allocatedRange->getTileType(),
auto isTrivialSpill = [&](LiveRange &allocatedRange) {
return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
newRange->getTileType()) &&
allocatedRange->values.size() == 1 &&
allocatedRange.values.size() == 1 &&
isTriviallyCloneableTileOp(
allocatedRange->values[0]
.getDefiningOp<ArmSMETileOpInterface>());
allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
};
if (isTrivialSpill(newRange))
if (isTrivialSpill(*newRange))
return newRange;
auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill);
if (trivialSpill != activeRanges.end())
return *trivialSpill;
auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
if (trivialSpill != overlappingRanges.end())
return &*trivialSpill;

// Heuristic: Spill the range that ends last (with a compatible tile type).
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) {
return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) ||
a->end() < b->end();
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
a.end() < b.end();
};
LiveRange *lastActiveLiveRange = *std::max_element(
activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier);
if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange))
return lastActiveLiveRange;
LiveRange &latestEndingLiveRange =
*std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
isSmallerTileTypeOrEndsEarlier);
if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
return &latestEndingLiveRange;
return newRange;
}

/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
/// Note: This does not attempt to fill holes in active live ranges.
void allocateTilesToLiveRanges(
ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
TileAllocator tileAllocator;
// `activeRanges` = Live ranges that need to be in a tile at the
// `currentPoint` in the program.
SetVector<LiveRange *> activeRanges;
// `inactiveRanges` = Live ranges that _do not_ need to be in a tile
// at the `currentPoint` in the program but could become active again later.
// An inactive section of a live range can be seen as a 'hole' in the live
// range, where it is possible to reuse the live range's tile ID _before_ it
// has ended. By identifying 'holes', the allocator can reuse tiles more
// often, which helps avoid costly tile spills.
SetVector<LiveRange *> inactiveRanges;
for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
// Release tile IDs from live ranges that have ended.
auto currentPoint = nextRange->start();
// 1. Update the `activeRanges` at `currentPoint`.
activeRanges.remove_if([&](LiveRange *activeRange) {
if (activeRange->end() <= nextRange->start()) {
// Check for live ranges that have expired.
if (activeRange->end() <= currentPoint) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
return true;
}
// Check for live ranges that have become inactive.
if (!activeRange->overlaps(currentPoint)) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
inactiveRanges.insert(activeRange);
return true;
}
return false;
});
// 2. Update the `inactiveRanges` at `currentPoint`.
inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
// Check for live ranges that have expired.
if (inactiveRange->end() <= currentPoint) {
return true;
}
// Check for live ranges that have become active.
if (inactiveRange->overlaps(currentPoint)) {
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
activeRanges.insert(inactiveRange);
return true;
}
return false;
});

// 3. Collect inactive live ranges that overlap with the new live range.
// Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
// whereas this checks if there is an overlap at any future point too.
SmallVector<LiveRange *> overlappingInactiveRanges;
for (LiveRange *inactiveRange : inactiveRanges) {
if (inactiveRange->overlaps(*nextRange)) {
// We need to reserve the tile IDs of overlapping inactive ranges to
// prevent two (overlapping) live ranges from getting the same tile ID.
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
overlappingInactiveRanges.push_back(inactiveRange);
}
}

// Allocate a tile ID to `nextRange`.
// 4. Allocate a tile ID to `nextRange`.
auto rangeTileType = nextRange->getTileType();
auto tileId = tileAllocator.allocateTileId(rangeTileType);
if (succeeded(tileId)) {
nextRange->tileId = *tileId;
} else {
// Create an iterator over all overlapping live ranges.
auto allOverlappingRanges = llvm::concat<LiveRange>(
llvm::make_pointee_range(activeRanges.getArrayRef()),
llvm::make_pointee_range(overlappingInactiveRanges));
// Choose an overlapping live range to spill.
LiveRange *rangeToSpill =
chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange);
chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
if (rangeToSpill != nextRange) {
// Spill an active live range (so release its tile ID first).
// Spill an (in)active live range (so release its tile ID first).
tileAllocator.releaseTileId(rangeToSpill->getTileType(),
*rangeToSpill->tileId);
activeRanges.remove(rangeToSpill);
// This will always succeed after a spill (of an active live range).
nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
// Remove the live range from the active/inactive sets.
if (!activeRanges.remove(rangeToSpill)) {
bool removed = inactiveRanges.remove(rangeToSpill);
assert(removed && "expected a range to be removed!");
}
}
rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
}

// Insert the live range into the active ranges.
// 5. Insert the live range into the active ranges.
if (nextRange->tileId < kInMemoryTileIdBase)
activeRanges.insert(nextRange);

// 6. Release tiles reserved for inactive live ranges (in step 3).
for (LiveRange *range : overlappingInactiveRanges) {
if (*range->tileId < kInMemoryTileIdBase)
tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
}
}
}

Expand Down
178 changes: 178 additions & 0 deletions mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,181 @@ func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
// Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
return
}

// -----

// CHECK-LIVE-RANGE-LABEL: @fill_holes_in_tile_liveness
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
// CHECK-LIVE-RANGE: ^bb0:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
// CHECK-LIVE-RANGE-NEXT: ^bb1:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb2:
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br

// Here there's a 'hole' in the liveness of %tileA (in bb1) where another value
// can reuse the tile ID assigned to %tileA. The liveness for %tileB is
// entirely within the 'hole' in %tileA's live range, so %tileB should get the
// same tile ID as %tileA.

// CHECK-LABEL: @fill_holes_in_tile_liveness
func.func @fill_holes_in_tile_liveness(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A:.*]] : i32}
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A]] : i32}
%tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb2:
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb3:
return
}

// -----

// CHECK-LIVE-RANGE-LABEL: @holes_in_tile_liveness_inactive_overlaps
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
// CHECK-LIVE-RANGE: ^bb0:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
// CHECK-LIVE-RANGE-NEXT: ^bb1:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.some_use
// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
// CHECK-LIVE-RANGE-NEXT: E cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb2:
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E| test.some_use
// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
// CHECK-LIVE-RANGE-NEXT: E cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb3:
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: func.return

// This tests an edge case in inactive live ranges. The first live range is
// inactive at the start of ^bb1. If the tile allocator did not check if the
// second live range overlapped the first it would wrongly re-use tile ID 0
// (as the first live range is inactive so tile ID 0 is free). This would mean
// in ^bb2 two overlapping live ranges would have the same tile ID (bad!).

// CHECK-LABEL: @holes_in_tile_liveness_inactive_overlaps
func.func @holes_in_tile_liveness_inactive_overlaps(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
%tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3(%tileB: vector<[4]x[4]xf32>)
^bb2:
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
%tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3(%tileC: vector<[4]x[4]xf32>)
^bb3(%tile: vector<[4]x[4]xf32>):
"test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
return
}

// -----

// This is the same as the previous example, but changes the tile types to
// vector<[16]x[16]xi8>. This means in bb1 the allocator will need to spill the
// first live range (which is inactive).

// Note: The live ranges are the same as the previous example (so are not checked).

// CHECK-LABEL: @spill_inactive_live_range
func.func @spill_inactive_live_range(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = 16 : i32}
%tileA = arm_sme.get_tile : vector<[16]x[16]xi8>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
cf.br ^bb3(%tileB: vector<[16]x[16]xi8>)
^bb2:
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileC = arm_sme.get_tile : vector<[16]x[16]xi8>
"test.some_use"(%tileA) : (vector<[16]x[16]xi8>) -> ()
cf.br ^bb3(%tileC: vector<[16]x[16]xi8>)
^bb3(%tile: vector<[16]x[16]xi8>):
"test.some_use"(%tile) : (vector<[16]x[16]xi8>) -> ()
return
}

// -----

// CHECK-LIVE-RANGE-LABEL: @reactivate_inactive_live_range
// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
// CHECK-LIVE-RANGE: ^bb0:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: E cf.cond_br
// CHECK-LIVE-RANGE-NEXT: ^bb1:
// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | test.dummy
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br
// CHECK-LIVE-RANGE-NEXT: ^bb2:
// CHECK-LIVE-RANGE-NEXT: | S arm_sme.get_tile
// CHECK-LIVE-RANGE-NEXT: | | test.dummy
// CHECK-LIVE-RANGE-NEXT: | | test.dummy
// CHECK-LIVE-RANGE-NEXT: | E test.some_use
// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LIVE-RANGE-NEXT: cf.br

// Here the live range for %tileA becomes inactive in bb1 (so %tileB gets tile
// ID 0 too). Then in bb2 the live range for tileA is reactivated as it overlaps
// with the start of %tileC's live range (which means %tileC gets tile ID 1).

func.func @reactivate_inactive_live_range(%cond: i1) {
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// CHECK: arm_sme.get_tile {tile_id = 0 : i32}
%tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
"test.dummy"(): () -> ()
"test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
cf.br ^bb3
^bb2:
// CHECK: arm_sme.get_tile {tile_id = 1 : i32}
%tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
"test.some_use"(%tileC) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb3:
return
}
Loading