@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
227
227
getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
228
228
Register UseReg) const ;
229
229
230
- std::pair<MachineOperand * , const TargetRegisterClass *>
230
+ std::pair<int64_t , const TargetRegisterClass *>
231
231
isRegSeqSplat (MachineInstr &RegSeg) const ;
232
232
233
- MachineOperand * tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234
- MachineOperand * SplatVal,
235
- const TargetRegisterClass *SplatRC) const ;
233
+ bool tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234
+ int64_t SplatVal,
235
+ const TargetRegisterClass *SplatRC) const ;
236
236
237
237
bool tryToFoldACImm (const FoldableDef &OpToFold, MachineInstr *UseMI,
238
238
unsigned UseOpIdx,
@@ -967,15 +967,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
967
967
return getRegSeqInit (*Def, Defs);
968
968
}
969
969
970
- std::pair<MachineOperand * , const TargetRegisterClass *>
970
+ std::pair<int64_t , const TargetRegisterClass *>
971
971
SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
972
972
SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
973
973
const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
974
974
if (!SrcRC)
975
975
return {};
976
976
977
- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
978
- // every other other element is 0 for 64-bit immediates)
977
+ bool TryToMatchSplat64 = false ;
978
+
979
979
int64_t Imm;
980
980
for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
981
981
const MachineOperand *Op = Defs[I].first ;
@@ -987,38 +987,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
987
987
Imm = SubImm;
988
988
continue ;
989
989
}
990
- if (Imm != SubImm)
990
+
991
+ if (Imm != SubImm) {
992
+ if (I == 1 && (E & 1 ) == 0 ) {
993
+ // If we have an even number of inputs, there's a chance this is a
994
+ // 64-bit element splat broken into 32-bit pieces.
995
+ TryToMatchSplat64 = true ;
996
+ break ;
997
+ }
998
+
991
999
return {}; // Can only fold splat constants
1000
+ }
1001
+ }
1002
+
1003
+ if (!TryToMatchSplat64)
1004
+ return {Defs[0 ].first ->getImm (), SrcRC};
1005
+
1006
+ // Fallback to recognizing 64-bit splats broken into 32-bit pieces
1007
+ // (i.e. recognize every other other element is 0 for 64-bit immediates)
1008
+ int64_t SplatVal64;
1009
+ for (unsigned I = 0 , E = Defs.size (); I != E; I += 2 ) {
1010
+ const MachineOperand *Op0 = Defs[I].first ;
1011
+ const MachineOperand *Op1 = Defs[I + 1 ].first ;
1012
+
1013
+ if (!Op0->isImm () || !Op1->isImm ())
1014
+ return {};
1015
+
1016
+ unsigned SubReg0 = Defs[I].second ;
1017
+ unsigned SubReg1 = Defs[I + 1 ].second ;
1018
+
1019
+ // Assume we're going to generally encounter reg_sequences with sorted
1020
+ // subreg indexes, so reject any that aren't consecutive.
1021
+ if (TRI->getChannelFromSubReg (SubReg0) + 1 !=
1022
+ TRI->getChannelFromSubReg (SubReg1))
1023
+ return {};
1024
+
1025
+ int64_t MergedVal = Make_64 (Op1->getImm (), Op0->getImm ());
1026
+ if (I == 0 )
1027
+ SplatVal64 = MergedVal;
1028
+ else if (SplatVal64 != MergedVal)
1029
+ return {};
992
1030
}
993
1031
994
- return {Defs[0 ].first , SrcRC};
1032
+ const TargetRegisterClass *RC64 = TRI->getSubRegisterClass (
1033
+ MRI->getRegClass (RegSeq.getOperand (0 ).getReg ()), AMDGPU::sub0_sub1);
1034
+
1035
+ return {SplatVal64, RC64};
995
1036
}
996
1037
997
- MachineOperand * SIFoldOperandsImpl::tryFoldRegSeqSplat (
998
- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand * SplatVal,
1038
+ bool SIFoldOperandsImpl::tryFoldRegSeqSplat (
1039
+ MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
999
1040
const TargetRegisterClass *SplatRC) const {
1000
1041
const MCInstrDesc &Desc = UseMI->getDesc ();
1001
1042
if (UseOpIdx >= Desc.getNumOperands ())
1002
- return nullptr ;
1043
+ return false ;
1003
1044
1004
1045
// Filter out unhandled pseudos.
1005
1046
if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
1006
- return nullptr ;
1047
+ return false ;
1007
1048
1008
1049
int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
1009
1050
if (RCID == -1 )
1010
- return nullptr ;
1051
+ return false ;
1052
+
1053
+ const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1011
1054
1012
1055
// Special case 0/-1, since when interpreted as a 64-bit element both halves
1013
- // have the same bits. Effectively this code does not handle 64-bit element
1014
- // operands correctly, as the incoming 64-bit constants are already split into
1015
- // 32-bit sequence elements.
1016
- //
1017
- // TODO: We should try to figure out how to interpret the reg_sequence as a
1018
- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1019
- // constants.
1020
- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
1021
- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1056
+ // have the same bits. These are the only cases where a splat has the same
1057
+ // interpretation for 32-bit and 64-bit splats.
1058
+ if (SplatVal != 0 && SplatVal != -1 ) {
1022
1059
// We need to figure out the scalar type read by the operand. e.g. the MFMA
1023
1060
// operand will be AReg_128, and we want to check if it's compatible with an
1024
1061
// AReg_32 constant.
@@ -1032,17 +1069,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
1032
1069
OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
1033
1070
break ;
1034
1071
default :
1035
- return nullptr ;
1072
+ return false ;
1036
1073
}
1037
1074
1038
1075
if (!TRI->getCommonSubClass (OpRC, SplatRC))
1039
- return nullptr ;
1076
+ return false ;
1040
1077
}
1041
1078
1042
- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
1043
- return nullptr ;
1079
+ MachineOperand TmpOp = MachineOperand::CreateImm (SplatVal);
1080
+ if (!TII->isOperandLegal (*UseMI, UseOpIdx, &TmpOp))
1081
+ return false ;
1044
1082
1045
- return SplatVal ;
1083
+ return true ;
1046
1084
}
1047
1085
1048
1086
bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -1120,7 +1158,7 @@ void SIFoldOperandsImpl::foldOperand(
1120
1158
Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
1121
1159
unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
1122
1160
1123
- MachineOperand * SplatVal;
1161
+ int64_t SplatVal;
1124
1162
const TargetRegisterClass *SplatRC;
1125
1163
std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
1126
1164
@@ -1131,10 +1169,9 @@ void SIFoldOperandsImpl::foldOperand(
1131
1169
MachineInstr *RSUseMI = RSUse->getParent ();
1132
1170
unsigned OpNo = RSUseMI->getOperandNo (RSUse);
1133
1171
1134
- if (SplatVal) {
1135
- if (MachineOperand *Foldable =
1136
- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1137
- FoldableDef SplatDef (*Foldable, SplatRC);
1172
+ if (SplatRC) {
1173
+ if (tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1174
+ FoldableDef SplatDef (SplatVal, SplatRC);
1138
1175
appendFoldCandidate (FoldList, RSUseMI, OpNo, SplatDef);
1139
1176
continue ;
1140
1177
}
0 commit comments