Skip to content

Commit bf7eb48

Browse files
committed
[Matrix] RAUW should only replace an instruction in ShapeMap if supportsShapeInfo
As an instruction is replaced in optimizeTransposes RAUW will replace it in the ShapeMap (ShapeMap is ValueMap so that uses are updated). In finalizeLowering however we skip updating uses if they are in the ShapeMap since they will be lowered separately at which point we pick up the lowered operands. In the testcase what happened was that since we replaced the doubled-transpose with the shuffle, it ended up in the ShapeMap. As we lowered the columnwise-load the use in the shuffle was not updated. Then as we removed the original columnwise-load we changed that to an undef. I.e. we ended up with: ``` %shuf = shufflevector <8 x double> undef, <8 x double> poison, <6 x i32> ^^^^^ <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6> ``` Besides the fix itself, I have fortified this last bit. As we change uses to undef when removing instruction we track the undefed instruction to make sure we eventually remove those too. This would have caught the issue at compile time. Differential Revision: https://reviews.llvm.org/D106714
1 parent 02077da commit bf7eb48

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ class LowerMatrixIntrinsics {
685685

686686
/// Try moving transposes in order to fold them away or into multiplies.
687687
void optimizeTransposes() {
688+
auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
689+
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
690+
// with New. We should only add New it it supportsShapeInfo so we insert
691+
// it conditionally instead.
692+
auto S = ShapeMap.find(&Old);
693+
if (S != ShapeMap.end()) {
694+
ShapeMap.erase(S);
695+
if (supportsShapeInfo(New))
696+
ShapeMap.insert({New, S->second});
697+
}
698+
Old.replaceAllUsesWith(New);
699+
};
700+
688701
// First sink all transposes inside matmuls, hoping that we end up with NN,
689702
// NT or TN variants.
690703
for (BasicBlock &BB : reverse(Func)) {
@@ -717,7 +730,7 @@ class LowerMatrixIntrinsics {
717730
Value *TATA;
718731
if (match(TA,
719732
m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
720-
I.replaceAllUsesWith(TATA);
733+
ReplaceAllUsesWith(I, TATA);
721734
EraseFromParent(&I);
722735
EraseFromParent(TA);
723736
}
@@ -740,8 +753,7 @@ class LowerMatrixIntrinsics {
740753
NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
741754
K->getZExtValue(),
742755
R->getZExtValue(), "mmul");
743-
setShapeInfo(NewInst, {C, R});
744-
I.replaceAllUsesWith(NewInst);
756+
ReplaceAllUsesWith(I, NewInst);
745757
EraseFromParent(&I);
746758
EraseFromParent(TA);
747759
}
@@ -774,8 +786,7 @@ class LowerMatrixIntrinsics {
774786
setShapeInfo(M, {C, R});
775787
Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(),
776788
C->getZExtValue());
777-
setShapeInfo(NewInst, {C, R});
778-
I->replaceAllUsesWith(NewInst);
789+
ReplaceAllUsesWith(*I, NewInst);
779790
if (I->use_empty())
780791
I->eraseFromParent();
781792
if (A->use_empty())
@@ -879,10 +890,30 @@ class LowerMatrixIntrinsics {
879890

880891
// Delete the instructions backwards, as it has a reduced likelihood of
881892
// having to update as many def-use and use-def chains.
893+
//
894+
// Because we add to ToRemove during fusion we can't guarantee that defs
895+
// are before uses. Change uses to undef temporarily as these should get
896+
// removed as well.
897+
//
898+
// For verification, we keep track of where we changed uses to undefs in
899+
// UndefedInsts and then check that we in fact remove them.
900+
SmallSet<Instruction *, 16> UndefedInsts;
882901
for (auto *Inst : reverse(ToRemove)) {
883-
if (!Inst->use_empty())
884-
Inst->replaceAllUsesWith(UndefValue::get(Inst->getType()));
902+
for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
903+
Use &U = *I++;
904+
if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
905+
UndefedInsts.insert(Undefed);
906+
U.set(UndefValue::get(Inst->getType()));
907+
}
885908
Inst->eraseFromParent();
909+
UndefedInsts.erase(Inst);
910+
}
911+
if (!UndefedInsts.empty()) {
912+
// If we didn't remove all undefed instructions, it's a hard error.
913+
dbgs() << "Undefed but present instructions:\n";
914+
for (auto *I : UndefedInsts)
915+
dbgs() << *I << "\n";
916+
llvm_unreachable("Undefed but instruction not removed");
886917
}
887918

888919
return Changed;

llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,32 @@ entry:
986986
ret <4 x float> %m
987987
}
988988

989+
define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) {
990+
; CHECK-LABEL: @transpose_of_transpose_of_non_matrix_op(
991+
; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[A:%.*]] to <2 x double>*
992+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
993+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[A]], i64 4
994+
; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
995+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
996+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, double* [[A]], i64 8
997+
; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>*
998+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
999+
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, double* [[A]], i64 12
1000+
; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
1001+
; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
1002+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
1003+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x double> [[COL_LOAD5]], <2 x double> [[COL_LOAD8]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
1004+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> [[TMP2]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
1005+
; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <8 x double> [[TMP3]], <8 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6>
1006+
; CHECK-NEXT: ret <6 x double> [[SHUF]]
1007+
;
1008+
%load = call <8 x double> @llvm.matrix.column.major.load.v8f64(double* %a, i64 4, i1 false, i32 2, i32 4)
1009+
%shuf = shufflevector <8 x double> %load, <8 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6>
1010+
%t = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %shuf, i32 3, i32 2)
1011+
%tt = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %t, i32 2, i32 3)
1012+
ret <6 x double> %tt
1013+
}
1014+
9891015
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
9901016
declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32)
9911017
declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)
@@ -995,3 +1021,4 @@ declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
9951021
declare <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double>, i32, i32)
9961022
declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32)
9971023
declare <4 x float> @llvm.matrix.transpose.v4f32(<4 x float>, i32, i32)
1024+
declare <8 x double> @llvm.matrix.column.major.load.v8f64(double*, i64, i1, i32, i32)

0 commit comments

Comments
 (0)