@@ -119,9 +119,22 @@ class SIFoldOperandsImpl {
119
119
MachineOperand *OpToFold) const ;
120
120
bool isUseSafeToFold (const MachineInstr &MI,
121
121
const MachineOperand &UseMO) const ;
122
- bool
122
+
123
+ const TargetRegisterClass *getRegSeqInit (
124
+ MachineInstr &RegSeq,
125
+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const ;
126
+
127
+ const TargetRegisterClass *
123
128
getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
124
- Register UseReg, uint8_t OpTy) const ;
129
+ Register UseReg) const ;
130
+
131
+ std::pair<MachineOperand *, const TargetRegisterClass *>
132
+ isRegSeqSplat (MachineInstr &RegSeg) const ;
133
+
134
+ MachineOperand *tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
135
+ MachineOperand *SplatVal,
136
+ const TargetRegisterClass *SplatRC) const ;
137
+
125
138
bool tryToFoldACImm (MachineOperand &OpToFold, MachineInstr *UseMI,
126
139
unsigned UseOpIdx,
127
140
SmallVectorImpl<FoldCandidate> &FoldList) const ;
@@ -825,19 +838,24 @@ static MachineOperand *lookUpCopyChain(const SIInstrInfo &TII,
825
838
return Sub;
826
839
}
827
840
828
- // Find a def of the UseReg, check if it is a reg_sequence and find initializers
829
- // for each subreg, tracking it to foldable inline immediate if possible.
830
- // Returns true on success.
831
- bool SIFoldOperandsImpl::getRegSeqInit (
832
- SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
833
- Register UseReg, uint8_t OpTy) const {
834
- MachineInstr *Def = MRI->getVRegDef (UseReg);
835
- if (!Def || !Def->isRegSequence ())
836
- return false ;
841
+ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
842
+ MachineInstr &RegSeq,
843
+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs) const {
837
844
838
- for (unsigned I = 1 , E = Def->getNumExplicitOperands (); I != E; I += 2 ) {
839
- MachineOperand &SrcOp = Def->getOperand (I);
840
- unsigned SubRegIdx = Def->getOperand (I + 1 ).getImm ();
845
+ assert (RegSeq.isRegSequence ());
846
+
847
+ const TargetRegisterClass *RC = nullptr ;
848
+
849
+ for (unsigned I = 1 , E = RegSeq.getNumExplicitOperands (); I != E; I += 2 ) {
850
+ MachineOperand &SrcOp = RegSeq.getOperand (I);
851
+ unsigned SubRegIdx = RegSeq.getOperand (I + 1 ).getImm ();
852
+
853
+ // Only accept reg_sequence with uniform reg class inputs for simplicity.
854
+ const TargetRegisterClass *OpRC = getRegOpRC (*MRI, *TRI, SrcOp);
855
+ if (!RC)
856
+ RC = OpRC;
857
+ else if (!TRI->getCommonSubClass (RC, OpRC))
858
+ return nullptr ;
841
859
842
860
if (SrcOp.getSubReg ()) {
843
861
// TODO: Handle subregister compose
@@ -846,16 +864,73 @@ bool SIFoldOperandsImpl::getRegSeqInit(
846
864
}
847
865
848
866
MachineOperand *DefSrc = lookUpCopyChain (*TII, *MRI, SrcOp.getReg ());
849
- if (DefSrc && (DefSrc->isReg () ||
850
- (DefSrc->isImm () && TII->isInlineConstant (*DefSrc, OpTy)))) {
867
+ if (DefSrc && (DefSrc->isReg () || DefSrc->isImm ())) {
851
868
Defs.emplace_back (DefSrc, SubRegIdx);
852
869
continue ;
853
870
}
854
871
855
872
Defs.emplace_back (&SrcOp, SubRegIdx);
856
873
}
857
874
858
- return true ;
875
+ return RC;
876
+ }
877
+
878
+ // Find a def of the UseReg, check if it is a reg_sequence and find initializers
879
+ // for each subreg, tracking it to an immediate if possible. Returns the
880
+ // register class of the inputs on success.
881
+ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit (
882
+ SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
883
+ Register UseReg) const {
884
+ MachineInstr *Def = MRI->getVRegDef (UseReg);
885
+ if (!Def || !Def->isRegSequence ())
886
+ return nullptr ;
887
+
888
+ return getRegSeqInit (*Def, Defs);
889
+ }
890
+
891
+ std::pair<MachineOperand *, const TargetRegisterClass *>
892
+ SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
893
+ SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
894
+ const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
895
+ if (!SrcRC)
896
+ return {};
897
+
898
+ int64_t Imm;
899
+ for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
900
+ const MachineOperand *Op = Defs[I].first ;
901
+ if (!Op->isImm ())
902
+ return {};
903
+
904
+ int64_t SubImm = Op->getImm ();
905
+ if (!I) {
906
+ Imm = SubImm;
907
+ continue ;
908
+ }
909
+ if (Imm != SubImm)
910
+ return {}; // Can only fold splat constants
911
+ }
912
+
913
+ return {Defs[0 ].first , SrcRC};
914
+ }
915
+
916
+ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat (
917
+ MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
918
+ const TargetRegisterClass *SplatRC) const {
919
+ const MCInstrDesc &Desc = UseMI->getDesc ();
920
+ if (UseOpIdx >= Desc.getNumOperands ())
921
+ return nullptr ;
922
+
923
+ // Filter out unhandled pseudos.
924
+ if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
925
+ return nullptr ;
926
+
927
+ // FIXME: Verify SplatRC is compatible with the use operand
928
+ uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
929
+ if (!TII->isInlineConstant (*SplatVal, OpTy) ||
930
+ !TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
931
+ return nullptr ;
932
+
933
+ return SplatVal;
859
934
}
860
935
861
936
bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -869,7 +944,6 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
869
944
if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
870
945
return false ;
871
946
872
- uint8_t OpTy = Desc.operands ()[UseOpIdx].OperandType ;
873
947
MachineOperand &UseOp = UseMI->getOperand (UseOpIdx);
874
948
if (OpToFold.isImm ()) {
875
949
if (unsigned UseSubReg = UseOp.getSubReg ()) {
@@ -916,31 +990,7 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
916
990
}
917
991
}
918
992
919
- SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
920
- if (!getRegSeqInit (Defs, UseReg, OpTy))
921
- return false ;
922
-
923
- int32_t Imm;
924
- for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
925
- const MachineOperand *Op = Defs[I].first ;
926
- if (!Op->isImm ())
927
- return false ;
928
-
929
- auto SubImm = Op->getImm ();
930
- if (!I) {
931
- Imm = SubImm;
932
- if (!TII->isInlineConstant (*Op, OpTy) ||
933
- !TII->isOperandLegal (*UseMI, UseOpIdx, Op))
934
- return false ;
935
-
936
- continue ;
937
- }
938
- if (Imm != SubImm)
939
- return false ; // Can only fold splat constants
940
- }
941
-
942
- appendFoldCandidate (FoldList, UseMI, UseOpIdx, Defs[0 ].first );
943
- return true ;
993
+ return false ;
944
994
}
945
995
946
996
void SIFoldOperandsImpl::foldOperand (
@@ -970,14 +1020,26 @@ void SIFoldOperandsImpl::foldOperand(
970
1020
Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
971
1021
unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
972
1022
1023
+ MachineOperand *SplatVal;
1024
+ const TargetRegisterClass *SplatRC;
1025
+ std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
1026
+
973
1027
// Grab the use operands first
974
1028
SmallVector<MachineOperand *, 4 > UsesToProcess (
975
1029
llvm::make_pointer_range (MRI->use_nodbg_operands (RegSeqDstReg)));
976
1030
for (auto *RSUse : UsesToProcess) {
977
1031
MachineInstr *RSUseMI = RSUse->getParent ();
1032
+ unsigned OpNo = RSUseMI->getOperandNo (RSUse);
978
1033
979
- if (tryToFoldACImm (UseMI->getOperand (0 ), RSUseMI,
980
- RSUseMI->getOperandNo (RSUse), FoldList))
1034
+ if (SplatVal) {
1035
+ if (MachineOperand *Foldable =
1036
+ tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1037
+ appendFoldCandidate (FoldList, RSUseMI, OpNo, Foldable);
1038
+ continue ;
1039
+ }
1040
+ }
1041
+
1042
+ if (tryToFoldACImm (UseMI->getOperand (0 ), RSUseMI, OpNo, FoldList))
981
1043
continue ;
982
1044
983
1045
if (RSUse->getSubReg () != RegSeqDstSubReg)
@@ -986,6 +1048,7 @@ void SIFoldOperandsImpl::foldOperand(
986
1048
foldOperand (OpToFold, RSUseMI, RSUseMI->getOperandNo (RSUse), FoldList,
987
1049
CopiesToReplace);
988
1050
}
1051
+
989
1052
return ;
990
1053
}
991
1054
@@ -2137,7 +2200,7 @@ bool SIFoldOperandsImpl::tryFoldRegSequence(MachineInstr &MI) {
2137
2200
return false ;
2138
2201
2139
2202
SmallVector<std::pair<MachineOperand*, unsigned >, 32 > Defs;
2140
- if (!getRegSeqInit (Defs, Reg, MCOI::OPERAND_REGISTER ))
2203
+ if (!getRegSeqInit (Defs, Reg))
2141
2204
return false ;
2142
2205
2143
2206
for (auto &[Op, SubIdx] : Defs) {
0 commit comments