Skip to content

Commit a8bf6fe

Browse files
committed
[clang][SPIR-V] Always add convervence intrinsics
PR #80680 added bits in the codegen to lazily add convergence intrinsics when required. This logic relied on the LoopStack. The issue is when parsing the condition, the loopstack doesn't yet reflect the correct values, as expected since we are not yet in the loop. However, convergence tokens should sometimes already be available. The solution which seemed the simplest is to greedily generate the tokens when we generate SPIR-V. Fixes #88144 Signed-off-by: Nathan Gauër <[email protected]>
1 parent 6aed0ab commit a8bf6fe

File tree

9 files changed

+445
-89
lines changed

9 files changed

+445
-89
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,91 +1141,8 @@ struct BitTest {
11411141
static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
11421142
};
11431143

1144-
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
1145-
// std::nullptr otherwise.
1146-
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
1147-
for (auto &I : *BB) {
1148-
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
1149-
if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
1150-
return II;
1151-
}
1152-
return nullptr;
1153-
}
1154-
11551144
} // namespace
11561145

1157-
llvm::CallBase *
1158-
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
1159-
llvm::Value *ParentToken) {
1160-
llvm::Value *bundleArgs[] = {ParentToken};
1161-
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
1162-
auto Output = llvm::CallBase::addOperandBundle(
1163-
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
1164-
Input->replaceAllUsesWith(Output);
1165-
Input->eraseFromParent();
1166-
return Output;
1167-
}
1168-
1169-
llvm::IntrinsicInst *
1170-
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
1171-
llvm::Value *ParentToken) {
1172-
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1173-
Builder.SetInsertPoint(&BB->front());
1174-
auto CB = Builder.CreateIntrinsic(
1175-
llvm::Intrinsic::experimental_convergence_loop, {}, {});
1176-
Builder.restoreIP(IP);
1177-
1178-
auto I = addConvergenceControlToken(CB, ParentToken);
1179-
return cast<llvm::IntrinsicInst>(I);
1180-
}
1181-
1182-
llvm::IntrinsicInst *
1183-
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
1184-
auto *BB = &F->getEntryBlock();
1185-
auto *token = getConvergenceToken(BB);
1186-
if (token)
1187-
return token;
1188-
1189-
// Adding a convergence token requires the function to be marked as
1190-
// convergent.
1191-
F->setConvergent();
1192-
1193-
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1194-
Builder.SetInsertPoint(&BB->front());
1195-
auto I = Builder.CreateIntrinsic(
1196-
llvm::Intrinsic::experimental_convergence_entry, {}, {});
1197-
assert(isa<llvm::IntrinsicInst>(I));
1198-
Builder.restoreIP(IP);
1199-
1200-
return cast<llvm::IntrinsicInst>(I);
1201-
}
1202-
1203-
llvm::IntrinsicInst *
1204-
CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
1205-
assert(LI != nullptr);
1206-
1207-
auto *token = getConvergenceToken(LI->getHeader());
1208-
if (token)
1209-
return token;
1210-
1211-
llvm::IntrinsicInst *PII =
1212-
LI->getParent()
1213-
? emitConvergenceLoopToken(
1214-
LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
1215-
: getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
1216-
1217-
return emitConvergenceLoopToken(LI->getHeader(), PII);
1218-
}
1219-
1220-
llvm::CallBase *
1221-
CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
1222-
llvm::Value *ParentToken =
1223-
LoopStack.hasInfo()
1224-
? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
1225-
: getOrEmitConvergenceEntryToken(Input->getFunction());
1226-
return addConvergenceControlToken(Input, ParentToken);
1227-
}
1228-
12291146
BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
12301147
switch (BuiltinID) {
12311148
// Main portable variants.
@@ -18400,12 +18317,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1840018317
ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
1840118318
}
1840218319
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
18403-
auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
18320+
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1840418321
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
1840518322
{}, false, true));
18406-
if (getTarget().getTriple().isSPIRVLogical())
18407-
CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
18408-
return CI;
1840918323
}
1841018324
}
1841118325
return nullptr;

clang/lib/CodeGen/CGCall.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4830,6 +4830,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
48304830
llvm::CallInst *call = Builder.CreateCall(
48314831
callee, args, getBundlesForFunclet(callee.getCallee()), name);
48324832
call->setCallingConv(getRuntimeCC());
4833+
4834+
if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
4835+
return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
48334836
return call;
48344837
}
48354838

clang/lib/CodeGen/CGStmt.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
978978
JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
979979
EmitBlock(LoopHeader.getBlock());
980980

981+
if (getTarget().getTriple().isSPIRVLogical())
982+
ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
983+
LoopHeader.getBlock(), ConvergenceTokenStack.back()));
984+
981985
// Create an exit block for when the condition fails, which will
982986
// also become the break target.
983987
JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1079,6 +1083,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
10791083
// block.
10801084
if (llvm::EnableSingleByteCoverage)
10811085
incrementProfileCounter(&S);
1086+
1087+
if (getTarget().getTriple().isSPIRVLogical())
1088+
ConvergenceTokenStack.pop_back();
10821089
}
10831090

10841091
void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1098,6 +1105,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
10981105
EmitBlockWithFallThrough(LoopBody, S.getBody());
10991106
else
11001107
EmitBlockWithFallThrough(LoopBody, &S);
1108+
1109+
if (getTarget().getTriple().isSPIRVLogical())
1110+
ConvergenceTokenStack.push_back(
1111+
emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
1112+
11011113
{
11021114
RunCleanupsScope BodyScope(*this);
11031115
EmitStmt(S.getBody());
@@ -1151,6 +1163,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
11511163
// block.
11521164
if (llvm::EnableSingleByteCoverage)
11531165
incrementProfileCounter(&S);
1166+
1167+
if (getTarget().getTriple().isSPIRVLogical())
1168+
ConvergenceTokenStack.pop_back();
11541169
}
11551170

11561171
void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1170,6 +1185,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
11701185
llvm::BasicBlock *CondBlock = CondDest.getBlock();
11711186
EmitBlock(CondBlock);
11721187

1188+
if (getTarget().getTriple().isSPIRVLogical())
1189+
ConvergenceTokenStack.push_back(
1190+
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
1191+
11731192
const SourceRange &R = S.getSourceRange();
11741193
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
11751194
SourceLocToDebugLoc(R.getBegin()),
@@ -1279,6 +1298,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
12791298
// block.
12801299
if (llvm::EnableSingleByteCoverage)
12811300
incrementProfileCounter(&S);
1301+
1302+
if (getTarget().getTriple().isSPIRVLogical())
1303+
ConvergenceTokenStack.pop_back();
12821304
}
12831305

12841306
void
@@ -1301,6 +1323,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
13011323
llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
13021324
EmitBlock(CondBlock);
13031325

1326+
if (getTarget().getTriple().isSPIRVLogical())
1327+
ConvergenceTokenStack.push_back(
1328+
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
1329+
13041330
const SourceRange &R = S.getSourceRange();
13051331
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
13061332
SourceLocToDebugLoc(R.getBegin()),
@@ -1369,6 +1395,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
13691395
// block.
13701396
if (llvm::EnableSingleByteCoverage)
13711397
incrementProfileCounter(&S);
1398+
1399+
if (getTarget().getTriple().isSPIRVLogical())
1400+
ConvergenceTokenStack.pop_back();
13721401
}
13731402

13741403
void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3158,3 +3187,68 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
31583187

31593188
return F;
31603189
}
3190+
3191+
namespace {
3192+
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
3193+
// std::nullptr otherwise.
3194+
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
3195+
for (auto &I : *BB) {
3196+
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
3197+
if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
3198+
return II;
3199+
}
3200+
return nullptr;
3201+
}
3202+
3203+
} // namespace
3204+
3205+
llvm::CallBase *
3206+
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
3207+
llvm::Value *ParentToken) {
3208+
llvm::Value *bundleArgs[] = {ParentToken};
3209+
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
3210+
auto Output = llvm::CallBase::addOperandBundle(
3211+
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
3212+
Input->replaceAllUsesWith(Output);
3213+
Input->eraseFromParent();
3214+
return Output;
3215+
}
3216+
3217+
llvm::IntrinsicInst *
3218+
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
3219+
llvm::Value *ParentToken) {
3220+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
3221+
3222+
if (BB->empty())
3223+
Builder.SetInsertPoint(BB);
3224+
else
3225+
Builder.SetInsertPoint(&BB->front());
3226+
3227+
auto CB = Builder.CreateIntrinsic(
3228+
llvm::Intrinsic::experimental_convergence_loop, {}, {});
3229+
Builder.restoreIP(IP);
3230+
3231+
auto I = addConvergenceControlToken(CB, ParentToken);
3232+
return cast<llvm::IntrinsicInst>(I);
3233+
}
3234+
3235+
llvm::IntrinsicInst *
3236+
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
3237+
auto *BB = &F->getEntryBlock();
3238+
auto *token = getConvergenceToken(BB);
3239+
if (token)
3240+
return token;
3241+
3242+
// Adding a convergence token requires the function to be marked as
3243+
// convergent.
3244+
F->setConvergent();
3245+
3246+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
3247+
Builder.SetInsertPoint(&BB->front());
3248+
auto I = Builder.CreateIntrinsic(
3249+
llvm::Intrinsic::experimental_convergence_entry, {}, {});
3250+
assert(isa<llvm::IntrinsicInst>(I));
3251+
Builder.restoreIP(IP);
3252+
3253+
return cast<llvm::IntrinsicInst>(I);
3254+
}

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
353353
assert(DeferredDeactivationCleanupStack.empty() &&
354354
"mismatched activate/deactivate of cleanups!");
355355

356+
if (getTarget().getTriple().isSPIRVLogical()) {
357+
ConvergenceTokenStack.pop_back();
358+
assert(ConvergenceTokenStack.empty() &&
359+
"mismatched push/pop in convergence stack!");
360+
}
361+
356362
bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
357363
&& NumSimpleReturnExprs == NumReturnExprs
358364
&& ReturnBlock.getBlock()->use_empty();
@@ -1277,6 +1283,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
12771283
if (CurFuncDecl)
12781284
if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
12791285
LargestVectorWidth = VecWidth->getVectorWidth();
1286+
1287+
if (getTarget().getTriple().isSPIRVLogical())
1288+
ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
12801289
}
12811290

12821291
void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ class CodeGenFunction : public CodeGenTypeCache {
315315
/// Stack to track the Logical Operator recursion nest for MC/DC.
316316
SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
317317

318+
/// Stack to track the controlled convergence tokens.
319+
SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
320+
318321
/// Number of nested loop to be consumed by the last surrounding
319322
/// loop-associated directive.
320323
int ExpectedOMPLoopDepth = 0;
@@ -5076,7 +5079,11 @@ class CodeGenFunction : public CodeGenTypeCache {
50765079
const llvm::Twine &Name = "");
50775080
// Adds a convergence_ctrl token to |Input| and emits the required parent
50785081
// convergence instructions.
5079-
llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
5082+
template <typename CallType>
5083+
CallType *addControlledConvergenceToken(CallType *Input) {
5084+
return dyn_cast<CallType>(
5085+
addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
5086+
}
50805087

50815088
private:
50825089
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|

clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
21
// RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
32

43
RWBuffer<float> Buf;

0 commit comments

Comments
 (0)