@@ -226,12 +226,12 @@ class SIFoldOperandsImpl {
226
226
getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
227
227
Register UseReg) const ;
228
228
229
- std::pair<MachineOperand * , const TargetRegisterClass *>
229
+ std::pair<int64_t , const TargetRegisterClass *>
230
230
isRegSeqSplat (MachineInstr &RegSeg) const ;
231
231
232
- MachineOperand * tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
233
- MachineOperand * SplatVal,
234
- const TargetRegisterClass *SplatRC) const ;
232
+ bool tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
233
+ int64_t SplatVal,
234
+ const TargetRegisterClass *SplatRC) const ;
235
235
236
236
bool tryToFoldACImm (const FoldableDef &OpToFold, MachineInstr *UseMI,
237
237
unsigned UseOpIdx,
@@ -964,15 +964,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
964
964
return getRegSeqInit (*Def, Defs);
965
965
}
966
966
967
- std::pair<MachineOperand * , const TargetRegisterClass *>
967
+ std::pair<int64_t , const TargetRegisterClass *>
968
968
SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
969
969
SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
970
970
const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
971
971
if (!SrcRC)
972
972
return {};
973
973
974
- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
975
- // every other other element is 0 for 64-bit immediates)
974
+ bool TryToMatchSplat64 = false ;
975
+
976
976
int64_t Imm;
977
977
for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
978
978
const MachineOperand *Op = Defs[I].first ;
@@ -984,38 +984,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
984
984
Imm = SubImm;
985
985
continue ;
986
986
}
987
- if (Imm != SubImm)
987
+
988
+ if (Imm != SubImm) {
989
+ if (I == 1 && (E & 1 ) == 0 ) {
990
+ // If we have an even number of inputs, there's a chance this is a
991
+ // 64-bit element splat broken into 32-bit pieces.
992
+ TryToMatchSplat64 = true ;
993
+ break ;
994
+ }
995
+
988
996
return {}; // Can only fold splat constants
997
+ }
998
+ }
999
+
1000
+ if (!TryToMatchSplat64)
1001
+ return {Defs[0 ].first ->getImm (), SrcRC};
1002
+
1003
+ // Fallback to recognizing 64-bit splats broken into 32-bit pieces
1004
+ // (i.e. recognize every other other element is 0 for 64-bit immediates)
1005
+ int64_t SplatVal64;
1006
+ for (unsigned I = 0 , E = Defs.size (); I != E; I += 2 ) {
1007
+ const MachineOperand *Op0 = Defs[I].first ;
1008
+ const MachineOperand *Op1 = Defs[I + 1 ].first ;
1009
+
1010
+ if (!Op0->isImm () || !Op1->isImm ())
1011
+ return {};
1012
+
1013
+ unsigned SubReg0 = Defs[I].second ;
1014
+ unsigned SubReg1 = Defs[I + 1 ].second ;
1015
+
1016
+ // Assume we're going to generally encounter reg_sequences with sorted
1017
+ // subreg indexes, so reject any that aren't consecutive.
1018
+ if (TRI->getChannelFromSubReg (SubReg0) + 1 !=
1019
+ TRI->getChannelFromSubReg (SubReg1))
1020
+ return {};
1021
+
1022
+ int64_t MergedVal = Make_64 (Op1->getImm (), Op0->getImm ());
1023
+ if (I == 0 )
1024
+ SplatVal64 = MergedVal;
1025
+ else if (SplatVal64 != MergedVal)
1026
+ return {};
989
1027
}
990
1028
991
- return {Defs[0 ].first , SrcRC};
1029
+ const TargetRegisterClass *RC64 = TRI->getSubRegisterClass (
1030
+ MRI->getRegClass (RegSeq.getOperand (0 ).getReg ()), AMDGPU::sub0_sub1);
1031
+
1032
+ return {SplatVal64, RC64};
992
1033
}
993
1034
994
- MachineOperand * SIFoldOperandsImpl::tryFoldRegSeqSplat (
995
- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand * SplatVal,
1035
+ bool SIFoldOperandsImpl::tryFoldRegSeqSplat (
1036
+ MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
996
1037
const TargetRegisterClass *SplatRC) const {
997
1038
const MCInstrDesc &Desc = UseMI->getDesc ();
998
1039
if (UseOpIdx >= Desc.getNumOperands ())
999
- return nullptr ;
1040
+ return false ;
1000
1041
1001
1042
// Filter out unhandled pseudos.
1002
1043
if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
1003
- return nullptr ;
1044
+ return false ;
1004
1045
1005
1046
int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
1006
1047
if (RCID == -1 )
1007
- return nullptr ;
1048
+ return false ;
1049
+
1050
+ const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1008
1051
1009
1052
// Special case 0/-1, since when interpreted as a 64-bit element both halves
1010
- // have the same bits. Effectively this code does not handle 64-bit element
1011
- // operands correctly, as the incoming 64-bit constants are already split into
1012
- // 32-bit sequence elements.
1013
- //
1014
- // TODO: We should try to figure out how to interpret the reg_sequence as a
1015
- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1016
- // constants.
1017
- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
1018
- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1053
+ // have the same bits. These are the only cases where a splat has the same
1054
+ // interpretation for 32-bit and 64-bit splats.
1055
+ if (SplatVal != 0 && SplatVal != -1 ) {
1019
1056
// We need to figure out the scalar type read by the operand. e.g. the MFMA
1020
1057
// operand will be AReg_128, and we want to check if it's compatible with an
1021
1058
// AReg_32 constant.
@@ -1029,17 +1066,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
1029
1066
OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
1030
1067
break ;
1031
1068
default :
1032
- return nullptr ;
1069
+ return false ;
1033
1070
}
1034
1071
1035
1072
if (!TRI->getCommonSubClass (OpRC, SplatRC))
1036
- return nullptr ;
1073
+ return false ;
1037
1074
}
1038
1075
1039
- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
1040
- return nullptr ;
1076
+ MachineOperand TmpOp = MachineOperand::CreateImm (SplatVal);
1077
+ if (!TII->isOperandLegal (*UseMI, UseOpIdx, &TmpOp))
1078
+ return false ;
1041
1079
1042
- return SplatVal ;
1080
+ return true ;
1043
1081
}
1044
1082
1045
1083
bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -1117,7 +1155,7 @@ void SIFoldOperandsImpl::foldOperand(
1117
1155
Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
1118
1156
unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
1119
1157
1120
- MachineOperand * SplatVal;
1158
+ int64_t SplatVal;
1121
1159
const TargetRegisterClass *SplatRC;
1122
1160
std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
1123
1161
@@ -1128,10 +1166,9 @@ void SIFoldOperandsImpl::foldOperand(
1128
1166
MachineInstr *RSUseMI = RSUse->getParent ();
1129
1167
unsigned OpNo = RSUseMI->getOperandNo (RSUse);
1130
1168
1131
- if (SplatVal) {
1132
- if (MachineOperand *Foldable =
1133
- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1134
- FoldableDef SplatDef (*Foldable, SplatRC);
1169
+ if (SplatRC) {
1170
+ if (tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1171
+ FoldableDef SplatDef (SplatVal, SplatRC);
1135
1172
appendFoldCandidate (FoldList, RSUseMI, OpNo, SplatDef);
1136
1173
continue ;
1137
1174
}
0 commit comments