Skip to content

Commit 98e3075

Browse files
authored
[HLSL][SPIRV] Add convergence tokens to entry point wrapper (#112757)
Inlining currently assumes that either all function use controled convergence or none of them do. This is why we need to have the entry point wrapper use controled convergence. https://github.com/llvm/llvm-project/blob/c85611e8583e6392d56075ebdfa60893b6284813/llvm/lib/Transforms/Utils/InlineFunction.cpp#L2431-L2439
1 parent e517cfc commit 98e3075

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,16 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
404404
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
405405
IRBuilder<> B(BB);
406406
llvm::SmallVector<Value *> Args;
407+
408+
SmallVector<OperandBundleDef, 1> OB;
409+
if (CGM.shouldEmitConvergenceTokens()) {
410+
assert(EntryFn->isConvergent());
411+
llvm::Value *I = B.CreateIntrinsic(
412+
llvm::Intrinsic::experimental_convergence_entry, {}, {});
413+
llvm::Value *bundleArgs[] = {I};
414+
OB.emplace_back("convergencectrl", bundleArgs);
415+
}
416+
407417
// FIXME: support struct parameters where semantics are on members.
408418
// See: https://github.com/llvm/llvm-project/issues/57874
409419
unsigned SRetOffset = 0;
@@ -419,7 +429,7 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
419429
Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
420430
}
421431

422-
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
432+
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
423433
CI->setCallingConv(Fn->getCallingConv());
424434
// FIXME: Handle codegen for return type semantics.
425435
// See: https://github.com/llvm/llvm-project/issues/57875
@@ -474,14 +484,22 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
474484
for (auto &F : M.functions()) {
475485
if (!F.hasFnAttribute("hlsl.shader"))
476486
continue;
477-
IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
487+
auto *Token = getConvergenceToken(F.getEntryBlock());
488+
Instruction *IP = &*F.getEntryBlock().begin();
489+
SmallVector<OperandBundleDef, 1> OB;
490+
if (Token) {
491+
llvm::Value *bundleArgs[] = {Token};
492+
OB.emplace_back("convergencectrl", bundleArgs);
493+
IP = Token->getNextNode();
494+
}
495+
IRBuilder<> B(IP);
478496
for (auto *Fn : CtorFns)
479-
B.CreateCall(FunctionCallee(Fn));
497+
B.CreateCall(FunctionCallee(Fn), {}, OB);
480498

481499
// Insert global dtors before the terminator of the last instruction
482500
B.SetInsertPoint(F.back().getTerminator());
483501
for (auto *Fn : DtorFns)
484-
B.CreateCall(FunctionCallee(Fn));
502+
B.CreateCall(FunctionCallee(Fn), {}, OB);
485503
}
486504

487505
// No need to keep global ctors/dtors for non-lib profile after call to
@@ -579,3 +597,18 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
579597
Builder.CreateRetVoid();
580598
return InitResBindingsFunc;
581599
}
600+
601+
llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
602+
if (!CGM.shouldEmitConvergenceTokens())
603+
return nullptr;
604+
605+
auto E = BB.end();
606+
for (auto I = BB.begin(); I != E; ++I) {
607+
auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
608+
if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
609+
return II;
610+
}
611+
}
612+
llvm_unreachable("Convergence token should have been emitted.");
613+
return nullptr;
614+
}

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class CGHLSLRuntime {
143143

144144
bool needsResourceBindingInitFn();
145145
llvm::Function *createResourceBindingInitFn();
146+
llvm::Instruction *getConvergenceToken(llvm::BasicBlock &BB);
146147

147148
private:
148149
void addBufferResourceAnnotation(llvm::GlobalVariable *GV,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @main()
4+
// CHECK-NEXT: entry:
5+
// CHECK-NEXT: [[token:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
6+
// CHECK-NEXT: call spir_func void @_Z4mainv() [ "convergencectrl"(token [[token]]) ]
7+
8+
[numthreads(1,1,1)]
9+
void main() {
10+
}
11+

0 commit comments

Comments
 (0)