Skip to content

[AMDGPU] Dynamic VGPR support for llvm.amdgcn.cs.chain #130094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3cacd07
[AMDGPU] Add GFX12 S_ALLOC_VGPR instruction
jasilvanus Mar 30, 2023
b2a7bdc
[AMDGPU] Add SubtargetFeature for dynamic VGPR mode
rovka Oct 23, 2023
c29d820
[AMDGPU] Deallocate VGPRs before exiting in dynamic VGPR mode
rovka Oct 23, 2023
aff1e13
[AMDGPU] Dynamic VGPR support for llvm.amdgcn.cs.chain
rovka Oct 10, 2023
84fe000
Update llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
rovka Mar 10, 2025
21758f3
Update llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
rovka Mar 10, 2025
1a99cab
Update llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
rovka Mar 10, 2025
296a9db
Update llvm/lib/Target/AMDGPU/SILateBranchLowering.cpp
rovka Mar 10, 2025
eb9955e
Remove wave size mattr from tests
rovka Mar 10, 2025
7decd1d
Diagnose unsupported
rovka Mar 10, 2025
e063564
debug loc & s/unsigned/int
rovka Mar 10, 2025
51d1111
Fix tablegen indent
rovka Mar 11, 2025
9ea5a9c
Explain removal of op flags
rovka Mar 11, 2025
cd932e6
Use specific ISD node for dvgpr case
rovka Mar 11, 2025
8e17cbe
Update comment in llvm/lib/Target/AMDGPU/SILateBranchLowering.cpp
rovka Mar 11, 2025
776eb73
Fixup formatting
rovka Mar 11, 2025
c7a45e6
Update error msg in llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
rovka Mar 12, 2025
6f26c9f
Update err msg in llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
rovka Mar 12, 2025
a30443c
s/addOperand/add llvm/lib/Target/AMDGPU/SILateBranchLowering.cpp
rovka Mar 12, 2025
f685823
Update dbg string llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
rovka Mar 13, 2025
1d610fd
Test for unsupported dgvpr on w64
rovka Mar 13, 2025
b848934
Capitalization
rovka Mar 14, 2025
d6c4844
Style issues in llvm/lib/Target/AMDGPU/SIISelLowering.cpp
rovka Mar 19, 2025
586679e
Merge branch 'main' into users/rovka/dvgpr-4
rovka Mar 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,21 @@ class SelectionDAGISel {
};

enum {
OPFL_None = 0, // Node has no chain or glue input and isn't variadic.
OPFL_Chain = 1, // Node has a chain input.
OPFL_GlueInput = 2, // Node has a glue input.
OPFL_GlueOutput = 4, // Node has a glue output.
OPFL_MemRefs = 8, // Node gets accumulated MemRefs.
OPFL_Variadic0 = 1<<4, // Node is variadic, root has 0 fixed inputs.
OPFL_Variadic1 = 2<<4, // Node is variadic, root has 1 fixed inputs.
OPFL_Variadic2 = 3<<4, // Node is variadic, root has 2 fixed inputs.
OPFL_Variadic3 = 4<<4, // Node is variadic, root has 3 fixed inputs.
OPFL_Variadic4 = 5<<4, // Node is variadic, root has 4 fixed inputs.
OPFL_Variadic5 = 6<<4, // Node is variadic, root has 5 fixed inputs.
OPFL_Variadic6 = 7<<4, // Node is variadic, root has 6 fixed inputs.

OPFL_VariadicInfo = OPFL_Variadic6
OPFL_None = 0, // Node has no chain or glue input and isn't variadic.
OPFL_Chain = 1, // Node has a chain input.
OPFL_GlueInput = 2, // Node has a glue input.
OPFL_GlueOutput = 4, // Node has a glue output.
OPFL_MemRefs = 8, // Node gets accumulated MemRefs.
OPFL_Variadic0 = 1 << 4, // Node is variadic, root has 0 fixed inputs.
OPFL_Variadic1 = 2 << 4, // Node is variadic, root has 1 fixed inputs.
OPFL_Variadic2 = 3 << 4, // Node is variadic, root has 2 fixed inputs.
OPFL_Variadic3 = 4 << 4, // Node is variadic, root has 3 fixed inputs.
OPFL_Variadic4 = 5 << 4, // Node is variadic, root has 4 fixed inputs.
OPFL_Variadic5 = 6 << 4, // Node is variadic, root has 5 fixed inputs.
OPFL_Variadic6 = 7 << 4, // Node is variadic, root has 6 fixed inputs.
OPFL_Variadic7 = 8 << 4, // Node is variadic, root has 7 fixed inputs.

OPFL_VariadicInfo = 15 << 4 // Mask for extracting the OPFL_VariadicN bits.
};

/// getNumFixedFromVariadicInfo - Transform an EmitNode flags word into the
Expand Down
13 changes: 9 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7946,10 +7946,6 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::amdgcn_cs_chain: {
assert(I.arg_size() == 5 && "Additional args not supported yet");
assert(cast<ConstantInt>(I.getOperand(4))->isZero() &&
"Non-zero flags not supported yet");

// At this point we don't care if it's amdgpu_cs_chain or
// amdgpu_cs_chain_preserve.
CallingConv::ID CC = CallingConv::AMDGPU_CS_Chain;
Expand All @@ -7976,6 +7972,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
assert(!Args[1].IsInReg && "VGPR args should not be marked inreg");
Args[2].IsInReg = true; // EXEC should be inreg

// Forward the flags and any additional arguments.
for (unsigned Idx = 4; Idx < I.arg_size(); ++Idx) {
TargetLowering::ArgListEntry Arg;
Arg.Node = getValue(I.getOperand(Idx));
Arg.Ty = I.getOperand(Idx)->getType();
Arg.setAttributes(&I, Idx);
Args.push_back(Arg);
}

TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(getCurSDLoc())
.setChain(getRoot())
Expand Down
127 changes: 96 additions & 31 deletions llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,17 +953,22 @@ getAssignFnsForCC(CallingConv::ID CC, const SITargetLowering &TLI) {
}

static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
bool IsTailCall, bool isWave32,
CallingConv::ID CC) {
bool IsTailCall, bool IsWave32,
CallingConv::ID CC,
bool IsDynamicVGPRChainCall = false) {
// For calls to amdgpu_cs_chain functions, the address is known to be uniform.
assert((AMDGPU::isChainCC(CC) || !IsIndirect || !IsTailCall) &&
"Indirect calls can't be tail calls, "
"because the address can be divergent");
if (!IsTailCall)
return AMDGPU::G_SI_CALL;

if (AMDGPU::isChainCC(CC))
return isWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
if (AMDGPU::isChainCC(CC)) {
if (IsDynamicVGPRChainCall)
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32_DVGPR
: AMDGPU::SI_CS_CHAIN_TC_W64_DVGPR;
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
}

return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
AMDGPU::SI_TCRETURN;
Expand All @@ -972,7 +977,8 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
// Add operands to call instruction to track the callee.
static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
MachineIRBuilder &MIRBuilder,
AMDGPUCallLowering::CallLoweringInfo &Info) {
AMDGPUCallLowering::CallLoweringInfo &Info,
bool IsDynamicVGPRChainCall = false) {
if (Info.Callee.isReg()) {
CallInst.addReg(Info.Callee.getReg());
CallInst.addImm(0);
Expand All @@ -983,7 +989,12 @@ static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
auto Ptr = MIRBuilder.buildGlobalValue(
LLT::pointer(GV->getAddressSpace(), 64), GV);
CallInst.addReg(Ptr.getReg(0));
CallInst.add(Info.Callee);

if (IsDynamicVGPRChainCall) {
// DynamicVGPR chain calls are always indirect.
CallInst.addImm(0);
} else
CallInst.add(Info.Callee);
} else
return false;

Expand Down Expand Up @@ -1177,6 +1188,18 @@ void AMDGPUCallLowering::handleImplicitCallArguments(
}
}

namespace {
// Chain calls have special arguments that we need to handle. These have the
// same index as they do in the llvm.amdgcn.cs.chain intrinsic.
enum ChainCallArgIdx {
Exec = 1,
Flags = 4,
NumVGPRs = 5,
FallbackExec = 6,
FallbackCallee = 7,
};
} // anonymous namespace

bool AMDGPUCallLowering::lowerTailCall(
MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info,
SmallVectorImpl<ArgInfo> &OutArgs) const {
Expand All @@ -1185,6 +1208,8 @@ bool AMDGPUCallLowering::lowerTailCall(
SIMachineFunctionInfo *FuncInfo = MF.getInfo<SIMachineFunctionInfo>();
const Function &F = MF.getFunction();
MachineRegisterInfo &MRI = MF.getRegInfo();
const SIInstrInfo *TII = ST.getInstrInfo();
const SIRegisterInfo *TRI = ST.getRegisterInfo();
const SITargetLowering &TLI = *getTLI<SITargetLowering>();

// True when we're tail calling, but without -tailcallopt.
Expand All @@ -1200,34 +1225,79 @@ bool AMDGPUCallLowering::lowerTailCall(
if (!IsSibCall)
CallSeqStart = MIRBuilder.buildInstr(AMDGPU::ADJCALLSTACKUP);

unsigned Opc =
getCallOpcode(MF, Info.Callee.isReg(), true, ST.isWave32(), CalleeCC);
bool IsChainCall = AMDGPU::isChainCC(Info.CallConv);
bool IsDynamicVGPRChainCall = false;

if (IsChainCall) {
ArgInfo FlagsArg = Info.OrigArgs[ChainCallArgIdx::Flags];
const APInt &FlagsValue = cast<ConstantInt>(FlagsArg.OrigValue)->getValue();
if (FlagsValue.isZero()) {
if (Info.OrigArgs.size() != 5) {
LLVM_DEBUG(dbgs() << "No additional args allowed if flags == 0\n");
return false;
}
} else if (FlagsValue.isOneBitSet(0)) {
IsDynamicVGPRChainCall = true;

if (Info.OrigArgs.size() != 8) {
LLVM_DEBUG(dbgs() << "Expected 3 additional args\n");
return false;
}

// On GFX12, we can only change the VGPR allocation for wave32.
if (!ST.isWave32()) {
F.getContext().diagnose(DiagnosticInfoUnsupported(
F, "dynamic VGPR mode is only supported for wave32"));
return false;
}

ArgInfo FallbackExecArg = Info.OrigArgs[ChainCallArgIdx::FallbackExec];
assert(FallbackExecArg.Regs.size() == 1 &&
"Expected single register for fallback EXEC");
if (!FallbackExecArg.Ty->isIntegerTy(ST.getWavefrontSize())) {
LLVM_DEBUG(dbgs() << "Bad type for fallback EXEC\n");
return false;
}
}
}

unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true,
ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall);
auto MIB = MIRBuilder.buildInstrNoInsert(Opc);
if (!addCallTargetOperands(MIB, MIRBuilder, Info))
if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall))
return false;

// Byte offset for the tail call. When we are sibcalling, this will always
// be 0.
MIB.addImm(0);

// If this is a chain call, we need to pass in the EXEC mask.
const SIRegisterInfo *TRI = ST.getRegisterInfo();
if (AMDGPU::isChainCC(Info.CallConv)) {
ArgInfo ExecArg = Info.OrigArgs[1];
// If this is a chain call, we need to pass in the EXEC mask as well as any
// other special args.
if (IsChainCall) {
auto AddRegOrImm = [&](const ArgInfo &Arg) {
if (auto CI = dyn_cast<ConstantInt>(Arg.OrigValue)) {
MIB.addImm(CI->getSExtValue());
} else {
MIB.addReg(Arg.Regs[0]);
unsigned Idx = MIB->getNumOperands() - 1;
MIB->getOperand(Idx).setReg(constrainOperandRegClass(
MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(),
MIB->getOperand(Idx), Idx));
}
};

ArgInfo ExecArg = Info.OrigArgs[ChainCallArgIdx::Exec];
assert(ExecArg.Regs.size() == 1 && "Too many regs for EXEC");

if (!ExecArg.Ty->isIntegerTy(ST.getWavefrontSize()))
if (!ExecArg.Ty->isIntegerTy(ST.getWavefrontSize())) {
LLVM_DEBUG(dbgs() << "Bad type for EXEC");
return false;

if (const auto *CI = dyn_cast<ConstantInt>(ExecArg.OrigValue)) {
MIB.addImm(CI->getSExtValue());
} else {
MIB.addReg(ExecArg.Regs[0]);
unsigned Idx = MIB->getNumOperands() - 1;
MIB->getOperand(Idx).setReg(constrainOperandRegClass(
MF, *TRI, MRI, *ST.getInstrInfo(), *ST.getRegBankInfo(), *MIB,
MIB->getDesc(), MIB->getOperand(Idx), Idx));
}

AddRegOrImm(ExecArg);
if (IsDynamicVGPRChainCall)
std::for_each(Info.OrigArgs.begin() + ChainCallArgIdx::NumVGPRs,
Info.OrigArgs.end(), AddRegOrImm);
}

// Tell the call which registers are clobbered.
Expand Down Expand Up @@ -1329,9 +1399,9 @@ bool AMDGPUCallLowering::lowerTailCall(
// FIXME: We should define regbankselectable call instructions to handle
// divergent call targets.
if (MIB->getOperand(0).isReg()) {
MIB->getOperand(0).setReg(constrainOperandRegClass(
MF, *TRI, MRI, *ST.getInstrInfo(), *ST.getRegBankInfo(), *MIB,
MIB->getDesc(), MIB->getOperand(0), 0));
MIB->getOperand(0).setReg(
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
}

MF.getFrameInfo().setHasTailCall();
Expand All @@ -1345,11 +1415,6 @@ bool AMDGPUCallLowering::lowerChainCall(MachineIRBuilder &MIRBuilder,
ArgInfo Callee = Info.OrigArgs[0];
ArgInfo SGPRArgs = Info.OrigArgs[2];
ArgInfo VGPRArgs = Info.OrigArgs[3];
ArgInfo Flags = Info.OrigArgs[4];

assert(cast<ConstantInt>(Flags.OrigValue)->isZero() &&
"Non-zero flags aren't supported yet.");
assert(Info.OrigArgs.size() == 5 && "Additional args aren't supported yet.");

MachineFunction &MF = MIRBuilder.getMF();
const Function &F = MF.getFunction();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5492,6 +5492,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(TC_RETURN)
NODE_NAME_CASE(TC_RETURN_GFX)
NODE_NAME_CASE(TC_RETURN_CHAIN)
NODE_NAME_CASE(TC_RETURN_CHAIN_DVGPR)
NODE_NAME_CASE(TRAP)
NODE_NAME_CASE(RET_GLUE)
NODE_NAME_CASE(WAVE_ADDRESS)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ enum NodeType : unsigned {
TC_RETURN,
TC_RETURN_GFX,
TC_RETURN_CHAIN,
TC_RETURN_CHAIN_DVGPR,
TRAP,

// Masked control flow nodes.
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def AMDGPUtc_return_chain: SDNode<"AMDGPUISD::TC_RETURN_CHAIN",
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
>;

// With dynamic VGPRs.
def AMDGPUtc_return_chain_dvgpr: SDNode<"AMDGPUISD::TC_RETURN_CHAIN_DVGPR",
SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
>;

def AMDGPUtrap : SDNode<"AMDGPUISD::TRAP",
SDTypeProfile<0, 1, [SDTCisVT<0, i16>]>,
[SDNPHasChain, SDNPVariadic, SDNPSideEffect, SDNPOptInGlue]
Expand Down
Loading