Skip to content

Commit 94d76dc

Browse files
committed
[clang][SPIR-V] Always add convervence intrinsics
PR llvm#80680 added bits in the codegen to lazily add convergence intrinsics when required. This logic relied on the LoopStack. The issue is when parsing the condition, the loopstack doesn't yet reflect the correct values, as expected since we are not yet in the loop. However, convergence tokens should sometimes already be available. The solution which seemed the simplest is to greedily generate the tokens when we generate SPIR-V. Fixes llvm#88144 Signed-off-by: Nathan Gauër <[email protected]>
1 parent 40327a6 commit 94d76dc

File tree

9 files changed

+445
-89
lines changed

9 files changed

+445
-89
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,91 +1133,8 @@ struct BitTest {
11331133
static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
11341134
};
11351135

1136-
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
1137-
// std::nullptr otherwise.
1138-
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
1139-
for (auto &I : *BB) {
1140-
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
1141-
if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
1142-
return II;
1143-
}
1144-
return nullptr;
1145-
}
1146-
11471136
} // namespace
11481137

1149-
llvm::CallBase *
1150-
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
1151-
llvm::Value *ParentToken) {
1152-
llvm::Value *bundleArgs[] = {ParentToken};
1153-
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
1154-
auto Output = llvm::CallBase::addOperandBundle(
1155-
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
1156-
Input->replaceAllUsesWith(Output);
1157-
Input->eraseFromParent();
1158-
return Output;
1159-
}
1160-
1161-
llvm::IntrinsicInst *
1162-
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
1163-
llvm::Value *ParentToken) {
1164-
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1165-
Builder.SetInsertPoint(&BB->front());
1166-
auto CB = Builder.CreateIntrinsic(
1167-
llvm::Intrinsic::experimental_convergence_loop, {}, {});
1168-
Builder.restoreIP(IP);
1169-
1170-
auto I = addConvergenceControlToken(CB, ParentToken);
1171-
return cast<llvm::IntrinsicInst>(I);
1172-
}
1173-
1174-
llvm::IntrinsicInst *
1175-
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
1176-
auto *BB = &F->getEntryBlock();
1177-
auto *token = getConvergenceToken(BB);
1178-
if (token)
1179-
return token;
1180-
1181-
// Adding a convergence token requires the function to be marked as
1182-
// convergent.
1183-
F->setConvergent();
1184-
1185-
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1186-
Builder.SetInsertPoint(&BB->front());
1187-
auto I = Builder.CreateIntrinsic(
1188-
llvm::Intrinsic::experimental_convergence_entry, {}, {});
1189-
assert(isa<llvm::IntrinsicInst>(I));
1190-
Builder.restoreIP(IP);
1191-
1192-
return cast<llvm::IntrinsicInst>(I);
1193-
}
1194-
1195-
llvm::IntrinsicInst *
1196-
CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
1197-
assert(LI != nullptr);
1198-
1199-
auto *token = getConvergenceToken(LI->getHeader());
1200-
if (token)
1201-
return token;
1202-
1203-
llvm::IntrinsicInst *PII =
1204-
LI->getParent()
1205-
? emitConvergenceLoopToken(
1206-
LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
1207-
: getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
1208-
1209-
return emitConvergenceLoopToken(LI->getHeader(), PII);
1210-
}
1211-
1212-
llvm::CallBase *
1213-
CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
1214-
llvm::Value *ParentToken =
1215-
LoopStack.hasInfo()
1216-
? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
1217-
: getOrEmitConvergenceEntryToken(Input->getFunction());
1218-
return addConvergenceControlToken(Input, ParentToken);
1219-
}
1220-
12211138
BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
12221139
switch (BuiltinID) {
12231140
// Main portable variants.
@@ -18306,12 +18223,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1830618223
ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
1830718224
}
1830818225
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
18309-
auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
18226+
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1831018227
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
1831118228
{}, false, true));
18312-
if (getTarget().getTriple().isSPIRVLogical())
18313-
CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
18314-
return CI;
1831518229
}
1831618230
}
1831718231
return nullptr;

clang/lib/CodeGen/CGCall.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4824,6 +4824,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
48244824
llvm::CallInst *call = Builder.CreateCall(
48254825
callee, args, getBundlesForFunclet(callee.getCallee()), name);
48264826
call->setCallingConv(getRuntimeCC());
4827+
4828+
if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
4829+
return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
48274830
return call;
48284831
}
48294832

clang/lib/CodeGen/CGStmt.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
915915
JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
916916
EmitBlock(LoopHeader.getBlock());
917917

918+
if (getTarget().getTriple().isSPIRVLogical())
919+
ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
920+
LoopHeader.getBlock(), ConvergenceTokenStack.back()));
921+
918922
// Create an exit block for when the condition fails, which will
919923
// also become the break target.
920924
JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1017,6 +1021,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
10171021
// block.
10181022
if (llvm::EnableSingleByteCoverage)
10191023
incrementProfileCounter(&S);
1024+
1025+
if (getTarget().getTriple().isSPIRVLogical())
1026+
ConvergenceTokenStack.pop_back();
10201027
}
10211028

10221029
void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1036,6 +1043,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
10361043
EmitBlockWithFallThrough(LoopBody, S.getBody());
10371044
else
10381045
EmitBlockWithFallThrough(LoopBody, &S);
1046+
1047+
if (getTarget().getTriple().isSPIRVLogical())
1048+
ConvergenceTokenStack.push_back(
1049+
emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
1050+
10391051
{
10401052
RunCleanupsScope BodyScope(*this);
10411053
EmitStmt(S.getBody());
@@ -1090,6 +1102,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
10901102
// block.
10911103
if (llvm::EnableSingleByteCoverage)
10921104
incrementProfileCounter(&S);
1105+
1106+
if (getTarget().getTriple().isSPIRVLogical())
1107+
ConvergenceTokenStack.pop_back();
10931108
}
10941109

10951110
void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1109,6 +1124,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
11091124
llvm::BasicBlock *CondBlock = CondDest.getBlock();
11101125
EmitBlock(CondBlock);
11111126

1127+
if (getTarget().getTriple().isSPIRVLogical())
1128+
ConvergenceTokenStack.push_back(
1129+
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
1130+
11121131
Expr::EvalResult Result;
11131132
bool CondIsConstInt =
11141133
!S.getCond() || S.getCond()->EvaluateAsInt(Result, getContext());
@@ -1222,6 +1241,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
12221241
// block.
12231242
if (llvm::EnableSingleByteCoverage)
12241243
incrementProfileCounter(&S);
1244+
1245+
if (getTarget().getTriple().isSPIRVLogical())
1246+
ConvergenceTokenStack.pop_back();
12251247
}
12261248

12271249
void
@@ -1244,6 +1266,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
12441266
llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
12451267
EmitBlock(CondBlock);
12461268

1269+
if (getTarget().getTriple().isSPIRVLogical())
1270+
ConvergenceTokenStack.push_back(
1271+
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
1272+
12471273
const SourceRange &R = S.getSourceRange();
12481274
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
12491275
SourceLocToDebugLoc(R.getBegin()),
@@ -1312,6 +1338,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
13121338
// block.
13131339
if (llvm::EnableSingleByteCoverage)
13141340
incrementProfileCounter(&S);
1341+
1342+
if (getTarget().getTriple().isSPIRVLogical())
1343+
ConvergenceTokenStack.pop_back();
13151344
}
13161345

13171346
void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3101,3 +3130,68 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
31013130

31023131
return F;
31033132
}
3133+
3134+
namespace {
3135+
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
3136+
// std::nullptr otherwise.
3137+
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
3138+
for (auto &I : *BB) {
3139+
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
3140+
if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
3141+
return II;
3142+
}
3143+
return nullptr;
3144+
}
3145+
3146+
} // namespace
3147+
3148+
llvm::CallBase *
3149+
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
3150+
llvm::Value *ParentToken) {
3151+
llvm::Value *bundleArgs[] = {ParentToken};
3152+
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
3153+
auto Output = llvm::CallBase::addOperandBundle(
3154+
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
3155+
Input->replaceAllUsesWith(Output);
3156+
Input->eraseFromParent();
3157+
return Output;
3158+
}
3159+
3160+
llvm::IntrinsicInst *
3161+
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
3162+
llvm::Value *ParentToken) {
3163+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
3164+
3165+
if (BB->empty())
3166+
Builder.SetInsertPoint(BB);
3167+
else
3168+
Builder.SetInsertPoint(&BB->front());
3169+
3170+
auto CB = Builder.CreateIntrinsic(
3171+
llvm::Intrinsic::experimental_convergence_loop, {}, {});
3172+
Builder.restoreIP(IP);
3173+
3174+
auto I = addConvergenceControlToken(CB, ParentToken);
3175+
return cast<llvm::IntrinsicInst>(I);
3176+
}
3177+
3178+
llvm::IntrinsicInst *
3179+
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
3180+
auto *BB = &F->getEntryBlock();
3181+
auto *token = getConvergenceToken(BB);
3182+
if (token)
3183+
return token;
3184+
3185+
// Adding a convergence token requires the function to be marked as
3186+
// convergent.
3187+
F->setConvergent();
3188+
3189+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
3190+
Builder.SetInsertPoint(&BB->front());
3191+
auto I = Builder.CreateIntrinsic(
3192+
llvm::Intrinsic::experimental_convergence_entry, {}, {});
3193+
assert(isa<llvm::IntrinsicInst>(I));
3194+
Builder.restoreIP(IP);
3195+
3196+
return cast<llvm::IntrinsicInst>(I);
3197+
}

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
347347
assert(BreakContinueStack.empty() &&
348348
"mismatched push/pop in break/continue stack!");
349349

350+
if (getTarget().getTriple().isSPIRVLogical()) {
351+
ConvergenceTokenStack.pop_back();
352+
assert(ConvergenceTokenStack.empty() &&
353+
"mismatched push/pop in convergence stack!");
354+
}
355+
350356
bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
351357
&& NumSimpleReturnExprs == NumReturnExprs
352358
&& ReturnBlock.getBlock()->use_empty();
@@ -1271,6 +1277,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
12711277
if (CurFuncDecl)
12721278
if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
12731279
LargestVectorWidth = VecWidth->getVectorWidth();
1280+
1281+
if (getTarget().getTriple().isSPIRVLogical())
1282+
ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
12741283
}
12751284

12761285
void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ class CodeGenFunction : public CodeGenTypeCache {
314314
/// Stack to track the Logical Operator recursion nest for MC/DC.
315315
SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
316316

317+
/// Stack to track the controlled convergence tokens.
318+
SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
319+
317320
/// Number of nested loop to be consumed by the last surrounding
318321
/// loop-associated directive.
319322
int ExpectedOMPLoopDepth = 0;
@@ -4987,7 +4990,11 @@ class CodeGenFunction : public CodeGenTypeCache {
49874990
const llvm::Twine &Name = "");
49884991
// Adds a convergence_ctrl token to |Input| and emits the required parent
49894992
// convergence instructions.
4990-
llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
4993+
template <typename CallType>
4994+
CallType *addControlledConvergenceToken(CallType *Input) {
4995+
return dyn_cast<CallType>(
4996+
addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
4997+
}
49914998

49924999
private:
49935000
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|

clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
21
// RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
32

43
RWBuffer<float> Buf;

0 commit comments

Comments
 (0)