Skip to content

Commit f431347

Browse files
committed
[clang][HLSL][SPRI-V] Add convergence intrinsics
HLSL has wave operations and other kind of function which required the control flow to either be converged, or respect certain constraints as where and how to re-converge. At the HLSL level, the convergence are mostly obvious: the control flow is expected to re-converge at the end of a scope. Once translated to IR, HLSL scopes disapear. This means we need a way to communicate convergence restrictions down to the backend. For this, the SPIR-V backend uses convergence intrinsics. So this commit adds some code to generate convergence intrinsics when required. This commit is not to be submitted as-is (lacks testing), but should serve as a basis for an upcoming RFC. Signed-off-by: Nathan Gauër <[email protected]>
1 parent e88c255 commit f431347

File tree

5 files changed

+145
-1
lines changed

5 files changed

+145
-1
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,8 +1129,97 @@ struct BitTest {
11291129

11301130
static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
11311131
};
1132+
1133+
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
1134+
// std::nullopt otherwise.
1135+
std::optional<llvm::IntrinsicInst *> getConvergenceToken(llvm::BasicBlock *BB) {
1136+
for (auto &I : *BB) {
1137+
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
1138+
if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
1139+
return II;
1140+
}
1141+
return std::nullopt;
1142+
}
1143+
11321144
} // namespace
11331145

1146+
llvm::CallBase *
1147+
CodeGenFunction::AddConvergenceControlAttr(llvm::CallBase *Input,
1148+
llvm::Value *ParentToken) {
1149+
llvm::Value *bundleArgs[] = {ParentToken};
1150+
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
1151+
auto Output = llvm::CallBase::addOperandBundle(
1152+
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
1153+
Input->replaceAllUsesWith(Output);
1154+
Input->eraseFromParent();
1155+
return Output;
1156+
}
1157+
1158+
llvm::IntrinsicInst *
1159+
CodeGenFunction::EmitConvergenceLoop(llvm::BasicBlock *BB,
1160+
llvm::Value *ParentToken) {
1161+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1162+
Builder.SetInsertPoint(&BB->front());
1163+
auto CB = Builder.CreateIntrinsic(
1164+
llvm::Intrinsic::experimental_convergence_loop, {}, {});
1165+
Builder.restoreIP(IP);
1166+
1167+
auto I = AddConvergenceControlAttr(CB, ParentToken);
1168+
// Controlled convergence is incompatible with uncontrolled convergence.
1169+
// Removing any old attributes.
1170+
I->setNotConvergent();
1171+
1172+
assert(isa<llvm::IntrinsicInst>(I));
1173+
return dyn_cast<llvm::IntrinsicInst>(I);
1174+
}
1175+
1176+
llvm::IntrinsicInst *
1177+
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
1178+
auto *BB = &F->getEntryBlock();
1179+
auto token = getConvergenceToken(BB);
1180+
if (token.has_value())
1181+
return token.value();
1182+
1183+
// Adding a convergence token requires the function to be marked as
1184+
// convergent.
1185+
F->setConvergent();
1186+
1187+
CGBuilderTy::InsertPoint IP = Builder.saveIP();
1188+
Builder.SetInsertPoint(&BB->front());
1189+
auto I = Builder.CreateIntrinsic(
1190+
llvm::Intrinsic::experimental_convergence_entry, {}, {});
1191+
assert(isa<llvm::IntrinsicInst>(I));
1192+
Builder.restoreIP(IP);
1193+
1194+
return dyn_cast<llvm::IntrinsicInst>(I);
1195+
}
1196+
1197+
llvm::IntrinsicInst *
1198+
CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
1199+
assert(LI != nullptr);
1200+
1201+
auto token = getConvergenceToken(LI->getHeader());
1202+
if (token.has_value())
1203+
return *token;
1204+
1205+
llvm::IntrinsicInst *PII =
1206+
LI->getParent()
1207+
? EmitConvergenceLoop(LI->getHeader(),
1208+
getOrEmitConvergenceLoopToken(LI->getParent()))
1209+
: getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
1210+
1211+
return EmitConvergenceLoop(LI->getHeader(), PII);
1212+
}
1213+
1214+
llvm::CallBase *
1215+
CodeGenFunction::AddControlledConvergenceAttr(llvm::CallBase *Input) {
1216+
llvm::Value *ParentToken =
1217+
LoopStack.hasInfo()
1218+
? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
1219+
: getOrEmitConvergenceEntryToken(Input->getFunction());
1220+
return AddConvergenceControlAttr(Input, ParentToken);
1221+
}
1222+
11341223
BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
11351224
switch (BuiltinID) {
11361225
// Main portable variants.
@@ -5696,6 +5785,19 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
56965785
{NDRange, Kernel, Block}));
56975786
}
56985787

5788+
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
5789+
llvm::Type *BoolTy = llvm::IntegerType::get(getLLVMContext(), 1);
5790+
llvm::Value *Src0 = EmitScalarExpr(E->getArg(0));
5791+
auto *CI =
5792+
EmitRuntimeCall(CGM.CreateRuntimeFunction(
5793+
llvm::FunctionType::get(IntTy, {BoolTy}, false),
5794+
"__hlsl_wave_active_count_bits", {}),
5795+
{Src0});
5796+
if (getTarget().getTriple().isSPIRVLogical())
5797+
CI = dyn_cast<CallInst>(AddControlledConvergenceAttr(CI));
5798+
return RValue::get(CI);
5799+
}
5800+
56995801
case Builtin::BI__builtin_store_half:
57005802
case Builtin::BI__builtin_store_halff: {
57015803
Value *Val = EmitScalarExpr(E->getArg(0));

clang/lib/CodeGen/CGCall.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5687,6 +5687,10 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
56875687
if (!CI->getType()->isVoidTy())
56885688
CI->setName("call");
56895689

5690+
if (getTarget().getTriple().isSPIRVLogical() &&
5691+
CI->getCalledFunction()->isConvergent())
5692+
CI = AddControlledConvergenceAttr(CI);
5693+
56905694
// Update largest vector width from the return type.
56915695
LargestVectorWidth =
56925696
std::max(LargestVectorWidth, getMaxVectorWidth(CI->getType()));

clang/lib/CodeGen/CGLoopInfo.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class LoopInfo {
110110
/// been processed.
111111
void finish();
112112

113+
/// Returns the first outer loop containing this loop if any, nullptr
114+
/// otherwise.
115+
const LoopInfo *getParent() const { return Parent; }
116+
113117
private:
114118
/// Loop ID metadata.
115119
llvm::TempMDTuple TempLoopID;
@@ -291,12 +295,14 @@ class LoopInfoStack {
291295
/// Set no progress for the next loop pushed.
292296
void setMustProgress(bool P) { StagedAttrs.MustProgress = P; }
293297

294-
private:
295298
/// Returns true if there is LoopInfo on the stack.
296299
bool hasInfo() const { return !Active.empty(); }
300+
297301
/// Return the LoopInfo for the current loop. HasInfo should be called
298302
/// first to ensure LoopInfo is present.
299303
const LoopInfo &getInfo() const { return *Active.back(); }
304+
305+
private:
300306
/// The set of attributes that will be applied to the next pushed loop.
301307
LoopAttributes StagedAttrs;
302308
/// Stack of active loops.

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4154,6 +4154,25 @@ class CodeGenFunction : public CodeGenTypeCache {
41544154
void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl);
41554155
void checkTargetFeatures(SourceLocation Loc, const FunctionDecl *TargetDecl);
41564156

4157+
// Adds a convergence_ctrl attribute to |Input| and emits the required parent
4158+
// convergence instructions.
4159+
llvm::CallBase *AddControlledConvergenceAttr(llvm::CallBase *Input);
4160+
4161+
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
4162+
// as it's parent convergence instr.
4163+
llvm::IntrinsicInst *EmitConvergenceLoop(llvm::BasicBlock *BB,
4164+
llvm::Value *ParentToken);
4165+
// Adds a convergence_ctrl attribute with |ParentToken| as parent convergence
4166+
// instr to the call |Input|.
4167+
llvm::CallBase *AddConvergenceControlAttr(llvm::CallBase *Input,
4168+
llvm::Value *ParentToken);
4169+
// Find the convergence_entry instruction |F|, or emits ones if none exists.
4170+
// Returns the convergence instruction.
4171+
llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
4172+
// Find the convergence_loop instruction for the loop defined by |LI|, or
4173+
// emits one if none exists. Returns the convergence instruction.
4174+
llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
4175+
41574176
llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,
41584177
const Twine &name = "");
41594178
llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,

llvm/include/llvm/IR/IntrinsicInst.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,6 +1746,19 @@ class ConvergenceControlInst : public IntrinsicInst {
17461746
static bool classof(const Value *V) {
17471747
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
17481748
}
1749+
1750+
// Returns the convergence intrinsic referenced by |I|'s convergencectrl
1751+
// attribute if any.
1752+
static IntrinsicInst *getParentConvergenceToken(Instruction *I) {
1753+
auto *CI = dyn_cast<llvm::CallInst>(I);
1754+
if (!CI)
1755+
return nullptr;
1756+
1757+
auto Bundle = CI->getOperandBundle(llvm::LLVMContext::OB_convergencectrl);
1758+
assert(Bundle->Inputs.size() == 1 &&
1759+
Bundle->Inputs[0]->getType()->isTokenTy());
1760+
return dyn_cast<llvm::IntrinsicInst>(Bundle->Inputs[0].get());
1761+
}
17491762
};
17501763

17511764
} // end namespace llvm

0 commit comments

Comments
 (0)