Skip to content

Commit ad60672

Browse files
authored
Merge pull request #8424 from fhahn/matrix-lifetime-fix
Pick matrix lifetime fix.
2 parents 7c093d5 + 234dfdf commit ad60672

File tree

2 files changed

+1220
-3
lines changed

2 files changed

+1220
-3
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
2121
#include "llvm/ADT/PostOrderIterator.h"
22+
#include "llvm/ADT/ScopeExit.h"
23+
#include "llvm/ADT/SmallSet.h"
2224
#include "llvm/ADT/SmallVector.h"
2325
#include "llvm/Analysis/AliasAnalysis.h"
2426
#include "llvm/Analysis/DomTreeUpdater.h"
@@ -997,12 +999,15 @@ class LowerMatrixIntrinsics {
997999
bool Changed = false;
9981000
SmallVector<CallInst *, 16> MaybeFusableInsts;
9991001
SmallVector<Instruction *, 16> MatrixInsts;
1002+
SmallVector<IntrinsicInst *, 16> LifetimeEnds;
10001003

10011004
// First, collect all instructions with shape information and candidates for
10021005
// fusion (currently only matrix multiplies).
10031006
ReversePostOrderTraversal<Function *> RPOT(&Func);
10041007
for (auto *BB : RPOT)
10051008
for (Instruction &I : *BB) {
1009+
if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
1010+
LifetimeEnds.push_back(cast<IntrinsicInst>(&I));
10061011
if (ShapeMap.find(&I) == ShapeMap.end())
10071012
continue;
10081013
if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
@@ -1017,7 +1022,7 @@ class LowerMatrixIntrinsics {
10171022

10181023
// Third, try to fuse candidates.
10191024
for (CallInst *CI : MaybeFusableInsts)
1020-
LowerMatrixMultiplyFused(CI, FusedInsts);
1025+
LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
10211026

10221027
Changed = !FusedInsts.empty();
10231028

@@ -1854,8 +1859,10 @@ class LowerMatrixIntrinsics {
18541859
///
18551860
/// Call finalizeLowering on lowered instructions. Instructions that are
18561861
/// completely eliminated by fusion are added to \p FusedInsts.
1857-
void LowerMatrixMultiplyFused(CallInst *MatMul,
1858-
SmallPtrSetImpl<Instruction *> &FusedInsts) {
1862+
void
1863+
LowerMatrixMultiplyFused(CallInst *MatMul,
1864+
SmallPtrSetImpl<Instruction *> &FusedInsts,
1865+
SmallVector<IntrinsicInst *, 16> &LifetimeEnds) {
18591866
if (!FuseMatrix || !DT)
18601867
return;
18611868

@@ -1944,6 +1951,55 @@ class LowerMatrixIntrinsics {
19441951
for (Instruction *I : ToHoist)
19451952
I->moveBefore(MatMul);
19461953

1954+
// Deal with lifetime.end calls that might be between Load0/Load1 and the
1955+
// store. To avoid introducing loads to dead objects (i.e. after the
1956+
// lifetime has been termined by @llvm.lifetime.end), either sink them
1957+
// after the store if in the same block, or remove the lifetime.end marker
1958+
// otherwise. This might pessimize further optimizations, by extending the
1959+
// lifetime of the object until the function returns, but should be
1960+
// conservatively correct.
1961+
MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
1962+
MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
1963+
BasicBlock *StoreParent = Store->getParent();
1964+
bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
1965+
LoadOp1->getParent() == StoreParent;
1966+
for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
1967+
IntrinsicInst *End = LifetimeEnds[Idx];
1968+
auto Inc = make_scope_exit([&Idx]() { Idx++; });
1969+
// If the lifetime.end is guaranteed to be before the loads or after the
1970+
// store, it won't interfere with fusion.
1971+
if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))
1972+
continue;
1973+
if (DT->dominates(Store, End))
1974+
continue;
1975+
// If all fusable ops are in the same block and the lifetime.end is in a
1976+
// different block, it won't interfere with fusion.
1977+
if (FusableOpsInSameBlock && End->getParent() != StoreParent)
1978+
continue;
1979+
1980+
// If the loads don't alias the lifetime.end, it won't interfere with
1981+
// fusion.
1982+
MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
1983+
if (!EndLoc.Ptr)
1984+
continue;
1985+
if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
1986+
continue;
1987+
1988+
// If both lifetime.end and the store are in the same block, extend the
1989+
// lifetime until after the store, so the new lifetime covers the loads
1990+
// we introduce later.
1991+
if (End->getParent() == StoreParent) {
1992+
End->moveAfter(Store);
1993+
continue;
1994+
}
1995+
1996+
// Otherwise remove the conflicting lifetime.end marker.
1997+
ToRemove.push_back(End);
1998+
std::swap(LifetimeEnds[Idx], LifetimeEnds.back());
1999+
LifetimeEnds.pop_back();
2000+
Inc.release();
2001+
}
2002+
19472003
emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
19482004
return;
19492005
}

0 commit comments

Comments
 (0)