Skip to content

Commit d87d361

Browse files
committed
[Matrix] Fix shape for factored transpose
The shape of the input is C x R. Differential Revision: https://reviews.llvm.org/D106722
1 parent bf7eb48 commit d87d361

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ class LowerMatrixIntrinsics {
774774
++II;
775775
Value *A, *B, *AT, *BT;
776776
ConstantInt *R, *K, *C;
777+
// A^t * B ^t -> (B * A)^t
777778
if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
778779
m_Value(A), m_Value(B), m_ConstantInt(R),
779780
m_ConstantInt(K), m_ConstantInt(C))) &&
@@ -784,8 +785,8 @@ class LowerMatrixIntrinsics {
784785
Value *M = Builder.CreateMatrixMultiply(
785786
BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
786787
setShapeInfo(M, {C, R});
787-
Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(),
788-
C->getZExtValue());
788+
Instruction *NewInst = Builder.CreateMatrixTranspose(
789+
M, C->getZExtValue(), R->getZExtValue());
789790
ReplaceAllUsesWith(*I, NewInst);
790791
if (I->use_empty())
791792
I->eraseFromParent();

llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,126 @@ define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) {
10121012
ret <6 x double> %tt
10131013
}
10141014

1015+
define <12 x double> @factor_transpose(<6 x double> %a, <8 x double> %b) {
1016+
; CHECK-LABEL: @factor_transpose(
1017+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
1018+
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
1019+
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x double> [[A:%.*]], <6 x double> poison, <2 x i32> <i32 0, i32 1>
1020+
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> <i32 2, i32 3>
1021+
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> <i32 4, i32 5>
1022+
; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
1023+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
1024+
; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x double> poison, double [[TMP1]], i32 0
1025+
; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT]], <2 x double> poison, <2 x i32> zeroinitializer
1026+
; CHECK-NEXT: [[TMP2:%.*]] = fmul <2 x double> [[BLOCK]], [[SPLAT_SPLAT]]
1027+
; CHECK-NEXT: [[BLOCK5:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
1028+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
1029+
; CHECK-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x double> poison, double [[TMP3]], i32 0
1030+
; CHECK-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT6]], <2 x double> poison, <2 x i32> zeroinitializer
1031+
; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[BLOCK5]], [[SPLAT_SPLAT7]]
1032+
; CHECK-NEXT: [[TMP5:%.*]] = fadd <2 x double> [[TMP2]], [[TMP4]]
1033+
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
1034+
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP6]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
1035+
; CHECK-NEXT: [[BLOCK8:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
1036+
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
1037+
; CHECK-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x double> poison, double [[TMP8]], i32 0
1038+
; CHECK-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT9]], <2 x double> poison, <2 x i32> zeroinitializer
1039+
; CHECK-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[BLOCK8]], [[SPLAT_SPLAT10]]
1040+
; CHECK-NEXT: [[BLOCK11:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
1041+
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
1042+
; CHECK-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i32 0
1043+
; CHECK-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT12]], <2 x double> poison, <2 x i32> zeroinitializer
1044+
; CHECK-NEXT: [[TMP11:%.*]] = fmul <2 x double> [[BLOCK11]], [[SPLAT_SPLAT13]]
1045+
; CHECK-NEXT: [[TMP12:%.*]] = fadd <2 x double> [[TMP9]], [[TMP11]]
1046+
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x double> [[TMP12]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
1047+
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
1048+
; CHECK-NEXT: [[BLOCK14:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
1049+
; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
1050+
; CHECK-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <2 x double> poison, double [[TMP15]], i32 0
1051+
; CHECK-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT15]], <2 x double> poison, <2 x i32> zeroinitializer
1052+
; CHECK-NEXT: [[TMP16:%.*]] = fmul <2 x double> [[BLOCK14]], [[SPLAT_SPLAT16]]
1053+
; CHECK-NEXT: [[BLOCK17:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
1054+
; CHECK-NEXT: [[TMP17:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
1055+
; CHECK-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i32 0
1056+
; CHECK-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT18]], <2 x double> poison, <2 x i32> zeroinitializer
1057+
; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x double> [[BLOCK17]], [[SPLAT_SPLAT19]]
1058+
; CHECK-NEXT: [[TMP19:%.*]] = fadd <2 x double> [[TMP16]], [[TMP18]]
1059+
; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x double> [[TMP19]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
1060+
; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP20]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
1061+
; CHECK-NEXT: [[BLOCK20:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
1062+
; CHECK-NEXT: [[TMP22:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
1063+
; CHECK-NEXT: [[SPLAT_SPLATINSERT21:%.*]] = insertelement <2 x double> poison, double [[TMP22]], i32 0
1064+
; CHECK-NEXT: [[SPLAT_SPLAT22:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT21]], <2 x double> poison, <2 x i32> zeroinitializer
1065+
; CHECK-NEXT: [[TMP23:%.*]] = fmul <2 x double> [[BLOCK20]], [[SPLAT_SPLAT22]]
1066+
; CHECK-NEXT: [[BLOCK23:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
1067+
; CHECK-NEXT: [[TMP24:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
1068+
; CHECK-NEXT: [[SPLAT_SPLATINSERT24:%.*]] = insertelement <2 x double> poison, double [[TMP24]], i32 0
1069+
; CHECK-NEXT: [[SPLAT_SPLAT25:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT24]], <2 x double> poison, <2 x i32> zeroinitializer
1070+
; CHECK-NEXT: [[TMP25:%.*]] = fmul <2 x double> [[BLOCK23]], [[SPLAT_SPLAT25]]
1071+
; CHECK-NEXT: [[TMP26:%.*]] = fadd <2 x double> [[TMP23]], [[TMP25]]
1072+
; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x double> [[TMP26]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
1073+
; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP27]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
1074+
; CHECK-NEXT: [[BLOCK26:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
1075+
; CHECK-NEXT: [[TMP29:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0
1076+
; CHECK-NEXT: [[SPLAT_SPLATINSERT27:%.*]] = insertelement <2 x double> poison, double [[TMP29]], i32 0
1077+
; CHECK-NEXT: [[SPLAT_SPLAT28:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT27]], <2 x double> poison, <2 x i32> zeroinitializer
1078+
; CHECK-NEXT: [[TMP30:%.*]] = fmul <2 x double> [[BLOCK26]], [[SPLAT_SPLAT28]]
1079+
; CHECK-NEXT: [[BLOCK29:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
1080+
; CHECK-NEXT: [[TMP31:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1
1081+
; CHECK-NEXT: [[SPLAT_SPLATINSERT30:%.*]] = insertelement <2 x double> poison, double [[TMP31]], i32 0
1082+
; CHECK-NEXT: [[SPLAT_SPLAT31:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT30]], <2 x double> poison, <2 x i32> zeroinitializer
1083+
; CHECK-NEXT: [[TMP32:%.*]] = fmul <2 x double> [[BLOCK29]], [[SPLAT_SPLAT31]]
1084+
; CHECK-NEXT: [[TMP33:%.*]] = fadd <2 x double> [[TMP30]], [[TMP32]]
1085+
; CHECK-NEXT: [[TMP34:%.*]] = shufflevector <2 x double> [[TMP33]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
1086+
; CHECK-NEXT: [[TMP35:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP34]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
1087+
; CHECK-NEXT: [[BLOCK32:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
1088+
; CHECK-NEXT: [[TMP36:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0
1089+
; CHECK-NEXT: [[SPLAT_SPLATINSERT33:%.*]] = insertelement <2 x double> poison, double [[TMP36]], i32 0
1090+
; CHECK-NEXT: [[SPLAT_SPLAT34:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT33]], <2 x double> poison, <2 x i32> zeroinitializer
1091+
; CHECK-NEXT: [[TMP37:%.*]] = fmul <2 x double> [[BLOCK32]], [[SPLAT_SPLAT34]]
1092+
; CHECK-NEXT: [[BLOCK35:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
1093+
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1
1094+
; CHECK-NEXT: [[SPLAT_SPLATINSERT36:%.*]] = insertelement <2 x double> poison, double [[TMP38]], i32 0
1095+
; CHECK-NEXT: [[SPLAT_SPLAT37:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT36]], <2 x double> poison, <2 x i32> zeroinitializer
1096+
; CHECK-NEXT: [[TMP39:%.*]] = fmul <2 x double> [[BLOCK35]], [[SPLAT_SPLAT37]]
1097+
; CHECK-NEXT: [[TMP40:%.*]] = fadd <2 x double> [[TMP37]], [[TMP39]]
1098+
; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <2 x double> [[TMP40]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
1099+
; CHECK-NEXT: [[TMP42:%.*]] = shufflevector <4 x double> [[TMP35]], <4 x double> [[TMP41]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
1100+
; CHECK-NEXT: [[TMP43:%.*]] = extractelement <4 x double> [[TMP14]], i64 0
1101+
; CHECK-NEXT: [[TMP44:%.*]] = insertelement <3 x double> undef, double [[TMP43]], i64 0
1102+
; CHECK-NEXT: [[TMP45:%.*]] = extractelement <4 x double> [[TMP28]], i64 0
1103+
; CHECK-NEXT: [[TMP46:%.*]] = insertelement <3 x double> [[TMP44]], double [[TMP45]], i64 1
1104+
; CHECK-NEXT: [[TMP47:%.*]] = extractelement <4 x double> [[TMP42]], i64 0
1105+
; CHECK-NEXT: [[TMP48:%.*]] = insertelement <3 x double> [[TMP46]], double [[TMP47]], i64 2
1106+
; CHECK-NEXT: [[TMP49:%.*]] = extractelement <4 x double> [[TMP14]], i64 1
1107+
; CHECK-NEXT: [[TMP50:%.*]] = insertelement <3 x double> undef, double [[TMP49]], i64 0
1108+
; CHECK-NEXT: [[TMP51:%.*]] = extractelement <4 x double> [[TMP28]], i64 1
1109+
; CHECK-NEXT: [[TMP52:%.*]] = insertelement <3 x double> [[TMP50]], double [[TMP51]], i64 1
1110+
; CHECK-NEXT: [[TMP53:%.*]] = extractelement <4 x double> [[TMP42]], i64 1
1111+
; CHECK-NEXT: [[TMP54:%.*]] = insertelement <3 x double> [[TMP52]], double [[TMP53]], i64 2
1112+
; CHECK-NEXT: [[TMP55:%.*]] = extractelement <4 x double> [[TMP14]], i64 2
1113+
; CHECK-NEXT: [[TMP56:%.*]] = insertelement <3 x double> undef, double [[TMP55]], i64 0
1114+
; CHECK-NEXT: [[TMP57:%.*]] = extractelement <4 x double> [[TMP28]], i64 2
1115+
; CHECK-NEXT: [[TMP58:%.*]] = insertelement <3 x double> [[TMP56]], double [[TMP57]], i64 1
1116+
; CHECK-NEXT: [[TMP59:%.*]] = extractelement <4 x double> [[TMP42]], i64 2
1117+
; CHECK-NEXT: [[TMP60:%.*]] = insertelement <3 x double> [[TMP58]], double [[TMP59]], i64 2
1118+
; CHECK-NEXT: [[TMP61:%.*]] = extractelement <4 x double> [[TMP14]], i64 3
1119+
; CHECK-NEXT: [[TMP62:%.*]] = insertelement <3 x double> undef, double [[TMP61]], i64 0
1120+
; CHECK-NEXT: [[TMP63:%.*]] = extractelement <4 x double> [[TMP28]], i64 3
1121+
; CHECK-NEXT: [[TMP64:%.*]] = insertelement <3 x double> [[TMP62]], double [[TMP63]], i64 1
1122+
; CHECK-NEXT: [[TMP65:%.*]] = extractelement <4 x double> [[TMP42]], i64 3
1123+
; CHECK-NEXT: [[TMP66:%.*]] = insertelement <3 x double> [[TMP64]], double [[TMP65]], i64 2
1124+
; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <3 x double> [[TMP48]], <3 x double> [[TMP54]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
1125+
; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <3 x double> [[TMP60]], <3 x double> [[TMP66]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
1126+
; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <6 x double> [[TMP67]], <6 x double> [[TMP68]], <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
1127+
; CHECK-NEXT: ret <12 x double> [[TMP69]]
1128+
;
1129+
%at = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %a, i32 2, i32 3)
1130+
%bt = call <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double> %b, i32 4, i32 2)
1131+
%m = call <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double> %at, <8 x double> %bt, i32 3, i32 2, i32 4)
1132+
ret <12 x double> %m
1133+
}
1134+
10151135
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
10161136
declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32)
10171137
declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)

0 commit comments

Comments
 (0)