Skip to content

[mlir][ArmSME] Support filling liveness 'holes' in the tile allocator #98350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 18, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jul 10, 2024

Holes in a live range are points where the corresponding value does not need to be in a tile/register. If the tile allocator keeps track of these holes it can reuse tiles for more values (avoiding spills).

Take this simple example:

func.func @example(%cond: i1) {
  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
  cf.cond_br %cond, ^bb2, ^bb1
^bb1:
  // If we end up here we never use %tileA again!
  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb2:
  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb3:
  return
}

If you were to calculate the liveness of %tileA and %tileB. You'd see there is a hole in the liveness of %tileA in bb1:

      %tileA  %tileB
^bb0:  Live
^bb1:          Live
^bb2:  Live

The tile allocator can make use of that hole and reuse the tile ID it assigned to %tileA for %tileB.

Holes in a live range are points where the corresponding value does not
need to be in a tile/register. If the tile allocator keeps track of
these holes it can reuse tiles for more values (avoiding spills).

Take this simple example:

```mlir
func.func @example(%cond: i1) {
  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
  cf.cond_br %cond, ^bb2, ^bb1
^bb1:
  // If we end up here we never use %tileA again!
  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb2:
  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb3:
  return
}
```

If you were to calculate the liveness of %tileA and %tileB. You'd see
there is a hole in the liveness of %tileA in bb1:

```
      %tileA  %tileB
^bb0:  Live
^bb1:          Live
^bb2:  Live
```

The tile allocator can make use of that hole and reuse the tile ID it
assigned to %tileA for %tileB.
@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

Holes in a live range are points where the corresponding value does not need to be in a tile/register. If the tile allocator keeps track of these holes it can reuse tiles for more values (avoiding spills).

Take this simple example:

func.func @<!-- -->example(%cond: i1) {
  %tileA = arm_sme.get_tile : vector&lt;[4]x[4]xf32&gt;
  cf.cond_br %cond, ^bb2, ^bb1
^bb1:
  // If we end up here we never use %tileA again!
  "test.some_use"(%tileB) : (vector&lt;[4]x[4]xf32&gt;) -&gt; ()
  cf.br ^bb3
^bb2:
  "test.some_use"(%tileA) : (vector&lt;[4]x[4]xf32&gt;) -&gt; ()
  cf.br ^bb3
^bb3:
  return
}

If you were to calculate the liveness of %tileA and %tileB. You'd see there is a hole in the liveness of %tileA in bb1:

      %tileA  %tileB
^bb0:  Live
^bb1:          Live
^bb2:  Live

The tile allocator can make use of that hole and reuse the tile ID it assigned to %tileA for %tileB.


Full diff: https://github.com/llvm/llvm-project/pull/98350.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (+88-25)
  • (modified) mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir (+130)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 733e758b43907..6023871c5affe 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -153,10 +153,18 @@ class TileAllocator {
     return failure();
   }
 
+  /// Acquires a specific tile ID. Asserts the tile is initially free.
+  void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) == TileMask::kNone &&
+           "cannot acquire allocated tile!");
+    tilesInUse |= tileMask;
+  }
+
   /// Releases a previously allocated tile ID.
   void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
     TileMask tileMask = getMasks(tileType)[tileId];
-    assert((tilesInUse & tileMask) != TileMask::kNone &&
+    assert((tilesInUse & tileMask) == tileMask &&
            "cannot release unallocated tile!");
     tilesInUse ^= tileMask;
   }
@@ -289,6 +297,11 @@ struct LiveRange {
         .valid();
   }
 
+  /// Returns true if this range overlaps with `point`.
+  bool overlaps(uint64_t point) const {
+    return ranges->lookup(point) == kValidLiveRange;
+  }
+
   /// Unions this live range with `otherRange`, aborts if the ranges overlap.
   void unionWith(LiveRange const &otherRange) {
     for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
@@ -488,69 +501,113 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
   return std::move(coalescedLiveRanges);
 }
 
-/// Choose a live range to spill (via some heuristics). This picks either an
-/// active live range from `activeRanges` or the new live range `newRange`.
+/// Choose a live range to spill (via some heuristics). This picks either a live
+/// range from `activeRanges`, `inactiveRanges`, or the new live range
+/// `newRange`. Note: All live ranges in `activeRanges` and `inactiveRanges` are
+/// assumed to overlap with `newRange`.
 LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
+                                      ArrayRef<LiveRange *> inactiveRanges,
                                       LiveRange *newRange) {
+  auto allOverlappingRanges =
+      llvm::concat<LiveRange>(llvm::make_pointee_range(activeRanges),
+                              llvm::make_pointee_range(inactiveRanges));
+
   // Heuristic: Spill trivially copyable operations (usually free).
-  auto isTrivialSpill = [&](LiveRange *allocatedRange) {
-    return isTileTypeGreaterOrEqual(allocatedRange->getTileType(),
+  auto isTrivialSpill = [&](LiveRange &allocatedRange) {
+    return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
                                     newRange->getTileType()) &&
-           allocatedRange->values.size() == 1 &&
+           allocatedRange.values.size() == 1 &&
            isTriviallyCloneableTileOp(
-               allocatedRange->values[0]
-                   .getDefiningOp<ArmSMETileOpInterface>());
+               allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
   };
-  if (isTrivialSpill(newRange))
+  if (isTrivialSpill(*newRange))
     return newRange;
-  auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill);
-  if (trivialSpill != activeRanges.end())
-    return *trivialSpill;
+  auto trivialSpill = llvm::find_if(allOverlappingRanges, isTrivialSpill);
+  if (trivialSpill != allOverlappingRanges.end())
+    return &*trivialSpill;
 
   // Heuristic: Spill the range that ends last (with a compatible tile type).
-  auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) {
-    return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) ||
-           a->end() < b->end();
+  auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
+    return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
+           a.end() < b.end();
   };
-  LiveRange *lastActiveLiveRange = *std::max_element(
-      activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier);
-  if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange))
-    return lastActiveLiveRange;
+  LiveRange &lastActiveLiveRange = *std::max_element(
+      allOverlappingRanges.begin(), allOverlappingRanges.end(),
+      isSmallerTileTypeOrEndsEarlier);
+  if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, *newRange))
+    return &lastActiveLiveRange;
   return newRange;
 }
 
 /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
-/// Note: This does not attempt to fill holes in active live ranges.
 void allocateTilesToLiveRanges(
     ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
   TileAllocator tileAllocator;
   SetVector<LiveRange *> activeRanges;
+  SetVector<LiveRange *> inactiveRanges;
   for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
-    // Release tile IDs from live ranges that have ended.
     activeRanges.remove_if([&](LiveRange *activeRange) {
+      // Check for live ranges that have expired.
       if (activeRange->end() <= nextRange->start()) {
         tileAllocator.releaseTileId(activeRange->getTileType(),
                                     *activeRange->tileId);
         return true;
       }
+      // Check for live ranges that have become inactive.
+      if (!activeRange->overlaps(nextRange->start())) {
+        tileAllocator.releaseTileId(activeRange->getTileType(),
+                                    *activeRange->tileId);
+        inactiveRanges.insert(activeRange);
+        return true;
+      }
+      return false;
+    });
+    inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
+      // Check for live ranges that have expired.
+      if (inactiveRange->end() <= nextRange->start()) {
+        return true;
+      }
+      // Check for live ranges that have become active.
+      if (inactiveRange->overlaps(nextRange->start())) {
+        tileAllocator.acquireTileId(inactiveRange->getTileType(),
+                                    *inactiveRange->tileId);
+        activeRanges.insert(inactiveRange);
+        return true;
+      }
       return false;
     });
 
+    // Collect inactive live ranges that overlap with the current new live
+    // range. We need to acquire the tile IDs of overlapping inactive ranges to
+    // prevent two (overlapping) live ranges from getting the same tile ID.
+    SmallVector<LiveRange *> overlappingInactiveRanges;
+    for (LiveRange *inactiveRange : inactiveRanges) {
+      if (inactiveRange->overlaps(*nextRange)) {
+        tileAllocator.acquireTileId(inactiveRange->getTileType(),
+                                    *inactiveRange->tileId);
+        overlappingInactiveRanges.push_back(inactiveRange);
+      }
+    }
+
     // Allocate a tile ID to `nextRange`.
     auto rangeTileType = nextRange->getTileType();
     auto tileId = tileAllocator.allocateTileId(rangeTileType);
     if (succeeded(tileId)) {
       nextRange->tileId = *tileId;
     } else {
-      LiveRange *rangeToSpill =
-          chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange);
+      LiveRange *rangeToSpill = chooseSpillUsingHeuristics(
+          activeRanges.getArrayRef(), overlappingInactiveRanges, nextRange);
       if (rangeToSpill != nextRange) {
-        // Spill an active live range (so release its tile ID first).
+        // Spill an (in)active live range (so release its tile ID first).
         tileAllocator.releaseTileId(rangeToSpill->getTileType(),
                                     *rangeToSpill->tileId);
-        activeRanges.remove(rangeToSpill);
         // This will always succeed after a spill (of an active live range).
         nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
+        // Remove the live range from the active/inactive sets.
+        if (!activeRanges.remove(rangeToSpill)) {
+          bool removed = inactiveRanges.remove(rangeToSpill);
+          assert(removed && "expected a range to be removed!");
+        }
       }
       rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
     }
@@ -558,6 +615,12 @@ void allocateTilesToLiveRanges(
     // Insert the live range into the active ranges.
     if (nextRange->tileId < kInMemoryTileIdBase)
       activeRanges.insert(nextRange);
+
+    // Release tiles reserved for inactive live ranges.
+    for (LiveRange *range : overlappingInactiveRanges) {
+      if (*range->tileId < kInMemoryTileIdBase)
+        tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
+    }
   }
 }
 
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 9c22b29ac22e7..59afa654778e5 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -430,3 +430,133 @@ func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
   // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
   return
 }
+
+// -----
+
+//  CHECK-LIVE-RANGE-LABEL: @fill_holes_in_tile_liveness
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//        CHECK-LIVE-RANGE: S  arm_sme.get_tile
+//        CHECK-LIVE-RANGE: E  cf.cond_br
+//        CHECK-LIVE-RANGE: ^bb1:
+//        CHECK-LIVE-RANGE:  S arm_sme.get_tile
+//        CHECK-LIVE-RANGE:  | test.dummy
+//        CHECK-LIVE-RANGE:  E test.some_use
+//        CHECK-LIVE-RANGE:    cf.br
+//        CHECK-LIVE-RANGE: ^bb2:
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: E  test.some_use
+//        CHECK-LIVE-RANGE:    cf.br
+
+// Here there's a 'hole' in the liveness of %tileA (in bb1) where another value
+// can reuse the tile ID (0) assigned to %tileA.
+
+// CHECK-LABEL: @fill_holes_in_tile_liveness
+func.func @fill_holes_in_tile_liveness(%cond: i1) {
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+  cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3
+^bb2:
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3
+^bb3:
+  return
+}
+
+// -----
+
+//  CHECK-LIVE-RANGE-LABEL: @holes_in_tile_liveness_inactive_overlaps
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//        CHECK-LIVE-RANGE: S  arm_sme.get_tile
+//        CHECK-LIVE-RANGE: E  cf.cond_br
+//        CHECK-LIVE-RANGE: ^bb1:
+//        CHECK-LIVE-RANGE:  S arm_sme.get_tile
+//        CHECK-LIVE-RANGE:  | test.dummy
+//        CHECK-LIVE-RANGE:  | test.some_use
+//        CHECK-LIVE-RANGE:  | arm_sme.copy_tile
+//        CHECK-LIVE-RANGE:  E cf.br
+//        CHECK-LIVE-RANGE: ^bb2:
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |S arm_sme.get_tile
+//        CHECK-LIVE-RANGE: E| test.some_use
+//        CHECK-LIVE-RANGE:  | arm_sme.copy_tile
+//        CHECK-LIVE-RANGE:  E cf.br
+//        CHECK-LIVE-RANGE: ^bb3:
+//        CHECK-LIVE-RANGE:  E test.some_use
+//        CHECK-LIVE-RANGE:    func.return
+
+// This tests an edge case in inactive live ranges. The first live range is
+// inactive at the start of ^bb1. If the tile allocator did not check if the
+// second live range overlapped the first it would wrongly re-use tile ID 0
+// (as the first live range is inactive so tile ID 0 is free). This would mean
+// in ^bb2 two overlapping live ranges would have the same tile ID (bad!).
+
+// CHECK-LABEL: @holes_in_tile_liveness_inactive_overlaps
+func.func @holes_in_tile_liveness_inactive_overlaps(%cond: i1) {
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+  cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+  // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+  %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3(%tileB: vector<[4]x[4]xf32>)
+^bb2:
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+  %tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
+  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3(%tileC: vector<[4]x[4]xf32>)
+^bb3(%tile: vector<[4]x[4]xf32>):
+  "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// This is the same as the previous example, but changes the tile types to
+// vector<[16]x[16]xi8>. This means in bb1 the allocator will need to spill the
+// first live range (which is inactive).
+
+// Note: The live ranges are the same as the previous example (so are not checked).
+
+// CHECK-LABEL: @spill_inactive_live_range
+func.func @spill_inactive_live_range(%cond: i1) {
+  // CHECK: arm_sme.get_tile {tile_id = 16 : i32}
+  %tileA = arm_sme.get_tile : vector<[16]x[16]xi8>
+  cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
+  cf.br ^bb3(%tileB: vector<[16]x[16]xi8>)
+^bb2:
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileC = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.some_use"(%tileA) : (vector<[16]x[16]xi8>) -> ()
+  cf.br ^bb3(%tileC: vector<[16]x[16]xi8>)
+^bb3(%tile: vector<[16]x[16]xi8>):
+  "test.some_use"(%tile) : (vector<[16]x[16]xi8>) -> ()
+  return
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

This is pretty complex stuff, so I've left a few questions. IIUC, you model "holes" as "inactive" live ranges? These are simply used as an additional set of ranges that is used as candidates for spilling? And there's no prioritising of "inactive" over "active"?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall makes sense, but this needs some more documentation - we should be able to follow this without referring to:

I've already requested some small clarification inline, but am also asking for allocateTilesToLiveRanges to be documented more throughly. In particular, IIUC, this is the overall structure:

for (LiveRange *nextRange : liveRangesSortedByStartPoint)

  // 1. Update the list of_active_  live ranges relative to `nextRange->start()`
  // 2. Update the list of _inactive_ live ranges relative to `nextRange->start()`
  
  // 3. Get the list of overlapping inactive live ranges
  // The list of inactive live ranges contains a list of _candidates_ to "borrow" tileID from. However, when updating `inactiveRanges` in step 2, we only looked at the starting point of the current range (i.e. `nextRange->start()`). We need to refine the list by excluding "overlapping" ranges. The tileIDs from overlapping ranges are re-acquired to flag them as "in use".
  
  // 4. Either acquire an available tileID or `newRange` or obtain one through spilling.
  
  // 5. Following step 3., re-release tileIDs corresponding to inactive live ranges

If this correct? If not, please refine. In any case, please make sure that key step in this method are clearly documented.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for all the comments and for addressing my feedback. Great work Ben! 🙏🏻

@MacDue MacDue merged commit eed72d4 into llvm:main Jul 18, 2024
5 of 6 checks passed
@MacDue MacDue deleted the alloc_holes branch July 18, 2024 19:13
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…#98350)

Summary:
Holes in a live range are points where the corresponding value does not
need to be in a tile/register. If the tile allocator keeps track of
these holes it can reuse tiles for more values (avoiding spills).

Take this simple example:

```mlir
func.func @example(%cond: i1) {
  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
  cf.cond_br %cond, ^bb2, ^bb1
^bb1:
  // If we end up here we never use %tileA again!
  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb2:
  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb3:
  return
}
```

If you were to calculate the liveness of %tileA and %tileB. You'd see
there is a hole in the liveness of %tileA in bb1:

```
      %tileA  %tileB
^bb0:  Live
^bb1:          Live
^bb2:  Live
```

The tile allocator can make use of that hole and reuse the tile ID it
assigned to %tileA for %tileB.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251426
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants