Skip to content

Pick matrix lifetime fix. #8424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
Expand Down Expand Up @@ -997,12 +999,15 @@ class LowerMatrixIntrinsics {
bool Changed = false;
SmallVector<CallInst *, 16> MaybeFusableInsts;
SmallVector<Instruction *, 16> MatrixInsts;
SmallVector<IntrinsicInst *, 16> LifetimeEnds;

// First, collect all instructions with shape information and candidates for
// fusion (currently only matrix multiplies).
ReversePostOrderTraversal<Function *> RPOT(&Func);
for (auto *BB : RPOT)
for (Instruction &I : *BB) {
if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
if (ShapeMap.find(&I) == ShapeMap.end())
continue;
if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
Expand All @@ -1017,7 +1022,7 @@ class LowerMatrixIntrinsics {

// Third, try to fuse candidates.
for (CallInst *CI : MaybeFusableInsts)
LowerMatrixMultiplyFused(CI, FusedInsts);
LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);

Changed = !FusedInsts.empty();

Expand Down Expand Up @@ -1854,8 +1859,10 @@ class LowerMatrixIntrinsics {
///
/// Call finalizeLowering on lowered instructions. Instructions that are
/// completely eliminated by fusion are added to \p FusedInsts.
void LowerMatrixMultiplyFused(CallInst *MatMul,
SmallPtrSetImpl<Instruction *> &FusedInsts) {
void
LowerMatrixMultiplyFused(CallInst *MatMul,
SmallPtrSetImpl<Instruction *> &FusedInsts,
SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
if (!FuseMatrix || !DT)
return;

Expand Down Expand Up @@ -1944,6 +1951,55 @@ class LowerMatrixIntrinsics {
for (Instruction *I : ToHoist)
I->moveBefore(MatMul);

// Deal with lifetime.end calls that might be between Load0/Load1 and the
// store. To avoid introducing loads to dead objects (i.e. after the
// lifetime has been termined by @llvm.lifetime.end), either sink them
// after the store if in the same block, or remove the lifetime.end marker
// otherwise. This might pessimize further optimizations, by extending the
// lifetime of the object until the function returns, but should be
// conservatively correct.
MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
BasicBlock *StoreParent = Store->getParent();
bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
LoadOp1->getParent() == StoreParent;
for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
IntrinsicInst *End = LifetimeEnds[Idx];
auto Inc = make_scope_exit([&Idx]() { Idx++; });
// If the lifetime.end is guaranteed to be before the loads or after the
// store, it won't interfere with fusion.
if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
continue;
if (DT->dominates(Store, End))
continue;
// If all fusable ops are in the same block and the lifetime.end is in a
// different block, it won't interfere with fusion.
if (FusableOpsInSameBlock && End->getParent() != StoreParent)
continue;

// If the loads don't alias the lifetime.end, it won't interfere with
// fusion.
MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
if (!EndLoc.Ptr)
continue;
if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
continue;

// If both lifetime.end and the store are in the same block, extend the
// lifetime until after the store, so the new lifetime covers the loads
// we introduce later.
if (End->getParent() == StoreParent) {
End->moveAfter(Store);
continue;
}

// Otherwise remove the conflicting lifetime.end marker.
ToRemove.push_back(End);
std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
LifetimeEnds.pop_back();
Inc.release();
}

emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
return;
}
Expand Down
Loading