Skip to content

Commit 6841520

Browse files
jrbyrnesbcahoon
authored andcommitted
[AMDGPU] Correctly insert s_nops for dst forwarding hazard (llvm#100276)
MI300 ISA section 4.5 states there is a hazard between "VALU op which uses OPSEL or SDWA with changes the result’s bit position" and "VALU op consumes result of that op" This includes the case where the second op is SDWA with same dest and dst_sel != DWORD && dst_unused == UNUSED_PRESERVE. In this case, there is an implicit read of the first op dst and the compiler needs to resolve this hazard. Confirmed with HW team. We model dst_unused == UNUSED_PRESERVE as tied-def of implicit operand, so this PR checks for that. MI300_SP_MAS section 1.3.9.2 specifies that CVT_SR_FP8_F32 and CVT_SR_BF8_F32 with opsel[3:2] !=0 have dest forwarding issue. Currently, we only add check for CVT_SR_FP8_F32 with opsel[3] != 0 -- this PR adds support opsel[2] != 0 as well (cherry picked from commit 7bcf4d6) Change-Id: Ic29854fd790cdd414b6ab77d7d7cbc4dc5c42815
1 parent 2d493db commit 6841520

File tree

8 files changed

+579
-22
lines changed

8 files changed

+579
-22
lines changed

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 112 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -873,13 +873,78 @@ GCNHazardRecognizer::checkVALUHazardsHelper(const MachineOperand &Def,
873873
return DataIdx >= 0 &&
874874
TRI->regsOverlap(MI.getOperand(DataIdx).getReg(), Reg);
875875
};
876+
876877
int WaitStatesNeededForDef =
877878
VALUWaitStates - getWaitStatesSince(IsHazardFn, VALUWaitStates);
878879
WaitStatesNeeded = std::max(WaitStatesNeeded, WaitStatesNeededForDef);
879880

880881
return WaitStatesNeeded;
881882
}
882883

884+
/// Dest sel forwarding issue occurs if additional logic is needed to swizzle /
885+
/// pack the computed value into correct bit position of the dest register. This
886+
/// occurs if we have SDWA with dst_sel != DWORD or if we have op_sel with
887+
/// dst_sel that is not aligned to the register. This function analayzes the \p
888+
/// MI and \returns an operand with dst forwarding issue, or nullptr if
889+
/// none exists.
890+
static const MachineOperand *
891+
getDstSelForwardingOperand(const MachineInstr &MI, const GCNSubtarget &ST) {
892+
if (!SIInstrInfo::isVALU(MI))
893+
return nullptr;
894+
895+
const SIInstrInfo *TII = ST.getInstrInfo();
896+
897+
unsigned Opcode = MI.getOpcode();
898+
899+
// There are three different types of instructions
900+
// which produce forwarded dest: 1. SDWA with dst_sel != DWORD, 2. VOP3
901+
// which write hi bits (e.g. op_sel[3] == 1), and 3. CVR_SR_FP8_F32 and
902+
// CVT_SR_BF8_F32 with op_sel[3:2]
903+
// != 0
904+
if (SIInstrInfo::isSDWA(MI)) {
905+
// Type 1: SDWA with dst_sel != DWORD
906+
if (auto *DstSel = TII->getNamedOperand(MI, AMDGPU::OpName::dst_sel))
907+
if (DstSel->getImm() == AMDGPU::SDWA::DWORD)
908+
return nullptr;
909+
} else {
910+
// Type 2 && Type 3: (VOP3 which write the hi bits) || (CVT_SR_FP8_F32 and
911+
// CVT_SR_BF8_F32 with op_sel[3:2] != 0)
912+
if (!AMDGPU::hasNamedOperand(Opcode, AMDGPU::OpName::op_sel) ||
913+
!(TII->getNamedOperand(MI, AMDGPU::OpName::src0_modifiers)->getImm() &
914+
SISrcMods::DST_OP_SEL ||
915+
(AMDGPU::isFP8DstSelInst(Opcode) &&
916+
(TII->getNamedOperand(MI, AMDGPU::OpName::src2_modifiers)->getImm() &
917+
SISrcMods::OP_SEL_0))))
918+
return nullptr;
919+
}
920+
921+
return TII->getNamedOperand(MI, AMDGPU::OpName::vdst);
922+
}
923+
924+
/// Checks whether the provided \p MI "consumes" the operand with a Dest sel
925+
/// fowarding issue \p Dst . We may "consume" the Dst via a standard explicit
926+
/// RAW, or through irregular ways (e.g implicit RAW, certain types of WAW)
927+
static bool consumesDstSelForwardingOperand(const MachineInstr *VALU,
928+
const MachineOperand *Dst,
929+
const SIRegisterInfo *TRI) {
930+
// We must consider implicit reads of the VALU. SDWA with dst_sel and
931+
// UNUSED_PRESERVE will implicitly read the result from forwarded dest,
932+
// and we must account for that hazard.
933+
// We also must account for WAW hazards. In particular, WAW with dest
934+
// preserve semantics (e.g. VOP3 with op_sel, VOP2 &&
935+
// !zeroesHigh16BitsOfDest) will read the forwarded dest for parity
936+
// check for ECC. Without accounting for this hazard, the ECC will be
937+
// wrong.
938+
// TODO: limit to RAW (including implicit reads) + problematic WAW (i.e.
939+
// complete zeroesHigh16BitsOfDest)
940+
for (auto &Operand : VALU->operands()) {
941+
if (Operand.isReg() && TRI->regsOverlap(Dst->getReg(), Operand.getReg())) {
942+
return true;
943+
}
944+
}
945+
return false;
946+
}
947+
883948
int GCNHazardRecognizer::checkVALUHazards(MachineInstr *VALU) {
884949
int WaitStatesNeeded = 0;
885950

@@ -910,27 +975,18 @@ int GCNHazardRecognizer::checkVALUHazards(MachineInstr *VALU) {
910975
if (ST.hasDstSelForwardingHazard()) {
911976
const int Shift16DefWaitstates = 1;
912977

913-
auto IsShift16BitDefFn = [this, VALU](const MachineInstr &MI) {
914-
if (!SIInstrInfo::isVALU(MI))
915-
return false;
916-
const SIInstrInfo *TII = ST.getInstrInfo();
917-
if (SIInstrInfo::isSDWA(MI)) {
918-
if (auto *DstSel = TII->getNamedOperand(MI, AMDGPU::OpName::dst_sel))
919-
if (DstSel->getImm() == AMDGPU::SDWA::DWORD)
920-
return false;
921-
} else {
922-
if (!AMDGPU::hasNamedOperand(MI.getOpcode(), AMDGPU::OpName::op_sel) ||
923-
!(TII->getNamedOperand(MI, AMDGPU::OpName::src0_modifiers)
924-
->getImm() &
925-
SISrcMods::DST_OP_SEL))
926-
return false;
927-
}
978+
auto IsShift16BitDefFn = [this, VALU](const MachineInstr &ProducerMI) {
928979
const SIRegisterInfo *TRI = ST.getRegisterInfo();
929-
if (auto *Dst = TII->getNamedOperand(MI, AMDGPU::OpName::vdst)) {
930-
Register Def = Dst->getReg();
980+
const MachineOperand *ForwardedDst =
981+
getDstSelForwardingOperand(ProducerMI, ST);
982+
if (ForwardedDst) {
983+
return consumesDstSelForwardingOperand(VALU, ForwardedDst, TRI);
984+
}
931985

932-
for (const MachineOperand &Use : VALU->explicit_uses()) {
933-
if (Use.isReg() && TRI->regsOverlap(Def, Use.getReg()))
986+
if (ProducerMI.isInlineAsm()) {
987+
// Assume inline asm has dst forwarding hazard
988+
for (auto &Def : ProducerMI.all_defs()) {
989+
if (consumesDstSelForwardingOperand(VALU, &Def, TRI))
934990
return true;
935991
}
936992
}
@@ -1027,7 +1083,7 @@ int GCNHazardRecognizer::checkInlineAsmHazards(MachineInstr *IA) {
10271083
// problematic thus far.
10281084

10291085
// see checkVALUHazards()
1030-
if (!ST.has12DWordStoreHazard())
1086+
if (!ST.has12DWordStoreHazard() && !ST.hasDstSelForwardingHazard())
10311087
return 0;
10321088

10331089
const MachineRegisterInfo &MRI = MF.getRegInfo();
@@ -1036,11 +1092,45 @@ int GCNHazardRecognizer::checkInlineAsmHazards(MachineInstr *IA) {
10361092
for (const MachineOperand &Op :
10371093
llvm::drop_begin(IA->operands(), InlineAsm::MIOp_FirstOperand)) {
10381094
if (Op.isReg() && Op.isDef()) {
1039-
WaitStatesNeeded =
1040-
std::max(WaitStatesNeeded, checkVALUHazardsHelper(Op, MRI));
1095+
if (!TRI.isVectorRegister(MRI, Op.getReg()))
1096+
continue;
1097+
1098+
if (ST.has12DWordStoreHazard()) {
1099+
WaitStatesNeeded =
1100+
std::max(WaitStatesNeeded, checkVALUHazardsHelper(Op, MRI));
1101+
}
10411102
}
10421103
}
10431104

1105+
if (ST.hasDstSelForwardingHazard()) {
1106+
const int Shift16DefWaitstates = 1;
1107+
1108+
auto IsShift16BitDefFn = [this, &IA](const MachineInstr &ProducerMI) {
1109+
const MachineOperand *Dst = getDstSelForwardingOperand(ProducerMI, ST);
1110+
// Assume inline asm reads the dst
1111+
if (Dst)
1112+
return IA->modifiesRegister(Dst->getReg(), &TRI) ||
1113+
IA->readsRegister(Dst->getReg(), &TRI);
1114+
1115+
if (ProducerMI.isInlineAsm()) {
1116+
// If MI is inline asm, assume it has dst forwarding hazard
1117+
for (auto &Def : ProducerMI.all_defs()) {
1118+
if (IA->modifiesRegister(Def.getReg(), &TRI) ||
1119+
IA->readsRegister(Def.getReg(), &TRI)) {
1120+
return true;
1121+
}
1122+
}
1123+
}
1124+
1125+
return false;
1126+
};
1127+
1128+
int WaitStatesNeededForDef =
1129+
Shift16DefWaitstates -
1130+
getWaitStatesSince(IsShift16BitDefFn, Shift16DefWaitstates);
1131+
WaitStatesNeeded = std::max(WaitStatesNeeded, WaitStatesNeededForDef);
1132+
}
1133+
10441134
return WaitStatesNeeded;
10451135
}
10461136

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,6 +2296,8 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
22962296

22972297
field bit IsFP8 = 0;
22982298

2299+
field bit HasFP8DstByteSel = 0;
2300+
22992301
field bit HasDst = !ne(DstVT.Value, untyped.Value);
23002302
field bit HasDst32 = HasDst;
23012303
field bit EmitDst = HasDst; // force dst encoding, see v_movreld_b32 special case
@@ -2879,6 +2881,15 @@ def getVCMPXOpFromVCMP : InstrMapping {
28792881
let ValueCols = [["1"]];
28802882
}
28812883

2884+
def FP8DstByteSelTable : GenericTable {
2885+
let FilterClass = "VOP3_Pseudo";
2886+
let CppTypeName = "FP8DstByteSelInfo";
2887+
let Fields = ["Opcode", "HasFP8DstByteSel"];
2888+
2889+
let PrimaryKey = ["Opcode"];
2890+
let PrimaryKeyName = "getFP8DstByteSelHelper";
2891+
}
2892+
28822893
def VOPDComponentTable : GenericTable {
28832894
let FilterClass = "VOPD_Component";
28842895
let CppTypeName = "VOPDComponentInfo";

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,13 @@ struct VOPTrue16Info {
367367
bool IsTrue16;
368368
};
369369

370+
struct FP8DstByteSelInfo {
371+
uint16_t Opcode;
372+
bool HasFP8DstByteSel;
373+
};
374+
375+
#define GET_FP8DstByteSelTable_DECL
376+
#define GET_FP8DstByteSelTable_IMPL
370377
#define GET_MTBUFInfoTable_DECL
371378
#define GET_MTBUFInfoTable_IMPL
372379
#define GET_MUBUFInfoTable_DECL
@@ -591,6 +598,11 @@ bool isTrue16Inst(unsigned Opc) {
591598
return Info ? Info->IsTrue16 : false;
592599
}
593600

601+
bool isFP8DstSelInst(unsigned Opc) {
602+
const FP8DstByteSelInfo *Info = getFP8DstByteSelHelper(Opc);
603+
return Info ? Info->HasFP8DstByteSel : false;
604+
}
605+
594606
unsigned mapWMMA2AddrTo3AddrOpcode(unsigned Opc) {
595607
const WMMAOpcodeMappingInfo *Info = getWMMAMappingInfoFrom2AddrOpcode(Opc);
596608
return Info ? Info->Opcode3Addr : ~0u;

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ getVOPDInstInfo(unsigned VOPDOpcode, const MCInstrInfo *InstrInfo);
805805
LLVM_READONLY
806806
bool isTrue16Inst(unsigned Opc);
807807

808+
LLVM_READONLY
809+
bool isFP8DstSelInst(unsigned Opc);
810+
808811
LLVM_READONLY
809812
unsigned mapWMMA2AddrTo3AddrOpcode(unsigned Opc);
810813

llvm/lib/Target/AMDGPU/VOP3Instructions.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def VOP3_CVT_SR_F8_F32_Profile : VOP3_Profile<VOPProfile<[i32, f32, i32, f32]>,
564564
let HasSrc2Mods = 1;
565565
let HasExtVOP3DPP = 1;
566566
let HasOpSel = 1;
567+
let HasFP8DstByteSel = 1;
567568
let AsmVOP3OpSel = !subst(", $src2_modifiers", "",
568569
getAsmVOP3OpSel<3, HasClamp, HasOMod,
569570
HasSrc0FloatMods, HasSrc1FloatMods,

llvm/lib/Target/AMDGPU/VOPInstructions.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ class VOP3_Pseudo <string opName, VOPProfile P, list<dag> pattern = [],
126126
let IsWMMA = P.IsWMMA;
127127
let IsSWMMAC = P.IsSWMMAC;
128128

129+
bit HasFP8DstByteSel = P.HasFP8DstByteSel;
130+
129131
let AsmOperands = !if(isVop3OpSel,
130132
P.AsmVOP3OpSel,
131133
!if(!and(isVOP3P, P.IsPacked), P.AsmVOP3P, P.Asm64));

0 commit comments

Comments
 (0)