Skip to content

Commit b68d880

Browse files
committed
AMDGPU: Handle folding vector splats of inline split f64 inline immediates
Recognize a reg_sequence with 32-bit elements that produce a 64-bit splat value. This enables folding f64 constants into mfma operands
1 parent 4ec9c56 commit b68d880

File tree

2 files changed

+76
-68
lines changed

2 files changed

+76
-68
lines changed

llvm/lib/Target/AMDGPU/SIFoldOperands.cpp

Lines changed: 70 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
227227
getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
228228
Register UseReg) const;
229229

230-
std::pair<MachineOperand *, const TargetRegisterClass *>
230+
std::pair<int64_t, const TargetRegisterClass *>
231231
isRegSeqSplat(MachineInstr &RegSeg) const;
232232

233-
MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
234-
MachineOperand *SplatVal,
235-
const TargetRegisterClass *SplatRC) const;
233+
bool tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
234+
int64_t SplatVal,
235+
const TargetRegisterClass *SplatRC) const;
236236

237237
bool tryToFoldACImm(const FoldableDef &OpToFold, MachineInstr *UseMI,
238238
unsigned UseOpIdx,
@@ -967,15 +967,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
967967
return getRegSeqInit(*Def, Defs);
968968
}
969969

970-
std::pair<MachineOperand *, const TargetRegisterClass *>
970+
std::pair<int64_t, const TargetRegisterClass *>
971971
SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
972972
SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs;
973973
const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs);
974974
if (!SrcRC)
975975
return {};
976976

977-
// TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
978-
// every other other element is 0 for 64-bit immediates)
977+
bool TryToMatchSplat64 = false;
978+
979979
int64_t Imm;
980980
for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
981981
const MachineOperand *Op = Defs[I].first;
@@ -987,38 +987,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
987987
Imm = SubImm;
988988
continue;
989989
}
990-
if (Imm != SubImm)
990+
991+
if (Imm != SubImm) {
992+
if (I == 1 && (E & 1) == 0) {
993+
// If we have an even number of inputs, there's a chance this is a
994+
// 64-bit element splat broken into 32-bit pieces.
995+
TryToMatchSplat64 = true;
996+
break;
997+
}
998+
991999
return {}; // Can only fold splat constants
1000+
}
1001+
}
1002+
1003+
if (!TryToMatchSplat64)
1004+
return {Defs[0].first->getImm(), SrcRC};
1005+
1006+
// Fallback to recognizing 64-bit splats broken into 32-bit pieces
1007+
// (i.e. recognize every other other element is 0 for 64-bit immediates)
1008+
int64_t SplatVal64;
1009+
for (unsigned I = 0, E = Defs.size(); I != E; I += 2) {
1010+
const MachineOperand *Op0 = Defs[I].first;
1011+
const MachineOperand *Op1 = Defs[I + 1].first;
1012+
1013+
if (!Op0->isImm() || !Op1->isImm())
1014+
return {};
1015+
1016+
unsigned SubReg0 = Defs[I].second;
1017+
unsigned SubReg1 = Defs[I + 1].second;
1018+
1019+
// Assume we're going to generally encounter reg_sequences with sorted
1020+
// subreg indexes, so reject any that aren't consecutive.
1021+
if (TRI->getChannelFromSubReg(SubReg0) + 1 !=
1022+
TRI->getChannelFromSubReg(SubReg1))
1023+
return {};
1024+
1025+
int64_t MergedVal = Make_64(Op1->getImm(), Op0->getImm());
1026+
if (I == 0)
1027+
SplatVal64 = MergedVal;
1028+
else if (SplatVal64 != MergedVal)
1029+
return {};
9921030
}
9931031

994-
return {Defs[0].first, SrcRC};
1032+
const TargetRegisterClass *RC64 = TRI->getSubRegisterClass(
1033+
MRI->getRegClass(RegSeq.getOperand(0).getReg()), AMDGPU::sub0_sub1);
1034+
1035+
return {SplatVal64, RC64};
9951036
}
9961037

997-
MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
998-
MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
1038+
bool SIFoldOperandsImpl::tryFoldRegSeqSplat(
1039+
MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
9991040
const TargetRegisterClass *SplatRC) const {
10001041
const MCInstrDesc &Desc = UseMI->getDesc();
10011042
if (UseOpIdx >= Desc.getNumOperands())
1002-
return nullptr;
1043+
return false;
10031044

10041045
// Filter out unhandled pseudos.
10051046
if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
1006-
return nullptr;
1047+
return false;
10071048

10081049
int16_t RCID = Desc.operands()[UseOpIdx].RegClass;
10091050
if (RCID == -1)
1010-
return nullptr;
1051+
return false;
1052+
1053+
const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
10111054

10121055
// Special case 0/-1, since when interpreted as a 64-bit element both halves
1013-
// have the same bits. Effectively this code does not handle 64-bit element
1014-
// operands correctly, as the incoming 64-bit constants are already split into
1015-
// 32-bit sequence elements.
1016-
//
1017-
// TODO: We should try to figure out how to interpret the reg_sequence as a
1018-
// split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1019-
// constants.
1020-
if (SplatVal->getImm() != 0 && SplatVal->getImm() != -1) {
1021-
const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
1056+
// have the same bits. These are the only cases where a splat has the same
1057+
// interpretation for 32-bit and 64-bit splats.
1058+
if (SplatVal != 0 && SplatVal != -1) {
10221059
// We need to figure out the scalar type read by the operand. e.g. the MFMA
10231060
// operand will be AReg_128, and we want to check if it's compatible with an
10241061
// AReg_32 constant.
@@ -1032,17 +1069,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
10321069
OpRC = TRI->getSubRegisterClass(OpRC, AMDGPU::sub0_sub1);
10331070
break;
10341071
default:
1035-
return nullptr;
1072+
return false;
10361073
}
10371074

10381075
if (!TRI->getCommonSubClass(OpRC, SplatRC))
1039-
return nullptr;
1076+
return false;
10401077
}
10411078

1042-
if (!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal))
1043-
return nullptr;
1079+
MachineOperand TmpOp = MachineOperand::CreateImm(SplatVal);
1080+
if (!TII->isOperandLegal(*UseMI, UseOpIdx, &TmpOp))
1081+
return false;
10441082

1045-
return SplatVal;
1083+
return true;
10461084
}
10471085

10481086
bool SIFoldOperandsImpl::tryToFoldACImm(
@@ -1120,7 +1158,7 @@ void SIFoldOperandsImpl::foldOperand(
11201158
Register RegSeqDstReg = UseMI->getOperand(0).getReg();
11211159
unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm();
11221160

1123-
MachineOperand *SplatVal;
1161+
int64_t SplatVal;
11241162
const TargetRegisterClass *SplatRC;
11251163
std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI);
11261164

@@ -1131,10 +1169,9 @@ void SIFoldOperandsImpl::foldOperand(
11311169
MachineInstr *RSUseMI = RSUse->getParent();
11321170
unsigned OpNo = RSUseMI->getOperandNo(RSUse);
11331171

1134-
if (SplatVal) {
1135-
if (MachineOperand *Foldable =
1136-
tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
1137-
FoldableDef SplatDef(*Foldable, SplatRC);
1172+
if (SplatRC) {
1173+
if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
1174+
FoldableDef SplatDef(SplatVal, SplatRC);
11381175
appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef);
11391176
continue;
11401177
}

llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,9 @@ bb:
165165
}
166166

167167
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_1:
168-
; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0x3ff00000
169-
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
170-
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
171-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
172-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
173-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
174-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
175-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
176-
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
177-
178-
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
168+
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
179169
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
180-
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
170+
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
181171
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
182172
; GCN: global_store_dwordx4
183173
; GCN: global_store_dwordx4
@@ -190,19 +180,9 @@ bb:
190180
}
191181

192182
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_neg1:
193-
; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0xbff00000
194-
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
195-
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
196-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
197-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
198-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
199-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
200-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
201-
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
202-
203-
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
183+
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
204184
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
205-
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
185+
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
206186
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
207187
; GCN: global_store_dwordx4
208188
; GCN: global_store_dwordx4
@@ -215,18 +195,9 @@ bb:
215195
}
216196

217197
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_int_64:
218-
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 64{{$}}
219-
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], 0
220-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
221-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
222-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
223-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
224-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
225-
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
226-
227-
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
198+
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
228199
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
229-
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
200+
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
230201
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
231202
; GCN: global_store_dwordx4
232203
; GCN: global_store_dwordx4

0 commit comments

Comments
 (0)