Skip to content

[clang][SPIR-V] Always add convergence intrinsics #88918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 1 addition & 87 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,91 +1141,8 @@ struct BitTest {
static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
};

// Returns the first convergence entry/loop/anchor instruction found in |BB|.
// std::nullptr otherwise.
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
for (auto &I : *BB) {
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
return II;
}
return nullptr;
}

} // namespace

llvm::CallBase *
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
llvm::Value *ParentToken) {
llvm::Value *bundleArgs[] = {ParentToken};
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
auto Output = llvm::CallBase::addOperandBundle(
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
Input->replaceAllUsesWith(Output);
Input->eraseFromParent();
return Output;
}

llvm::IntrinsicInst *
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
llvm::Value *ParentToken) {
CGBuilderTy::InsertPoint IP = Builder.saveIP();
Builder.SetInsertPoint(&BB->front());
auto CB = Builder.CreateIntrinsic(
llvm::Intrinsic::experimental_convergence_loop, {}, {});
Builder.restoreIP(IP);

auto I = addConvergenceControlToken(CB, ParentToken);
return cast<llvm::IntrinsicInst>(I);
}

llvm::IntrinsicInst *
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
auto *BB = &F->getEntryBlock();
auto *token = getConvergenceToken(BB);
if (token)
return token;

// Adding a convergence token requires the function to be marked as
// convergent.
F->setConvergent();

CGBuilderTy::InsertPoint IP = Builder.saveIP();
Builder.SetInsertPoint(&BB->front());
auto I = Builder.CreateIntrinsic(
llvm::Intrinsic::experimental_convergence_entry, {}, {});
assert(isa<llvm::IntrinsicInst>(I));
Builder.restoreIP(IP);

return cast<llvm::IntrinsicInst>(I);
}

llvm::IntrinsicInst *
CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
assert(LI != nullptr);

auto *token = getConvergenceToken(LI->getHeader());
if (token)
return token;

llvm::IntrinsicInst *PII =
LI->getParent()
? emitConvergenceLoopToken(
LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
: getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());

return emitConvergenceLoopToken(LI->getHeader(), PII);
}

llvm::CallBase *
CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
llvm::Value *ParentToken =
LoopStack.hasInfo()
? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
: getOrEmitConvergenceEntryToken(Input->getFunction());
return addConvergenceControlToken(Input, ParentToken);
}

BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
switch (BuiltinID) {
// Main portable variants.
Expand Down Expand Up @@ -18402,12 +18319,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
{}, false, true));
if (getTarget().getTriple().isSPIRVLogical())
CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
return CI;
}
}
return nullptr;
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4830,6 +4830,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
llvm::CallInst *call = Builder.CreateCall(
callee, args, getBundlesForFunclet(callee.getCallee()), name);
call->setCallingConv(getRuntimeCC());

if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
return addControlledConvergenceToken(call);
return call;
}

Expand Down Expand Up @@ -5730,7 +5733,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
if (!CI->getType()->isVoidTy())
CI->setName("call");

if (getTarget().getTriple().isSPIRVLogical() && CI->isConvergent())
if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
CI = addControlledConvergenceToken(CI);

// Update largest vector width from the return type.
Expand Down
93 changes: 93 additions & 0 deletions clang/lib/CodeGen/CGStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
EmitBlock(LoopHeader.getBlock());

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
LoopHeader.getBlock(), ConvergenceTokenStack.back()));

// Create an exit block for when the condition fails, which will
// also become the break target.
JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
Expand Down Expand Up @@ -1079,6 +1083,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
// block.
if (llvm::EnableSingleByteCoverage)
incrementProfileCounter(&S);

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.pop_back();
}

void CodeGenFunction::EmitDoStmt(const DoStmt &S,
Expand All @@ -1098,6 +1105,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
EmitBlockWithFallThrough(LoopBody, S.getBody());
else
EmitBlockWithFallThrough(LoopBody, &S);

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.push_back(
emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));

{
RunCleanupsScope BodyScope(*this);
EmitStmt(S.getBody());
Expand Down Expand Up @@ -1151,6 +1163,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
// block.
if (llvm::EnableSingleByteCoverage)
incrementProfileCounter(&S);

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.pop_back();
}

void CodeGenFunction::EmitForStmt(const ForStmt &S,
Expand All @@ -1170,6 +1185,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
llvm::BasicBlock *CondBlock = CondDest.getBlock();
EmitBlock(CondBlock);

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.push_back(
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));

const SourceRange &R = S.getSourceRange();
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
SourceLocToDebugLoc(R.getBegin()),
Expand Down Expand Up @@ -1279,6 +1298,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
// block.
if (llvm::EnableSingleByteCoverage)
incrementProfileCounter(&S);

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.pop_back();
}

void
Expand All @@ -1301,6 +1323,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
EmitBlock(CondBlock);

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.push_back(
emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));

const SourceRange &R = S.getSourceRange();
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
SourceLocToDebugLoc(R.getBegin()),
Expand Down Expand Up @@ -1369,6 +1395,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
// block.
if (llvm::EnableSingleByteCoverage)
incrementProfileCounter(&S);

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.pop_back();
}

void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
Expand Down Expand Up @@ -3158,3 +3187,67 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {

return F;
}

namespace {
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
// std::nullptr otherwise.
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
for (auto &I : *BB) {
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
return II;
}
return nullptr;
}

} // namespace

llvm::CallBase *
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
llvm::Value *ParentToken) {
llvm::Value *bundleArgs[] = {ParentToken};
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
auto Output = llvm::CallBase::addOperandBundle(
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
Input->replaceAllUsesWith(Output);
Input->eraseFromParent();
return Output;
}

llvm::IntrinsicInst *
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
llvm::Value *ParentToken) {
CGBuilderTy::InsertPoint IP = Builder.saveIP();
if (BB->empty())
Builder.SetInsertPoint(BB);
else
Builder.SetInsertPoint(BB->getFirstInsertionPt());

llvm::CallBase *CB = Builder.CreateIntrinsic(
llvm::Intrinsic::experimental_convergence_loop, {}, {});
Builder.restoreIP(IP);

llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
return cast<llvm::IntrinsicInst>(I);
}

llvm::IntrinsicInst *
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
llvm::BasicBlock *BB = &F->getEntryBlock();
llvm::IntrinsicInst *Token = getConvergenceToken(BB);
if (Token)
return Token;

// Adding a convergence token requires the function to be marked as
// convergent.
F->setConvergent();

CGBuilderTy::InsertPoint IP = Builder.saveIP();
Builder.SetInsertPoint(&BB->front());
llvm::CallBase *I = Builder.CreateIntrinsic(
llvm::Intrinsic::experimental_convergence_entry, {}, {});
assert(isa<llvm::IntrinsicInst>(I));
Builder.restoreIP(IP);

return cast<llvm::IntrinsicInst>(I);
}
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
assert(DeferredDeactivationCleanupStack.empty() &&
"mismatched activate/deactivate of cleanups!");

if (CGM.shouldEmitConvergenceTokens()) {
ConvergenceTokenStack.pop_back();
assert(ConvergenceTokenStack.empty() &&
"mismatched push/pop in convergence stack!");
}

bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
&& NumSimpleReturnExprs == NumReturnExprs
&& ReturnBlock.getBlock()->use_empty();
Expand Down Expand Up @@ -1277,6 +1283,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
if (CurFuncDecl)
if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
LargestVectorWidth = VecWidth->getVectorWidth();

if (CGM.shouldEmitConvergenceTokens())
ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
}

void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {
Expand Down
9 changes: 8 additions & 1 deletion clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ class CodeGenFunction : public CodeGenTypeCache {
/// Stack to track the Logical Operator recursion nest for MC/DC.
SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;

/// Stack to track the controlled convergence tokens.
SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;

/// Number of nested loop to be consumed by the last surrounding
/// loop-associated directive.
int ExpectedOMPLoopDepth = 0;
Expand Down Expand Up @@ -5076,7 +5079,11 @@ class CodeGenFunction : public CodeGenTypeCache {
const llvm::Twine &Name = "");
// Adds a convergence_ctrl token to |Input| and emits the required parent
// convergence instructions.
llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
template <typename CallType>
CallType *addControlledConvergenceToken(CallType *Input) {
return cast<CallType>(
addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
}

private:
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,14 @@ class CodeGenModule : public CodeGenTypeCache {
void AddGlobalDtor(llvm::Function *Dtor, int Priority = 65535,
bool IsDtorAttrFunc = false);

// Return whether structured convergence intrinsics should be generated for
// this target.
bool shouldEmitConvergenceTokens() const {
// TODO: this should probably become unconditional once the controlled
// convergence becomes the norm.
return getTriple().isSPIRVLogical();
}

private:
llvm::Constant *GetOrCreateLLVMFunction(
StringRef MangledName, llvm::Type *Ty, GlobalDecl D, bool ForVTable,
Expand Down
1 change: 0 additions & 1 deletion clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
// RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV

RWBuffer<float> Buf;
Expand Down
Loading
Loading