Skip to content

Commit 4ddab12

Browse files
authored
AMDGPU: Move reg_sequence splat handling (#140313)
This code clunkily tried to find a splat reg_sequence by looking at every use of the reg_sequence, and then looking back at the reg_sequence to see if it's a splat. Extract this into a separate helper function to help clean this up. This now parses whether the reg_sequence forms a splat once, and defers the legal inline immediate check to the use check (which is really use context dependent) The one regression is in globalisel, which has an extra copy that should have been separately folded out. It was getting dealt with by the handling of foldable copies in tryToFoldACImm. This is preparation for #139908 and #139317
1 parent 578741b commit 4ddab12

File tree

2 files changed

+112
-47
lines changed

2 files changed

+112
-47
lines changed

llvm/lib/Target/AMDGPU/SIFoldOperands.cpp

Lines changed: 109 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,22 @@ class SIFoldOperandsImpl {
119119
MachineOperand *OpToFold) const;
120120
bool isUseSafeToFold(const MachineInstr &MI,
121121
const MachineOperand &UseMO) const;
122-
bool
122+
123+
const TargetRegisterClass *getRegSeqInit(
124+
MachineInstr &RegSeq,
125+
SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs) const;
126+
127+
const TargetRegisterClass *
123128
getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
124-
Register UseReg, uint8_t OpTy) const;
129+
Register UseReg) const;
130+
131+
std::pair<MachineOperand *, const TargetRegisterClass *>
132+
isRegSeqSplat(MachineInstr &RegSeg) const;
133+
134+
MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
135+
MachineOperand *SplatVal,
136+
const TargetRegisterClass *SplatRC) const;
137+
125138
bool tryToFoldACImm(MachineOperand &OpToFold, MachineInstr *UseMI,
126139
unsigned UseOpIdx,
127140
SmallVectorImpl<FoldCandidate> &FoldList) const;
@@ -825,19 +838,24 @@ static MachineOperand *lookUpCopyChain(const SIInstrInfo &TII,
825838
return Sub;
826839
}
827840

828-
// Find a def of the UseReg, check if it is a reg_sequence and find initializers
829-
// for each subreg, tracking it to foldable inline immediate if possible.
830-
// Returns true on success.
831-
bool SIFoldOperandsImpl::getRegSeqInit(
832-
SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
833-
Register UseReg, uint8_t OpTy) const {
834-
MachineInstr *Def = MRI->getVRegDef(UseReg);
835-
if (!Def || !Def->isRegSequence())
836-
return false;
841+
const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
842+
MachineInstr &RegSeq,
843+
SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs) const {
837844

838-
for (unsigned I = 1, E = Def->getNumExplicitOperands(); I != E; I += 2) {
839-
MachineOperand &SrcOp = Def->getOperand(I);
840-
unsigned SubRegIdx = Def->getOperand(I + 1).getImm();
845+
assert(RegSeq.isRegSequence());
846+
847+
const TargetRegisterClass *RC = nullptr;
848+
849+
for (unsigned I = 1, E = RegSeq.getNumExplicitOperands(); I != E; I += 2) {
850+
MachineOperand &SrcOp = RegSeq.getOperand(I);
851+
unsigned SubRegIdx = RegSeq.getOperand(I + 1).getImm();
852+
853+
// Only accept reg_sequence with uniform reg class inputs for simplicity.
854+
const TargetRegisterClass *OpRC = getRegOpRC(*MRI, *TRI, SrcOp);
855+
if (!RC)
856+
RC = OpRC;
857+
else if (!TRI->getCommonSubClass(RC, OpRC))
858+
return nullptr;
841859

842860
if (SrcOp.getSubReg()) {
843861
// TODO: Handle subregister compose
@@ -846,16 +864,73 @@ bool SIFoldOperandsImpl::getRegSeqInit(
846864
}
847865

848866
MachineOperand *DefSrc = lookUpCopyChain(*TII, *MRI, SrcOp.getReg());
849-
if (DefSrc && (DefSrc->isReg() ||
850-
(DefSrc->isImm() && TII->isInlineConstant(*DefSrc, OpTy)))) {
867+
if (DefSrc && (DefSrc->isReg() || DefSrc->isImm())) {
851868
Defs.emplace_back(DefSrc, SubRegIdx);
852869
continue;
853870
}
854871

855872
Defs.emplace_back(&SrcOp, SubRegIdx);
856873
}
857874

858-
return true;
875+
return RC;
876+
}
877+
878+
// Find a def of the UseReg, check if it is a reg_sequence and find initializers
879+
// for each subreg, tracking it to an immediate if possible. Returns the
880+
// register class of the inputs on success.
881+
const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
882+
SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
883+
Register UseReg) const {
884+
MachineInstr *Def = MRI->getVRegDef(UseReg);
885+
if (!Def || !Def->isRegSequence())
886+
return nullptr;
887+
888+
return getRegSeqInit(*Def, Defs);
889+
}
890+
891+
std::pair<MachineOperand *, const TargetRegisterClass *>
892+
SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
893+
SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs;
894+
const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs);
895+
if (!SrcRC)
896+
return {};
897+
898+
int64_t Imm;
899+
for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
900+
const MachineOperand *Op = Defs[I].first;
901+
if (!Op->isImm())
902+
return {};
903+
904+
int64_t SubImm = Op->getImm();
905+
if (!I) {
906+
Imm = SubImm;
907+
continue;
908+
}
909+
if (Imm != SubImm)
910+
return {}; // Can only fold splat constants
911+
}
912+
913+
return {Defs[0].first, SrcRC};
914+
}
915+
916+
MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
917+
MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
918+
const TargetRegisterClass *SplatRC) const {
919+
const MCInstrDesc &Desc = UseMI->getDesc();
920+
if (UseOpIdx >= Desc.getNumOperands())
921+
return nullptr;
922+
923+
// Filter out unhandled pseudos.
924+
if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
925+
return nullptr;
926+
927+
// FIXME: Verify SplatRC is compatible with the use operand
928+
uint8_t OpTy = Desc.operands()[UseOpIdx].OperandType;
929+
if (!TII->isInlineConstant(*SplatVal, OpTy) ||
930+
!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal))
931+
return nullptr;
932+
933+
return SplatVal;
859934
}
860935

861936
bool SIFoldOperandsImpl::tryToFoldACImm(
@@ -869,7 +944,6 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
869944
if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
870945
return false;
871946

872-
uint8_t OpTy = Desc.operands()[UseOpIdx].OperandType;
873947
MachineOperand &UseOp = UseMI->getOperand(UseOpIdx);
874948
if (OpToFold.isImm()) {
875949
if (unsigned UseSubReg = UseOp.getSubReg()) {
@@ -916,31 +990,7 @@ bool SIFoldOperandsImpl::tryToFoldACImm(
916990
}
917991
}
918992

919-
SmallVector<std::pair<MachineOperand*, unsigned>, 32> Defs;
920-
if (!getRegSeqInit(Defs, UseReg, OpTy))
921-
return false;
922-
923-
int32_t Imm;
924-
for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
925-
const MachineOperand *Op = Defs[I].first;
926-
if (!Op->isImm())
927-
return false;
928-
929-
auto SubImm = Op->getImm();
930-
if (!I) {
931-
Imm = SubImm;
932-
if (!TII->isInlineConstant(*Op, OpTy) ||
933-
!TII->isOperandLegal(*UseMI, UseOpIdx, Op))
934-
return false;
935-
936-
continue;
937-
}
938-
if (Imm != SubImm)
939-
return false; // Can only fold splat constants
940-
}
941-
942-
appendFoldCandidate(FoldList, UseMI, UseOpIdx, Defs[0].first);
943-
return true;
993+
return false;
944994
}
945995

946996
void SIFoldOperandsImpl::foldOperand(
@@ -970,14 +1020,26 @@ void SIFoldOperandsImpl::foldOperand(
9701020
Register RegSeqDstReg = UseMI->getOperand(0).getReg();
9711021
unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm();
9721022

1023+
MachineOperand *SplatVal;
1024+
const TargetRegisterClass *SplatRC;
1025+
std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI);
1026+
9731027
// Grab the use operands first
9741028
SmallVector<MachineOperand *, 4> UsesToProcess(
9751029
llvm::make_pointer_range(MRI->use_nodbg_operands(RegSeqDstReg)));
9761030
for (auto *RSUse : UsesToProcess) {
9771031
MachineInstr *RSUseMI = RSUse->getParent();
1032+
unsigned OpNo = RSUseMI->getOperandNo(RSUse);
9781033

979-
if (tryToFoldACImm(UseMI->getOperand(0), RSUseMI,
980-
RSUseMI->getOperandNo(RSUse), FoldList))
1034+
if (SplatVal) {
1035+
if (MachineOperand *Foldable =
1036+
tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
1037+
appendFoldCandidate(FoldList, RSUseMI, OpNo, Foldable);
1038+
continue;
1039+
}
1040+
}
1041+
1042+
if (tryToFoldACImm(UseMI->getOperand(0), RSUseMI, OpNo, FoldList))
9811043
continue;
9821044

9831045
if (RSUse->getSubReg() != RegSeqDstSubReg)
@@ -986,6 +1048,7 @@ void SIFoldOperandsImpl::foldOperand(
9861048
foldOperand(OpToFold, RSUseMI, RSUseMI->getOperandNo(RSUse), FoldList,
9871049
CopiesToReplace);
9881050
}
1051+
9891052
return;
9901053
}
9911054

@@ -2137,7 +2200,7 @@ bool SIFoldOperandsImpl::tryFoldRegSequence(MachineInstr &MI) {
21372200
return false;
21382201

21392202
SmallVector<std::pair<MachineOperand*, unsigned>, 32> Defs;
2140-
if (!getRegSeqInit(Defs, Reg, MCOI::OPERAND_REGISTER))
2203+
if (!getRegSeqInit(Defs, Reg))
21412204
return false;
21422205

21432206
for (auto &[Op, SubIdx] : Defs) {

llvm/test/CodeGen/AMDGPU/packed-fp32.ll

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1680,10 +1680,12 @@ define amdgpu_kernel void @fma_v2_v_lit_splat(ptr addrspace(1) %a) {
16801680
; PACKED-GISEL-NEXT: s_load_dwordx2 s[0:1], s[4:5], 0x24
16811681
; PACKED-GISEL-NEXT: v_and_b32_e32 v0, 0x3ff, v0
16821682
; PACKED-GISEL-NEXT: v_lshlrev_b32_e32 v2, 3, v0
1683+
; PACKED-GISEL-NEXT: s_mov_b32 s2, 1.0
1684+
; PACKED-GISEL-NEXT: s_mov_b32 s3, s2
16831685
; PACKED-GISEL-NEXT: s_waitcnt lgkmcnt(0)
16841686
; PACKED-GISEL-NEXT: global_load_dwordx2 v[0:1], v2, s[0:1]
16851687
; PACKED-GISEL-NEXT: s_waitcnt vmcnt(0)
1686-
; PACKED-GISEL-NEXT: v_pk_fma_f32 v[0:1], v[0:1], 4.0, 1.0
1688+
; PACKED-GISEL-NEXT: v_pk_fma_f32 v[0:1], v[0:1], 4.0, s[2:3]
16871689
; PACKED-GISEL-NEXT: global_store_dwordx2 v2, v[0:1], s[0:1]
16881690
; PACKED-GISEL-NEXT: s_endpgm
16891691
%id = tail call i32 @llvm.amdgcn.workitem.id.x()

0 commit comments

Comments
 (0)