Skip to content

Commit cbadf87

Browse files
committed
[AMDGPU] Dynamic VGPR support for llvm.amdgcn.cs.chain llvm#130094
1 parent b9e094e commit cbadf87

11 files changed

+751
-155
lines changed

llvm/include/llvm/CodeGen/SelectionDAGISel.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,21 @@ class SelectionDAGISel {
328328
};
329329

330330
enum {
331-
OPFL_None = 0, // Node has no chain or glue input and isn't variadic.
332-
OPFL_Chain = 1, // Node has a chain input.
333-
OPFL_GlueInput = 2, // Node has a glue input.
334-
OPFL_GlueOutput = 4, // Node has a glue output.
335-
OPFL_MemRefs = 8, // Node gets accumulated MemRefs.
336-
OPFL_Variadic0 = 1<<4, // Node is variadic, root has 0 fixed inputs.
337-
OPFL_Variadic1 = 2<<4, // Node is variadic, root has 1 fixed inputs.
338-
OPFL_Variadic2 = 3<<4, // Node is variadic, root has 2 fixed inputs.
339-
OPFL_Variadic3 = 4<<4, // Node is variadic, root has 3 fixed inputs.
340-
OPFL_Variadic4 = 5<<4, // Node is variadic, root has 4 fixed inputs.
341-
OPFL_Variadic5 = 6<<4, // Node is variadic, root has 5 fixed inputs.
342-
OPFL_Variadic6 = 7<<4, // Node is variadic, root has 6 fixed inputs.
343-
344-
OPFL_VariadicInfo = OPFL_Variadic6
331+
OPFL_None = 0, // Node has no chain or glue input and isn't variadic.
332+
OPFL_Chain = 1, // Node has a chain input.
333+
OPFL_GlueInput = 2, // Node has a glue input.
334+
OPFL_GlueOutput = 4, // Node has a glue output.
335+
OPFL_MemRefs = 8, // Node gets accumulated MemRefs.
336+
OPFL_Variadic0 = 1 << 4, // Node is variadic, root has 0 fixed inputs.
337+
OPFL_Variadic1 = 2 << 4, // Node is variadic, root has 1 fixed inputs.
338+
OPFL_Variadic2 = 3 << 4, // Node is variadic, root has 2 fixed inputs.
339+
OPFL_Variadic3 = 4 << 4, // Node is variadic, root has 3 fixed inputs.
340+
OPFL_Variadic4 = 5 << 4, // Node is variadic, root has 4 fixed inputs.
341+
OPFL_Variadic5 = 6 << 4, // Node is variadic, root has 5 fixed inputs.
342+
OPFL_Variadic6 = 7 << 4, // Node is variadic, root has 6 fixed inputs.
343+
OPFL_Variadic7 = 8 << 4, // Node is variadic, root has 7 fixed inputs.
344+
345+
OPFL_VariadicInfo = 15 << 4 // Mask for extracting the OPFL_VariadicN bits.
345346
};
346347

347348
/// getNumFixedFromVariadicInfo - Transform an EmitNode flags word into the

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7995,10 +7995,6 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79957995
return;
79967996
}
79977997
case Intrinsic::amdgcn_cs_chain: {
7998-
assert(I.arg_size() == 5 && "Additional args not supported yet");
7999-
assert(cast<ConstantInt>(I.getOperand(4))->isZero() &&
8000-
"Non-zero flags not supported yet");
8001-
80027998
// At this point we don't care if it's amdgpu_cs_chain or
80037999
// amdgpu_cs_chain_preserve.
80048000
CallingConv::ID CC = CallingConv::AMDGPU_CS_Chain;
@@ -8025,6 +8021,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
80258021
assert(!Args[1].IsInReg && "VGPR args should not be marked inreg");
80268022
Args[2].IsInReg = true; // EXEC should be inreg
80278023

8024+
// Forward the flags and any additional arguments.
8025+
for (unsigned Idx = 4; Idx < I.arg_size(); ++Idx) {
8026+
TargetLowering::ArgListEntry Arg;
8027+
Arg.Node = getValue(I.getOperand(Idx));
8028+
Arg.Ty = I.getOperand(Idx)->getType();
8029+
Arg.setAttributes(&I, Idx);
8030+
Args.push_back(Arg);
8031+
}
8032+
80288033
TargetLowering::CallLoweringInfo CLI(DAG);
80298034
CLI.setDebugLoc(getCurSDLoc())
80308035
.setChain(getRoot())

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -952,17 +952,22 @@ getAssignFnsForCC(CallingConv::ID CC, const SITargetLowering &TLI) {
952952
}
953953

954954
static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
955-
bool IsTailCall, bool isWave32,
956-
CallingConv::ID CC) {
955+
bool IsTailCall, bool IsWave32,
956+
CallingConv::ID CC,
957+
bool IsDynamicVGPRChainCall = false) {
957958
// For calls to amdgpu_cs_chain functions, the address is known to be uniform.
958959
assert((AMDGPU::isChainCC(CC) || !IsIndirect || !IsTailCall) &&
959960
"Indirect calls can't be tail calls, "
960961
"because the address can be divergent");
961962
if (!IsTailCall)
962963
return AMDGPU::G_SI_CALL;
963964

964-
if (AMDGPU::isChainCC(CC))
965-
return isWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
965+
if (AMDGPU::isChainCC(CC)) {
966+
if (IsDynamicVGPRChainCall)
967+
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32_DVGPR
968+
: AMDGPU::SI_CS_CHAIN_TC_W64_DVGPR;
969+
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
970+
}
966971

967972
return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
968973
AMDGPU::SI_TCRETURN;
@@ -971,7 +976,8 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
971976
// Add operands to call instruction to track the callee.
972977
static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
973978
MachineIRBuilder &MIRBuilder,
974-
AMDGPUCallLowering::CallLoweringInfo &Info) {
979+
AMDGPUCallLowering::CallLoweringInfo &Info,
980+
bool IsDynamicVGPRChainCall = false) {
975981
if (Info.Callee.isReg()) {
976982
CallInst.addReg(Info.Callee.getReg());
977983
CallInst.addImm(0);
@@ -982,7 +988,12 @@ static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
982988
auto Ptr = MIRBuilder.buildGlobalValue(
983989
LLT::pointer(GV->getAddressSpace(), 64), GV);
984990
CallInst.addReg(Ptr.getReg(0));
985-
CallInst.add(Info.Callee);
991+
992+
if (IsDynamicVGPRChainCall) {
993+
// DynamicVGPR chain calls are always indirect.
994+
CallInst.addImm(0);
995+
} else
996+
CallInst.add(Info.Callee);
986997
} else
987998
return false;
988999

@@ -1176,6 +1187,18 @@ void AMDGPUCallLowering::handleImplicitCallArguments(
11761187
}
11771188
}
11781189

1190+
namespace {
1191+
// Chain calls have special arguments that we need to handle. These have the
1192+
// same index as they do in the llvm.amdgcn.cs.chain intrinsic.
1193+
enum ChainCallArgIdx {
1194+
Exec = 1,
1195+
Flags = 4,
1196+
NumVGPRs = 5,
1197+
FallbackExec = 6,
1198+
FallbackCallee = 7,
1199+
};
1200+
} // anonymous namespace
1201+
11791202
bool AMDGPUCallLowering::lowerTailCall(
11801203
MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info,
11811204
SmallVectorImpl<ArgInfo> &OutArgs) const {
@@ -1184,6 +1207,8 @@ bool AMDGPUCallLowering::lowerTailCall(
11841207
SIMachineFunctionInfo *FuncInfo = MF.getInfo<SIMachineFunctionInfo>();
11851208
const Function &F = MF.getFunction();
11861209
MachineRegisterInfo &MRI = MF.getRegInfo();
1210+
const SIInstrInfo *TII = ST.getInstrInfo();
1211+
const SIRegisterInfo *TRI = ST.getRegisterInfo();
11871212
const SITargetLowering &TLI = *getTLI<SITargetLowering>();
11881213

11891214
// True when we're tail calling, but without -tailcallopt.
@@ -1199,34 +1224,79 @@ bool AMDGPUCallLowering::lowerTailCall(
11991224
if (!IsSibCall)
12001225
CallSeqStart = MIRBuilder.buildInstr(AMDGPU::ADJCALLSTACKUP);
12011226

1202-
unsigned Opc =
1203-
getCallOpcode(MF, Info.Callee.isReg(), true, ST.isWave32(), CalleeCC);
1227+
bool IsChainCall = AMDGPU::isChainCC(Info.CallConv);
1228+
bool IsDynamicVGPRChainCall = false;
1229+
1230+
if (IsChainCall) {
1231+
ArgInfo FlagsArg = Info.OrigArgs[ChainCallArgIdx::Flags];
1232+
const APInt &FlagsValue = cast<ConstantInt>(FlagsArg.OrigValue)->getValue();
1233+
if (FlagsValue.isZero()) {
1234+
if (Info.OrigArgs.size() != 5) {
1235+
LLVM_DEBUG(dbgs() << "No additional args allowed if flags == 0\n");
1236+
return false;
1237+
}
1238+
} else if (FlagsValue.isOneBitSet(0)) {
1239+
IsDynamicVGPRChainCall = true;
1240+
1241+
if (Info.OrigArgs.size() != 8) {
1242+
LLVM_DEBUG(dbgs() << "Expected 3 additional args");
1243+
return false;
1244+
}
1245+
1246+
// On GFX12, we can only change the VGPR allocation for wave32.
1247+
if (!ST.isWave32()) {
1248+
F.getContext().diagnose(DiagnosticInfoUnsupported(
1249+
F, "Dynamic VGPR mode is only supported for wave32\n"));
1250+
return false;
1251+
}
1252+
1253+
ArgInfo FallbackExecArg = Info.OrigArgs[ChainCallArgIdx::FallbackExec];
1254+
assert(FallbackExecArg.Regs.size() == 1 &&
1255+
"Expected single register for fallback EXEC");
1256+
if (!FallbackExecArg.Ty->isIntegerTy(ST.getWavefrontSize())) {
1257+
LLVM_DEBUG(dbgs() << "Bad type for fallback EXEC");
1258+
return false;
1259+
}
1260+
}
1261+
}
1262+
1263+
unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true,
1264+
ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall);
12041265
auto MIB = MIRBuilder.buildInstrNoInsert(Opc);
1205-
if (!addCallTargetOperands(MIB, MIRBuilder, Info))
1266+
if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall))
12061267
return false;
12071268

12081269
// Byte offset for the tail call. When we are sibcalling, this will always
12091270
// be 0.
12101271
MIB.addImm(0);
12111272

1212-
// If this is a chain call, we need to pass in the EXEC mask.
1213-
const SIRegisterInfo *TRI = ST.getRegisterInfo();
1214-
if (AMDGPU::isChainCC(Info.CallConv)) {
1215-
ArgInfo ExecArg = Info.OrigArgs[1];
1273+
// If this is a chain call, we need to pass in the EXEC mask as well as any
1274+
// other special args.
1275+
if (IsChainCall) {
1276+
auto AddRegOrImm = [&](const ArgInfo &Arg) {
1277+
if (auto CI = dyn_cast<ConstantInt>(Arg.OrigValue)) {
1278+
MIB.addImm(CI->getSExtValue());
1279+
} else {
1280+
MIB.addReg(Arg.Regs[0]);
1281+
unsigned Idx = MIB->getNumOperands() - 1;
1282+
MIB->getOperand(Idx).setReg(constrainOperandRegClass(
1283+
MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(),
1284+
MIB->getOperand(Idx), Idx));
1285+
}
1286+
};
1287+
1288+
ArgInfo ExecArg = Info.OrigArgs[ChainCallArgIdx::Exec];
12161289
assert(ExecArg.Regs.size() == 1 && "Too many regs for EXEC");
12171290

1218-
if (!ExecArg.Ty->isIntegerTy(ST.getWavefrontSize()))
1291+
if (!ExecArg.Ty->isIntegerTy(ST.getWavefrontSize())) {
1292+
LLVM_DEBUG(dbgs() << "Bad type for EXEC");
12191293
return false;
1220-
1221-
if (const auto *CI = dyn_cast<ConstantInt>(ExecArg.OrigValue)) {
1222-
MIB.addImm(CI->getSExtValue());
1223-
} else {
1224-
MIB.addReg(ExecArg.Regs[0]);
1225-
unsigned Idx = MIB->getNumOperands() - 1;
1226-
MIB->getOperand(Idx).setReg(constrainOperandRegClass(
1227-
MF, *TRI, MRI, *ST.getInstrInfo(), *ST.getRegBankInfo(), *MIB,
1228-
MIB->getDesc(), MIB->getOperand(Idx), Idx));
12291294
}
1295+
1296+
AddRegOrImm(ExecArg);
1297+
if (IsDynamicVGPRChainCall)
1298+
std::for_each(Info.OrigArgs.begin() + ChainCallArgIdx::NumVGPRs,
1299+
Info.OrigArgs.end(), AddRegOrImm);
12301300
}
12311301

12321302
// Tell the call which registers are clobbered.
@@ -1328,9 +1398,9 @@ bool AMDGPUCallLowering::lowerTailCall(
13281398
// FIXME: We should define regbankselectable call instructions to handle
13291399
// divergent call targets.
13301400
if (MIB->getOperand(0).isReg()) {
1331-
MIB->getOperand(0).setReg(constrainOperandRegClass(
1332-
MF, *TRI, MRI, *ST.getInstrInfo(), *ST.getRegBankInfo(), *MIB,
1333-
MIB->getDesc(), MIB->getOperand(0), 0));
1401+
MIB->getOperand(0).setReg(
1402+
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1403+
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
13341404
}
13351405

13361406
MF.getFrameInfo().setHasTailCall();
@@ -1344,11 +1414,6 @@ bool AMDGPUCallLowering::lowerChainCall(MachineIRBuilder &MIRBuilder,
13441414
ArgInfo Callee = Info.OrigArgs[0];
13451415
ArgInfo SGPRArgs = Info.OrigArgs[2];
13461416
ArgInfo VGPRArgs = Info.OrigArgs[3];
1347-
ArgInfo Flags = Info.OrigArgs[4];
1348-
1349-
assert(cast<ConstantInt>(Flags.OrigValue)->isZero() &&
1350-
"Non-zero flags aren't supported yet.");
1351-
assert(Info.OrigArgs.size() == 5 && "Additional args aren't supported yet.");
13521417

13531418
MachineFunction &MF = MIRBuilder.getMF();
13541419
const Function &F = MF.getFunction();

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3684,6 +3684,19 @@ bool SITargetLowering::mayBeEmittedAsTailCall(const CallInst *CI) const {
36843684
return true;
36853685
}
36863686

3687+
namespace {
3688+
// Chain calls have special arguments that we need to handle. These are
3689+
// tagging along at the end of the arguments list(s), after the SGPR and VGPR
3690+
// arguments (index 0 and 1 respectively).
3691+
enum ChainCallArgIdx {
3692+
Exec = 2,
3693+
Flags,
3694+
NumVGPRs,
3695+
FallbackExec,
3696+
FallbackCallee
3697+
};
3698+
} // anonymous namespace
3699+
36873700
// The wave scratch offset register is used as the global base pointer.
36883701
SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
36893702
SmallVectorImpl<SDValue> &InVals) const {
@@ -3692,37 +3705,67 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
36923705

36933706
SelectionDAG &DAG = CLI.DAG;
36943707

3695-
TargetLowering::ArgListEntry RequestedExec;
3696-
if (IsChainCallConv) {
3697-
// The last argument should be the value that we need to put in EXEC.
3698-
// Pop it out of CLI.Outs and CLI.OutVals before we do any processing so we
3699-
// don't treat it like the rest of the arguments.
3700-
RequestedExec = CLI.Args.back();
3701-
assert(RequestedExec.Node && "No node for EXEC");
3708+
const SDLoc &DL = CLI.DL;
3709+
SDValue Chain = CLI.Chain;
3710+
SDValue Callee = CLI.Callee;
37023711

3703-
if (!RequestedExec.Ty->isIntegerTy(Subtarget->getWavefrontSize()))
3712+
llvm::SmallVector<SDValue, 6> ChainCallSpecialArgs;
3713+
if (IsChainCallConv) {
3714+
// The last arguments should be the value that we need to put in EXEC,
3715+
// followed by the flags and any other arguments with special meanings.
3716+
// Pop them out of CLI.Outs and CLI.OutVals before we do any processing so
3717+
// we don't treat them like the "real" arguments.
3718+
auto RequestedExecIt = std::find_if(
3719+
CLI.Outs.begin(), CLI.Outs.end(),
3720+
[](const ISD::OutputArg &Arg) { return Arg.OrigArgIndex == 2; });
3721+
assert(RequestedExecIt != CLI.Outs.end() && "No node for EXEC");
3722+
3723+
size_t SpecialArgsBeginIdx = RequestedExecIt - CLI.Outs.begin();
3724+
CLI.OutVals.erase(CLI.OutVals.begin() + SpecialArgsBeginIdx,
3725+
CLI.OutVals.end());
3726+
CLI.Outs.erase(RequestedExecIt, CLI.Outs.end());
3727+
3728+
assert(CLI.Outs.back().OrigArgIndex < 2 &&
3729+
"Haven't popped all the special args");
3730+
3731+
TargetLowering::ArgListEntry RequestedExecArg =
3732+
CLI.Args[ChainCallArgIdx::Exec];
3733+
if (!RequestedExecArg.Ty->isIntegerTy(Subtarget->getWavefrontSize()))
37043734
return lowerUnhandledCall(CLI, InVals, "Invalid value for EXEC");
37053735

3706-
assert(CLI.Outs.back().OrigArgIndex == 2 && "Unexpected last arg");
3707-
CLI.Outs.pop_back();
3708-
CLI.OutVals.pop_back();
3736+
// Convert constants into TargetConstants, so they become immediate operands
3737+
// instead of being selected into S_MOV.
3738+
auto PushNodeOrTargetConstant = [&](TargetLowering::ArgListEntry Arg) {
3739+
if (auto ArgNode = dyn_cast<ConstantSDNode>(Arg.Node))
3740+
ChainCallSpecialArgs.push_back(DAG.getTargetConstant(
3741+
ArgNode->getAPIntValue(), DL, ArgNode->getValueType(0)));
3742+
else
3743+
ChainCallSpecialArgs.push_back(Arg.Node);
3744+
};
37093745

3710-
if (RequestedExec.Ty->isIntegerTy(64)) {
3711-
assert(CLI.Outs.back().OrigArgIndex == 2 && "Exec wasn't split up");
3712-
CLI.Outs.pop_back();
3713-
CLI.OutVals.pop_back();
3714-
}
3746+
PushNodeOrTargetConstant(RequestedExecArg);
3747+
3748+
// Process any other special arguments depending on the value of the flags.
3749+
TargetLowering::ArgListEntry Flags = CLI.Args[ChainCallArgIdx::Flags];
3750+
3751+
const APInt &FlagsValue = cast<ConstantSDNode>(Flags.Node)->getAPIntValue();
3752+
if (FlagsValue.isZero()) {
3753+
if (CLI.Args.size() > ChainCallArgIdx::Flags + 1)
3754+
return lowerUnhandledCall(CLI, InVals,
3755+
"No additional args allowed if flags == 0");
3756+
} else if (FlagsValue.isOneBitSet(0)) {
3757+
if (CLI.Args.size() != ChainCallArgIdx::FallbackCallee + 1) {
3758+
return lowerUnhandledCall(CLI, InVals, "Expected 3 additional args");
3759+
}
37153760

3716-
assert(CLI.Outs.back().OrigArgIndex != 2 &&
3717-
"Haven't popped all the pieces of the EXEC mask");
3761+
std::for_each(CLI.Args.begin() + ChainCallArgIdx::NumVGPRs,
3762+
CLI.Args.end(), PushNodeOrTargetConstant);
3763+
}
37183764
}
37193765

3720-
const SDLoc &DL = CLI.DL;
37213766
SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
37223767
SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
37233768
SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
3724-
SDValue Chain = CLI.Chain;
3725-
SDValue Callee = CLI.Callee;
37263769
bool &IsTailCall = CLI.IsTailCall;
37273770
bool IsVarArg = CLI.IsVarArg;
37283771
bool IsSibCall = false;
@@ -4010,7 +4053,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
40104053
}
40114054

40124055
if (IsChainCallConv)
4013-
Ops.push_back(RequestedExec.Node);
4056+
Ops.insert(Ops.end(), ChainCallSpecialArgs.begin(),
4057+
ChainCallSpecialArgs.end());
40144058

40154059
// Add argument registers to the end of the list so that they are known live
40164060
// into the call.

0 commit comments

Comments
 (0)