@@ -952,17 +952,22 @@ getAssignFnsForCC(CallingConv::ID CC, const SITargetLowering &TLI) {
952
952
}
953
953
954
954
static unsigned getCallOpcode (const MachineFunction &CallerF, bool IsIndirect,
955
- bool IsTailCall, bool isWave32,
956
- CallingConv::ID CC) {
955
+ bool IsTailCall, bool IsWave32,
956
+ CallingConv::ID CC,
957
+ bool IsDynamicVGPRChainCall = false ) {
957
958
// For calls to amdgpu_cs_chain functions, the address is known to be uniform.
958
959
assert ((AMDGPU::isChainCC (CC) || !IsIndirect || !IsTailCall) &&
959
960
" Indirect calls can't be tail calls, "
960
961
" because the address can be divergent" );
961
962
if (!IsTailCall)
962
963
return AMDGPU::G_SI_CALL;
963
964
964
- if (AMDGPU::isChainCC (CC))
965
- return isWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
965
+ if (AMDGPU::isChainCC (CC)) {
966
+ if (IsDynamicVGPRChainCall)
967
+ return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32_DVGPR
968
+ : AMDGPU::SI_CS_CHAIN_TC_W64_DVGPR;
969
+ return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
970
+ }
966
971
967
972
return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
968
973
AMDGPU::SI_TCRETURN;
@@ -971,7 +976,8 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
971
976
// Add operands to call instruction to track the callee.
972
977
static bool addCallTargetOperands (MachineInstrBuilder &CallInst,
973
978
MachineIRBuilder &MIRBuilder,
974
- AMDGPUCallLowering::CallLoweringInfo &Info) {
979
+ AMDGPUCallLowering::CallLoweringInfo &Info,
980
+ bool IsDynamicVGPRChainCall = false ) {
975
981
if (Info.Callee .isReg ()) {
976
982
CallInst.addReg (Info.Callee .getReg ());
977
983
CallInst.addImm (0 );
@@ -982,7 +988,12 @@ static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
982
988
auto Ptr = MIRBuilder.buildGlobalValue (
983
989
LLT::pointer (GV->getAddressSpace (), 64 ), GV);
984
990
CallInst.addReg (Ptr.getReg (0 ));
985
- CallInst.add (Info.Callee );
991
+
992
+ if (IsDynamicVGPRChainCall) {
993
+ // DynamicVGPR chain calls are always indirect.
994
+ CallInst.addImm (0 );
995
+ } else
996
+ CallInst.add (Info.Callee );
986
997
} else
987
998
return false ;
988
999
@@ -1176,6 +1187,18 @@ void AMDGPUCallLowering::handleImplicitCallArguments(
1176
1187
}
1177
1188
}
1178
1189
1190
+ namespace {
1191
+ // Chain calls have special arguments that we need to handle. These have the
1192
+ // same index as they do in the llvm.amdgcn.cs.chain intrinsic.
1193
+ enum ChainCallArgIdx {
1194
+ Exec = 1 ,
1195
+ Flags = 4 ,
1196
+ NumVGPRs = 5 ,
1197
+ FallbackExec = 6 ,
1198
+ FallbackCallee = 7 ,
1199
+ };
1200
+ } // anonymous namespace
1201
+
1179
1202
bool AMDGPUCallLowering::lowerTailCall (
1180
1203
MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info,
1181
1204
SmallVectorImpl<ArgInfo> &OutArgs) const {
@@ -1184,6 +1207,8 @@ bool AMDGPUCallLowering::lowerTailCall(
1184
1207
SIMachineFunctionInfo *FuncInfo = MF.getInfo <SIMachineFunctionInfo>();
1185
1208
const Function &F = MF.getFunction ();
1186
1209
MachineRegisterInfo &MRI = MF.getRegInfo ();
1210
+ const SIInstrInfo *TII = ST.getInstrInfo ();
1211
+ const SIRegisterInfo *TRI = ST.getRegisterInfo ();
1187
1212
const SITargetLowering &TLI = *getTLI<SITargetLowering>();
1188
1213
1189
1214
// True when we're tail calling, but without -tailcallopt.
@@ -1199,34 +1224,79 @@ bool AMDGPUCallLowering::lowerTailCall(
1199
1224
if (!IsSibCall)
1200
1225
CallSeqStart = MIRBuilder.buildInstr (AMDGPU::ADJCALLSTACKUP);
1201
1226
1202
- unsigned Opc =
1203
- getCallOpcode (MF, Info.Callee .isReg (), true , ST.isWave32 (), CalleeCC);
1227
+ bool IsChainCall = AMDGPU::isChainCC (Info.CallConv );
1228
+ bool IsDynamicVGPRChainCall = false ;
1229
+
1230
+ if (IsChainCall) {
1231
+ ArgInfo FlagsArg = Info.OrigArgs [ChainCallArgIdx::Flags];
1232
+ const APInt &FlagsValue = cast<ConstantInt>(FlagsArg.OrigValue )->getValue ();
1233
+ if (FlagsValue.isZero ()) {
1234
+ if (Info.OrigArgs .size () != 5 ) {
1235
+ LLVM_DEBUG (dbgs () << " No additional args allowed if flags == 0\n " );
1236
+ return false ;
1237
+ }
1238
+ } else if (FlagsValue.isOneBitSet (0 )) {
1239
+ IsDynamicVGPRChainCall = true ;
1240
+
1241
+ if (Info.OrigArgs .size () != 8 ) {
1242
+ LLVM_DEBUG (dbgs () << " Expected 3 additional args" );
1243
+ return false ;
1244
+ }
1245
+
1246
+ // On GFX12, we can only change the VGPR allocation for wave32.
1247
+ if (!ST.isWave32 ()) {
1248
+ F.getContext ().diagnose (DiagnosticInfoUnsupported (
1249
+ F, " Dynamic VGPR mode is only supported for wave32\n " ));
1250
+ return false ;
1251
+ }
1252
+
1253
+ ArgInfo FallbackExecArg = Info.OrigArgs [ChainCallArgIdx::FallbackExec];
1254
+ assert (FallbackExecArg.Regs .size () == 1 &&
1255
+ " Expected single register for fallback EXEC" );
1256
+ if (!FallbackExecArg.Ty ->isIntegerTy (ST.getWavefrontSize ())) {
1257
+ LLVM_DEBUG (dbgs () << " Bad type for fallback EXEC" );
1258
+ return false ;
1259
+ }
1260
+ }
1261
+ }
1262
+
1263
+ unsigned Opc = getCallOpcode (MF, Info.Callee .isReg (), /* IsTailCall*/ true ,
1264
+ ST.isWave32 (), CalleeCC, IsDynamicVGPRChainCall);
1204
1265
auto MIB = MIRBuilder.buildInstrNoInsert (Opc);
1205
- if (!addCallTargetOperands (MIB, MIRBuilder, Info))
1266
+ if (!addCallTargetOperands (MIB, MIRBuilder, Info, IsDynamicVGPRChainCall ))
1206
1267
return false ;
1207
1268
1208
1269
// Byte offset for the tail call. When we are sibcalling, this will always
1209
1270
// be 0.
1210
1271
MIB.addImm (0 );
1211
1272
1212
- // If this is a chain call, we need to pass in the EXEC mask.
1213
- const SIRegisterInfo *TRI = ST.getRegisterInfo ();
1214
- if (AMDGPU::isChainCC (Info.CallConv )) {
1215
- ArgInfo ExecArg = Info.OrigArgs [1 ];
1273
+ // If this is a chain call, we need to pass in the EXEC mask as well as any
1274
+ // other special args.
1275
+ if (IsChainCall) {
1276
+ auto AddRegOrImm = [&](const ArgInfo &Arg) {
1277
+ if (auto CI = dyn_cast<ConstantInt>(Arg.OrigValue )) {
1278
+ MIB.addImm (CI->getSExtValue ());
1279
+ } else {
1280
+ MIB.addReg (Arg.Regs [0 ]);
1281
+ unsigned Idx = MIB->getNumOperands () - 1 ;
1282
+ MIB->getOperand (Idx).setReg (constrainOperandRegClass (
1283
+ MF, *TRI, MRI, *TII, *ST.getRegBankInfo (), *MIB, MIB->getDesc (),
1284
+ MIB->getOperand (Idx), Idx));
1285
+ }
1286
+ };
1287
+
1288
+ ArgInfo ExecArg = Info.OrigArgs [ChainCallArgIdx::Exec];
1216
1289
assert (ExecArg.Regs .size () == 1 && " Too many regs for EXEC" );
1217
1290
1218
- if (!ExecArg.Ty ->isIntegerTy (ST.getWavefrontSize ()))
1291
+ if (!ExecArg.Ty ->isIntegerTy (ST.getWavefrontSize ())) {
1292
+ LLVM_DEBUG (dbgs () << " Bad type for EXEC" );
1219
1293
return false ;
1220
-
1221
- if (const auto *CI = dyn_cast<ConstantInt>(ExecArg.OrigValue )) {
1222
- MIB.addImm (CI->getSExtValue ());
1223
- } else {
1224
- MIB.addReg (ExecArg.Regs [0 ]);
1225
- unsigned Idx = MIB->getNumOperands () - 1 ;
1226
- MIB->getOperand (Idx).setReg (constrainOperandRegClass (
1227
- MF, *TRI, MRI, *ST.getInstrInfo (), *ST.getRegBankInfo (), *MIB,
1228
- MIB->getDesc (), MIB->getOperand (Idx), Idx));
1229
1294
}
1295
+
1296
+ AddRegOrImm (ExecArg);
1297
+ if (IsDynamicVGPRChainCall)
1298
+ std::for_each (Info.OrigArgs .begin () + ChainCallArgIdx::NumVGPRs,
1299
+ Info.OrigArgs .end (), AddRegOrImm);
1230
1300
}
1231
1301
1232
1302
// Tell the call which registers are clobbered.
@@ -1328,9 +1398,9 @@ bool AMDGPUCallLowering::lowerTailCall(
1328
1398
// FIXME: We should define regbankselectable call instructions to handle
1329
1399
// divergent call targets.
1330
1400
if (MIB->getOperand (0 ).isReg ()) {
1331
- MIB->getOperand (0 ).setReg (constrainOperandRegClass (
1332
- MF, *TRI, MRI, *ST. getInstrInfo () , *ST.getRegBankInfo (), *MIB ,
1333
- MIB->getDesc (), MIB->getOperand (0 ), 0 ));
1401
+ MIB->getOperand (0 ).setReg (
1402
+ constrainOperandRegClass ( MF, *TRI, MRI, *TII , *ST.getRegBankInfo (),
1403
+ *MIB, MIB->getDesc (), MIB->getOperand (0 ), 0 ));
1334
1404
}
1335
1405
1336
1406
MF.getFrameInfo ().setHasTailCall ();
@@ -1344,11 +1414,6 @@ bool AMDGPUCallLowering::lowerChainCall(MachineIRBuilder &MIRBuilder,
1344
1414
ArgInfo Callee = Info.OrigArgs [0 ];
1345
1415
ArgInfo SGPRArgs = Info.OrigArgs [2 ];
1346
1416
ArgInfo VGPRArgs = Info.OrigArgs [3 ];
1347
- ArgInfo Flags = Info.OrigArgs [4 ];
1348
-
1349
- assert (cast<ConstantInt>(Flags.OrigValue )->isZero () &&
1350
- " Non-zero flags aren't supported yet." );
1351
- assert (Info.OrigArgs .size () == 5 && " Additional args aren't supported yet." );
1352
1417
1353
1418
MachineFunction &MF = MIRBuilder.getMF ();
1354
1419
const Function &F = MF.getFunction ();
0 commit comments