Skip to content

Commit e17b3cd

Browse files
rovkamihajlovicanaarsenm
authored
[AMDGPU] Dynamic VGPR support for llvm.amdgcn.cs.chain (#130094)
The llvm.amdgcn.cs.chain intrinsic has a 'flags' operand which may indicate that we want to reallocate the VGPRs before performing the call. A call with the following arguments: ``` llvm.amdgcn.cs.chain %callee, %exec, %sgpr_args, %vgpr_args, /*flags*/0x1, %num_vgprs, %fallback_exec, %fallback_callee ``` is supposed to do the following: - copy the SGPR and VGPR args into their respective registers - try to change the VGPR allocation - if the allocation has succeeded, set EXEC to %exec and jump to %callee, otherwise set EXEC to %fallback_exec and jump to %fallback_callee This patch implements the dynamic VGPR behaviour by generating an S_ALLOC_VGPR followed by S_CSELECT_B32/64 instructions for the EXEC and callee. The rest of the call sequence is left undisturbed (i.e. identical to the case where the flags are 0 and we don't use dynamic VGPRs). We achieve this by introducing some new pseudos (SI_CS_CHAIN_TC_Wn_DVGPR) which are expanded in the SILateBranchLowering pass, just like the simpler SI_CS_CHAIN_TC_Wn pseudos. The main reason is so that we don't risk other passes (particularly the PostRA scheduler) introducing instructions between the S_ALLOC_VGPR and the jump. Such instructions might end up using VGPRs that have been deallocated, or the wrong EXEC mask. Once the whole backend treats S_ALLOC_VGPR and changes to EXEC as barriers for instructions that use VGPRs, we could in principle move the expansion earlier (but in the absence of a good reason for that my personal preference is to keep it later in order to make debugging easier). Since the expansion happens after register allocation, we're careful to select constants to immediate operands instead of letting ISel generate S_MOVs which could interfere with register allocation (i.e. make it look like we need more registers than we actually do). For GFX12, S_ALLOC_VGPR only works in wave32 mode, so we bail out during ISel in wave64 mode. However, we can define the pseudos for wave64 too so it's easy to handle if future generations support it. --------- Co-authored-by: Ana Mihajlovic <[email protected]> Co-authored-by: Matt Arsenault <[email protected]>
1 parent 53a395f commit e17b3cd

15 files changed

+784
-156
lines changed

llvm/include/llvm/CodeGen/SelectionDAGISel.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,21 @@ class SelectionDAGISel {
328328
};
329329

330330
enum {
331-
OPFL_None = 0, // Node has no chain or glue input and isn't variadic.
332-
OPFL_Chain = 1, // Node has a chain input.
333-
OPFL_GlueInput = 2, // Node has a glue input.
334-
OPFL_GlueOutput = 4, // Node has a glue output.
335-
OPFL_MemRefs = 8, // Node gets accumulated MemRefs.
336-
OPFL_Variadic0 = 1<<4, // Node is variadic, root has 0 fixed inputs.
337-
OPFL_Variadic1 = 2<<4, // Node is variadic, root has 1 fixed inputs.
338-
OPFL_Variadic2 = 3<<4, // Node is variadic, root has 2 fixed inputs.
339-
OPFL_Variadic3 = 4<<4, // Node is variadic, root has 3 fixed inputs.
340-
OPFL_Variadic4 = 5<<4, // Node is variadic, root has 4 fixed inputs.
341-
OPFL_Variadic5 = 6<<4, // Node is variadic, root has 5 fixed inputs.
342-
OPFL_Variadic6 = 7<<4, // Node is variadic, root has 6 fixed inputs.
343-
344-
OPFL_VariadicInfo = OPFL_Variadic6
331+
OPFL_None = 0, // Node has no chain or glue input and isn't variadic.
332+
OPFL_Chain = 1, // Node has a chain input.
333+
OPFL_GlueInput = 2, // Node has a glue input.
334+
OPFL_GlueOutput = 4, // Node has a glue output.
335+
OPFL_MemRefs = 8, // Node gets accumulated MemRefs.
336+
OPFL_Variadic0 = 1 << 4, // Node is variadic, root has 0 fixed inputs.
337+
OPFL_Variadic1 = 2 << 4, // Node is variadic, root has 1 fixed inputs.
338+
OPFL_Variadic2 = 3 << 4, // Node is variadic, root has 2 fixed inputs.
339+
OPFL_Variadic3 = 4 << 4, // Node is variadic, root has 3 fixed inputs.
340+
OPFL_Variadic4 = 5 << 4, // Node is variadic, root has 4 fixed inputs.
341+
OPFL_Variadic5 = 6 << 4, // Node is variadic, root has 5 fixed inputs.
342+
OPFL_Variadic6 = 7 << 4, // Node is variadic, root has 6 fixed inputs.
343+
OPFL_Variadic7 = 8 << 4, // Node is variadic, root has 7 fixed inputs.
344+
345+
OPFL_VariadicInfo = 15 << 4 // Mask for extracting the OPFL_VariadicN bits.
345346
};
346347

347348
/// getNumFixedFromVariadicInfo - Transform an EmitNode flags word into the

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7946,10 +7946,6 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79467946
return;
79477947
}
79487948
case Intrinsic::amdgcn_cs_chain: {
7949-
assert(I.arg_size() == 5 && "Additional args not supported yet");
7950-
assert(cast<ConstantInt>(I.getOperand(4))->isZero() &&
7951-
"Non-zero flags not supported yet");
7952-
79537949
// At this point we don't care if it's amdgpu_cs_chain or
79547950
// amdgpu_cs_chain_preserve.
79557951
CallingConv::ID CC = CallingConv::AMDGPU_CS_Chain;
@@ -7976,6 +7972,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79767972
assert(!Args[1].IsInReg && "VGPR args should not be marked inreg");
79777973
Args[2].IsInReg = true; // EXEC should be inreg
79787974

7975+
// Forward the flags and any additional arguments.
7976+
for (unsigned Idx = 4; Idx < I.arg_size(); ++Idx) {
7977+
TargetLowering::ArgListEntry Arg;
7978+
Arg.Node = getValue(I.getOperand(Idx));
7979+
Arg.Ty = I.getOperand(Idx)->getType();
7980+
Arg.setAttributes(&I, Idx);
7981+
Args.push_back(Arg);
7982+
}
7983+
79797984
TargetLowering::CallLoweringInfo CLI(DAG);
79807985
CLI.setDebugLoc(getCurSDLoc())
79817986
.setChain(getRoot())

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -953,17 +953,22 @@ getAssignFnsForCC(CallingConv::ID CC, const SITargetLowering &TLI) {
953953
}
954954

955955
static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
956-
bool IsTailCall, bool isWave32,
957-
CallingConv::ID CC) {
956+
bool IsTailCall, bool IsWave32,
957+
CallingConv::ID CC,
958+
bool IsDynamicVGPRChainCall = false) {
958959
// For calls to amdgpu_cs_chain functions, the address is known to be uniform.
959960
assert((AMDGPU::isChainCC(CC) || !IsIndirect || !IsTailCall) &&
960961
"Indirect calls can't be tail calls, "
961962
"because the address can be divergent");
962963
if (!IsTailCall)
963964
return AMDGPU::G_SI_CALL;
964965

965-
if (AMDGPU::isChainCC(CC))
966-
return isWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
966+
if (AMDGPU::isChainCC(CC)) {
967+
if (IsDynamicVGPRChainCall)
968+
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32_DVGPR
969+
: AMDGPU::SI_CS_CHAIN_TC_W64_DVGPR;
970+
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
971+
}
967972

968973
return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
969974
AMDGPU::SI_TCRETURN;
@@ -972,7 +977,8 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
972977
// Add operands to call instruction to track the callee.
973978
static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
974979
MachineIRBuilder &MIRBuilder,
975-
AMDGPUCallLowering::CallLoweringInfo &Info) {
980+
AMDGPUCallLowering::CallLoweringInfo &Info,
981+
bool IsDynamicVGPRChainCall = false) {
976982
if (Info.Callee.isReg()) {
977983
CallInst.addReg(Info.Callee.getReg());
978984
CallInst.addImm(0);
@@ -983,7 +989,12 @@ static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
983989
auto Ptr = MIRBuilder.buildGlobalValue(
984990
LLT::pointer(GV->getAddressSpace(), 64), GV);
985991
CallInst.addReg(Ptr.getReg(0));
986-
CallInst.add(Info.Callee);
992+
993+
if (IsDynamicVGPRChainCall) {
994+
// DynamicVGPR chain calls are always indirect.
995+
CallInst.addImm(0);
996+
} else
997+
CallInst.add(Info.Callee);
987998
} else
988999
return false;
9891000

@@ -1177,6 +1188,18 @@ void AMDGPUCallLowering::handleImplicitCallArguments(
11771188
}
11781189
}
11791190

1191+
namespace {
1192+
// Chain calls have special arguments that we need to handle. These have the
1193+
// same index as they do in the llvm.amdgcn.cs.chain intrinsic.
1194+
enum ChainCallArgIdx {
1195+
Exec = 1,
1196+
Flags = 4,
1197+
NumVGPRs = 5,
1198+
FallbackExec = 6,
1199+
FallbackCallee = 7,
1200+
};
1201+
} // anonymous namespace
1202+
11801203
bool AMDGPUCallLowering::lowerTailCall(
11811204
MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info,
11821205
SmallVectorImpl<ArgInfo> &OutArgs) const {
@@ -1185,6 +1208,8 @@ bool AMDGPUCallLowering::lowerTailCall(
11851208
SIMachineFunctionInfo *FuncInfo = MF.getInfo<SIMachineFunctionInfo>();
11861209
const Function &F = MF.getFunction();
11871210
MachineRegisterInfo &MRI = MF.getRegInfo();
1211+
const SIInstrInfo *TII = ST.getInstrInfo();
1212+
const SIRegisterInfo *TRI = ST.getRegisterInfo();
11881213
const SITargetLowering &TLI = *getTLI<SITargetLowering>();
11891214

11901215
// True when we're tail calling, but without -tailcallopt.
@@ -1200,34 +1225,79 @@ bool AMDGPUCallLowering::lowerTailCall(
12001225
if (!IsSibCall)
12011226
CallSeqStart = MIRBuilder.buildInstr(AMDGPU::ADJCALLSTACKUP);
12021227

1203-
unsigned Opc =
1204-
getCallOpcode(MF, Info.Callee.isReg(), true, ST.isWave32(), CalleeCC);
1228+
bool IsChainCall = AMDGPU::isChainCC(Info.CallConv);
1229+
bool IsDynamicVGPRChainCall = false;
1230+
1231+
if (IsChainCall) {
1232+
ArgInfo FlagsArg = Info.OrigArgs[ChainCallArgIdx::Flags];
1233+
const APInt &FlagsValue = cast<ConstantInt>(FlagsArg.OrigValue)->getValue();
1234+
if (FlagsValue.isZero()) {
1235+
if (Info.OrigArgs.size() != 5) {
1236+
LLVM_DEBUG(dbgs() << "No additional args allowed if flags == 0\n");
1237+
return false;
1238+
}
1239+
} else if (FlagsValue.isOneBitSet(0)) {
1240+
IsDynamicVGPRChainCall = true;
1241+
1242+
if (Info.OrigArgs.size() != 8) {
1243+
LLVM_DEBUG(dbgs() << "Expected 3 additional args\n");
1244+
return false;
1245+
}
1246+
1247+
// On GFX12, we can only change the VGPR allocation for wave32.
1248+
if (!ST.isWave32()) {
1249+
F.getContext().diagnose(DiagnosticInfoUnsupported(
1250+
F, "dynamic VGPR mode is only supported for wave32"));
1251+
return false;
1252+
}
1253+
1254+
ArgInfo FallbackExecArg = Info.OrigArgs[ChainCallArgIdx::FallbackExec];
1255+
assert(FallbackExecArg.Regs.size() == 1 &&
1256+
"Expected single register for fallback EXEC");
1257+
if (!FallbackExecArg.Ty->isIntegerTy(ST.getWavefrontSize())) {
1258+
LLVM_DEBUG(dbgs() << "Bad type for fallback EXEC\n");
1259+
return false;
1260+
}
1261+
}
1262+
}
1263+
1264+
unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true,
1265+
ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall);
12051266
auto MIB = MIRBuilder.buildInstrNoInsert(Opc);
1206-
if (!addCallTargetOperands(MIB, MIRBuilder, Info))
1267+
if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall))
12071268
return false;
12081269

12091270
// Byte offset for the tail call. When we are sibcalling, this will always
12101271
// be 0.
12111272
MIB.addImm(0);
12121273

1213-
// If this is a chain call, we need to pass in the EXEC mask.
1214-
const SIRegisterInfo *TRI = ST.getRegisterInfo();
1215-
if (AMDGPU::isChainCC(Info.CallConv)) {
1216-
ArgInfo ExecArg = Info.OrigArgs[1];
1274+
// If this is a chain call, we need to pass in the EXEC mask as well as any
1275+
// other special args.
1276+
if (IsChainCall) {
1277+
auto AddRegOrImm = [&](const ArgInfo &Arg) {
1278+
if (auto CI = dyn_cast<ConstantInt>(Arg.OrigValue)) {
1279+
MIB.addImm(CI->getSExtValue());
1280+
} else {
1281+
MIB.addReg(Arg.Regs[0]);
1282+
unsigned Idx = MIB->getNumOperands() - 1;
1283+
MIB->getOperand(Idx).setReg(constrainOperandRegClass(
1284+
MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(),
1285+
MIB->getOperand(Idx), Idx));
1286+
}
1287+
};
1288+
1289+
ArgInfo ExecArg = Info.OrigArgs[ChainCallArgIdx::Exec];
12171290
assert(ExecArg.Regs.size() == 1 && "Too many regs for EXEC");
12181291

1219-
if (!ExecArg.Ty->isIntegerTy(ST.getWavefrontSize()))
1292+
if (!ExecArg.Ty->isIntegerTy(ST.getWavefrontSize())) {
1293+
LLVM_DEBUG(dbgs() << "Bad type for EXEC");
12201294
return false;
1221-
1222-
if (const auto *CI = dyn_cast<ConstantInt>(ExecArg.OrigValue)) {
1223-
MIB.addImm(CI->getSExtValue());
1224-
} else {
1225-
MIB.addReg(ExecArg.Regs[0]);
1226-
unsigned Idx = MIB->getNumOperands() - 1;
1227-
MIB->getOperand(Idx).setReg(constrainOperandRegClass(
1228-
MF, *TRI, MRI, *ST.getInstrInfo(), *ST.getRegBankInfo(), *MIB,
1229-
MIB->getDesc(), MIB->getOperand(Idx), Idx));
12301295
}
1296+
1297+
AddRegOrImm(ExecArg);
1298+
if (IsDynamicVGPRChainCall)
1299+
std::for_each(Info.OrigArgs.begin() + ChainCallArgIdx::NumVGPRs,
1300+
Info.OrigArgs.end(), AddRegOrImm);
12311301
}
12321302

12331303
// Tell the call which registers are clobbered.
@@ -1329,9 +1399,9 @@ bool AMDGPUCallLowering::lowerTailCall(
13291399
// FIXME: We should define regbankselectable call instructions to handle
13301400
// divergent call targets.
13311401
if (MIB->getOperand(0).isReg()) {
1332-
MIB->getOperand(0).setReg(constrainOperandRegClass(
1333-
MF, *TRI, MRI, *ST.getInstrInfo(), *ST.getRegBankInfo(), *MIB,
1334-
MIB->getDesc(), MIB->getOperand(0), 0));
1402+
MIB->getOperand(0).setReg(
1403+
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1404+
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
13351405
}
13361406

13371407
MF.getFrameInfo().setHasTailCall();
@@ -1345,11 +1415,6 @@ bool AMDGPUCallLowering::lowerChainCall(MachineIRBuilder &MIRBuilder,
13451415
ArgInfo Callee = Info.OrigArgs[0];
13461416
ArgInfo SGPRArgs = Info.OrigArgs[2];
13471417
ArgInfo VGPRArgs = Info.OrigArgs[3];
1348-
ArgInfo Flags = Info.OrigArgs[4];
1349-
1350-
assert(cast<ConstantInt>(Flags.OrigValue)->isZero() &&
1351-
"Non-zero flags aren't supported yet.");
1352-
assert(Info.OrigArgs.size() == 5 && "Additional args aren't supported yet.");
13531418

13541419
MachineFunction &MF = MIRBuilder.getMF();
13551420
const Function &F = MF.getFunction();

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5492,6 +5492,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
54925492
NODE_NAME_CASE(TC_RETURN)
54935493
NODE_NAME_CASE(TC_RETURN_GFX)
54945494
NODE_NAME_CASE(TC_RETURN_CHAIN)
5495+
NODE_NAME_CASE(TC_RETURN_CHAIN_DVGPR)
54955496
NODE_NAME_CASE(TRAP)
54965497
NODE_NAME_CASE(RET_GLUE)
54975498
NODE_NAME_CASE(WAVE_ADDRESS)

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ enum NodeType : unsigned {
402402
TC_RETURN,
403403
TC_RETURN_GFX,
404404
TC_RETURN_CHAIN,
405+
TC_RETURN_CHAIN_DVGPR,
405406
TRAP,
406407

407408
// Masked control flow nodes.

llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def AMDGPUtc_return_chain: SDNode<"AMDGPUISD::TC_RETURN_CHAIN",
9999
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
100100
>;
101101

102+
// With dynamic VGPRs.
103+
def AMDGPUtc_return_chain_dvgpr: SDNode<"AMDGPUISD::TC_RETURN_CHAIN_DVGPR",
104+
SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>,
105+
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
106+
>;
107+
102108
def AMDGPUtrap : SDNode<"AMDGPUISD::TRAP",
103109
SDTypeProfile<0, 1, [SDTCisVT<0, i16>]>,
104110
[SDNPHasChain, SDNPVariadic, SDNPSideEffect, SDNPOptInGlue]

0 commit comments

Comments
 (0)