Skip to content

Commit 01dac79

Browse files
committed
[Matrix] Adjust lifetime.ends during multiply fusion.
At the moment, loads introduced by multiply fusion may be placed after an objects lifetime has been terminated by lifetime.end. This introduces reads to dead objects. To avoid this, first collect all lifetime.end calls in the function. During fusion, we deal with any lifetime.end calls that may alias any of the loads. Such lifetime.end calls are either moved when possible (both the lifetime.end and the store are in the same block) or deleted.
1 parent 9f7ed36 commit 01dac79

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -990,12 +990,15 @@ class LowerMatrixIntrinsics {
990990
bool Changed = false;
991991
SmallVector<CallInst *, 16> MaybeFusableInsts;
992992
SmallVector<Instruction *, 16> MatrixInsts;
993+
SmallSetVector<IntrinsicInst *, 16> LifetimeEnds;
993994

994995
// First, collect all instructions with shape information and candidates for
995996
// fusion (currently only matrix multiplies).
996997
ReversePostOrderTraversal<Function *> RPOT(&Func);
997998
for (auto *BB : RPOT)
998999
for (Instruction &I : *BB) {
1000+
if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
1001+
LifetimeEnds.insert(cast<IntrinsicInst>(&I));
9991002
if (ShapeMap.find(&I) == ShapeMap.end())
10001003
continue;
10011004
if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
@@ -1010,7 +1013,7 @@ class LowerMatrixIntrinsics {
10101013

10111014
// Third, try to fuse candidates.
10121015
for (CallInst *CI : MaybeFusableInsts)
1013-
LowerMatrixMultiplyFused(CI, FusedInsts);
1016+
LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
10141017

10151018
Changed = !FusedInsts.empty();
10161019

@@ -1856,8 +1859,10 @@ class LowerMatrixIntrinsics {
18561859
///
18571860
/// Call finalizeLowering on lowered instructions. Instructions that are
18581861
/// completely eliminated by fusion are added to \p FusedInsts.
1859-
void LowerMatrixMultiplyFused(CallInst *MatMul,
1860-
SmallPtrSetImpl<Instruction *> &FusedInsts) {
1862+
void
1863+
LowerMatrixMultiplyFused(CallInst *MatMul,
1864+
SmallPtrSetImpl<Instruction *> &FusedInsts,
1865+
SmallSetVector<IntrinsicInst *, 16> &LifetimeEnds) {
18611866
if (!FuseMatrix || !DT)
18621867
return;
18631868

@@ -1946,6 +1951,35 @@ class LowerMatrixIntrinsics {
19461951
for (Instruction *I : ToHoist)
19471952
I->moveBefore(MatMul);
19481953

1954+
// Deal with lifetime.end calls that might be between Load0/Load1 and the
1955+
// store. To avoid introducing loads to dead objects (i.e. after thei
1956+
// lifetime has been termined by @llvm.lifetime.end), either sink them
1957+
// after the store if in the same block, or remove the lifetime.end marker
1958+
// otherwise. This might pessimize further optimizations, by extending the
1959+
// lifetime of the object until the function returns, but should be
1960+
// conservatively correct.
1961+
MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
1962+
MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
1963+
for (IntrinsicInst *End : make_early_inc_range(LifetimeEnds)) {
1964+
if (DT->dominates(Store, End))
1965+
continue;
1966+
MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
1967+
if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
1968+
continue;
1969+
1970+
// If both lifetime.end and the store are in the same block, extend the
1971+
// lifetime until after the store, so the new lifetime covers the loads
1972+
// we introduce later.
1973+
if (Store->getParent() == End->getParent()) {
1974+
End->moveAfter(Store);
1975+
continue;
1976+
}
1977+
1978+
// Otherwise remove the conflicting lifetime.end marker.
1979+
ToRemove.push_back(End);
1980+
LifetimeEnds.remove(End);
1981+
}
1982+
19491983
emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
19501984
return;
19511985
}

llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,11 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
66
; Tests to make sure no loads are introduced after a lifetime.end by multiply
77
; fusion.
88

9-
; FIXME: Currently the tests are mis-compiled, with loads being introduced after
10-
; llvm.lifetime.end calls.
11-
129
define void @lifetime_for_first_arg_before_multiply(ptr noalias %B, ptr noalias %C) {
1310
; CHECK-LABEL: @lifetime_for_first_arg_before_multiply(
1411
; CHECK-NEXT: entry:
1512
; CHECK-NEXT: [[A:%.*]] = alloca <4 x double>, align 32
1613
; CHECK-NEXT: call void @init(ptr [[A]])
17-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
1814
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
1915
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
2016
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -77,6 +73,7 @@ define void @lifetime_for_first_arg_before_multiply(ptr noalias %B, ptr noalias
7773
; CHECK-NEXT: store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
7874
; CHECK-NEXT: [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
7975
; CHECK-NEXT: store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
76+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
8077
; CHECK-NEXT: ret void
8178
;
8279
entry:
@@ -95,7 +92,6 @@ define void @lifetime_for_second_arg_before_multiply(ptr noalias %A, ptr noalias
9592
; CHECK-NEXT: entry:
9693
; CHECK-NEXT: [[B:%.*]] = alloca <4 x double>, align 32
9794
; CHECK-NEXT: call void @init(ptr [[B]])
98-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
9995
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
10096
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
10197
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -158,6 +154,7 @@ define void @lifetime_for_second_arg_before_multiply(ptr noalias %A, ptr noalias
158154
; CHECK-NEXT: store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
159155
; CHECK-NEXT: [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
160156
; CHECK-NEXT: store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
157+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
161158
; CHECK-NEXT: ret void
162159
;
163160
entry:
@@ -177,7 +174,6 @@ define void @lifetime_for_first_arg_before_multiply_load_from_offset(ptr noalias
177174
; CHECK-NEXT: [[A:%.*]] = alloca <8 x double>, align 64
178175
; CHECK-NEXT: call void @init(ptr [[A]])
179176
; CHECK-NEXT: [[GEP_8:%.*]] = getelementptr i8, ptr [[A]], i64 8
180-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
181177
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[GEP_8]], i64 0
182178
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
183179
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -240,6 +236,7 @@ define void @lifetime_for_first_arg_before_multiply_load_from_offset(ptr noalias
240236
; CHECK-NEXT: store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
241237
; CHECK-NEXT: [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
242238
; CHECK-NEXT: store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
239+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
243240
; CHECK-NEXT: ret void
244241
;
245242
entry:
@@ -261,7 +258,6 @@ define void @lifetime_for_first_arg_before_multiply_lifetime_does_not_dominate(p
261258
; CHECK-NEXT: call void @init(ptr [[A]])
262259
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
263260
; CHECK: then:
264-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
265261
; CHECK-NEXT: br label [[EXIT]]
266262
; CHECK: exit:
267263
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
@@ -352,7 +348,6 @@ define void @lifetime_for_second_arg_before_multiply_lifetime_does_not_dominate(
352348
; CHECK-NEXT: call void @init(ptr [[B]])
353349
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
354350
; CHECK: then:
355-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
356351
; CHECK-NEXT: br label [[EXIT]]
357352
; CHECK: exit:
358353
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
@@ -441,10 +436,9 @@ define void @lifetime_for_ptr_first_arg_before_multiply(ptr noalias %A, ptr noal
441436
; CHECK-NEXT: entry:
442437
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
443438
; CHECK: then:
444-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[A:%.*]])
445439
; CHECK-NEXT: br label [[EXIT]]
446440
; CHECK: exit:
447-
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
441+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
448442
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
449443
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
450444
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
@@ -528,15 +522,13 @@ define void @lifetime_for_both_ptr_args_before_multiply(ptr noalias %A, ptr noal
528522
; CHECK-NEXT: entry:
529523
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
530524
; CHECK: then:
531-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[B:%.*]])
532-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[A:%.*]])
533525
; CHECK-NEXT: br label [[EXIT]]
534526
; CHECK: exit:
535-
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
527+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
536528
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
537529
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
538530
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
539-
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, ptr [[B]], i64 0
531+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, ptr [[B:%.*]], i64 0
540532
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[TMP1]], align 8
541533
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, ptr [[TMP1]], i64 2
542534
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 8
@@ -618,7 +610,6 @@ define void @lifetime_for_ptr_select_before_multiply(ptr noalias %A, ptr noalias
618610
; CHECK-NEXT: [[P:%.*]] = select i1 [[C_0:%.*]], ptr [[A:%.*]], ptr [[B:%.*]]
619611
; CHECK-NEXT: br i1 [[C_1:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
620612
; CHECK: then:
621-
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[P]])
622613
; CHECK-NEXT: br label [[EXIT]]
623614
; CHECK: exit:
624615
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr double, ptr [[P]], i64 0

0 commit comments

Comments
 (0)