Skip to content

[clang][HLSL][SPRI-V] Add convergence intrinsics #80680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4599,6 +4599,12 @@ def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Prototype = "unsigned int(bool)";
}

def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_get_lane_index"];
let Attributes = [NoThrow, Const];
let Prototype = "unsigned int()";
}

def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const];
Expand Down
93 changes: 93 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,8 +1131,92 @@ struct BitTest {

static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
};

// Returns the first convergence entry/loop/anchor instruction found in |BB|.
// std::nullptr otherwise.
llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
for (auto &I : *BB) {
auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
return II;
}
return nullptr;
}

} // namespace

llvm::CallBase *
CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
llvm::Value *ParentToken) {
llvm::Value *bundleArgs[] = {ParentToken};
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
auto Output = llvm::CallBase::addOperandBundle(
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
Input->replaceAllUsesWith(Output);
Input->eraseFromParent();
return Output;
}

llvm::IntrinsicInst *
CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
llvm::Value *ParentToken) {
CGBuilderTy::InsertPoint IP = Builder.saveIP();
Builder.SetInsertPoint(&BB->front());
auto CB = Builder.CreateIntrinsic(
llvm::Intrinsic::experimental_convergence_loop, {}, {});
Builder.restoreIP(IP);

auto I = addConvergenceControlToken(CB, ParentToken);
return cast<llvm::IntrinsicInst>(I);
}

llvm::IntrinsicInst *
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
auto *BB = &F->getEntryBlock();
auto *token = getConvergenceToken(BB);
if (token)
return token;

// Adding a convergence token requires the function to be marked as
// convergent.
F->setConvergent();

CGBuilderTy::InsertPoint IP = Builder.saveIP();
Builder.SetInsertPoint(&BB->front());
auto I = Builder.CreateIntrinsic(
llvm::Intrinsic::experimental_convergence_entry, {}, {});
assert(isa<llvm::IntrinsicInst>(I));
Builder.restoreIP(IP);

return cast<llvm::IntrinsicInst>(I);
}

llvm::IntrinsicInst *
CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
assert(LI != nullptr);

auto *token = getConvergenceToken(LI->getHeader());
if (token)
return token;

llvm::IntrinsicInst *PII =
LI->getParent()
? emitConvergenceLoopToken(
LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
: getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());

return emitConvergenceLoopToken(LI->getHeader(), PII);
}

llvm::CallBase *
CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
llvm::Value *ParentToken =
LoopStack.hasInfo()
? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
: getOrEmitConvergenceEntryToken(Input->getFunction());
return addConvergenceControlToken(Input, ParentToken);
}

BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
switch (BuiltinID) {
// Main portable variants.
Expand Down Expand Up @@ -5803,6 +5887,15 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
{NDRange, Kernel, Block}));
}

case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this PR. Can you move this to Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E)

auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
{}, false, true));
if (getTarget().getTriple().isSPIRVLogical())
CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
return RValue::get(CI);
}

case Builtin::BI__builtin_store_half:
case Builtin::BI__builtin_store_halff: {
Value *Val = EmitScalarExpr(E->getArg(0));
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5692,6 +5692,9 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
if (!CI->getType()->isVoidTy())
CI->setName("call");

if (getTarget().getTriple().isSPIRVLogical() && CI->isConvergent())
CI = addControlledConvergenceToken(CI);

// Update largest vector width from the return type.
LargestVectorWidth =
std::max(LargestVectorWidth, getMaxVectorWidth(CI->getType()));
Expand Down
7 changes: 6 additions & 1 deletion clang/lib/CodeGen/CGLoopInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ class LoopInfo {
/// been processed.
void finish();

/// Returns the first outer loop containing this loop if any, nullptr
/// otherwise.
const LoopInfo *getParent() const { return Parent; }

private:
/// Loop ID metadata.
llvm::TempMDTuple TempLoopID;
Expand Down Expand Up @@ -291,12 +295,13 @@ class LoopInfoStack {
/// Set no progress for the next loop pushed.
void setMustProgress(bool P) { StagedAttrs.MustProgress = P; }

private:
/// Returns true if there is LoopInfo on the stack.
bool hasInfo() const { return !Active.empty(); }
/// Return the LoopInfo for the current loop. HasInfo should be called
/// first to ensure LoopInfo is present.
const LoopInfo &getInfo() const { return *Active.back(); }

private:
/// The set of attributes that will be applied to the next pushed loop.
LoopAttributes StagedAttrs;
/// Stack of active loops.
Expand Down
19 changes: 19 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4870,6 +4870,25 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *emitBoolVecConversion(llvm::Value *SrcVec,
unsigned NumElementsDst,
const llvm::Twine &Name = "");
// Adds a convergence_ctrl token to |Input| and emits the required parent
// convergence instructions.
llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);

private:
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
// as it's parent convergence instr.
llvm::IntrinsicInst *emitConvergenceLoopToken(llvm::BasicBlock *BB,
llvm::Value *ParentToken);
// Adds a convergence_ctrl token with |ParentToken| as parent convergence
// instr to the call |Input|.
llvm::CallBase *addConvergenceControlToken(llvm::CallBase *Input,
llvm::Value *ParentToken);
// Find the convergence_entry instruction |F|, or emits ones if none exists.
// Returns the convergence instruction.
llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
// Find the convergence_loop instruction for the loop defined by |LI|, or
// emits one if none exists. Returns the convergence instruction.
llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);

private:
llvm::MDNode *getRangeForLoadFromType(QualType Ty);
Expand Down
7 changes: 6 additions & 1 deletion clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,12 @@ float4 trunc(float4);
/// true, across all active lanes in the current wave.
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_count_bits)
uint WaveActiveCountBits(bool Val);
__attribute__((convergent)) uint WaveActiveCountBits(bool Val);

/// \brief Returns the index of the current lane within the current wave.
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_get_lane_index)
__attribute__((convergent)) uint WaveGetLaneIndex();

} // namespace hlsl
#endif //_HLSL_HLSL_INTRINSICS_H_
40 changes: 40 additions & 0 deletions clang/test/CodeGenHLSL/builtins/wave_get_lane_index_do_while.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s

// CHECK: define spir_func void @main() [[A0:#[0-9]+]] {
void main() {
// CHECK: entry:
// CHECK: %[[CT_ENTRY:[0-9]+]] = call token @llvm.experimental.convergence.entry()
// CHECK: br label %[[LABEL_WHILE_COND:.+]]
int cond = 0;

// CHECK: [[LABEL_WHILE_COND]]:
// CHECK: %[[CT_LOOP:[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[CT_ENTRY]]) ]
// CHECK: br label %[[LABEL_WHILE_BODY:.+]]
while (true) {

// CHECK: [[LABEL_WHILE_BODY]]:
// CHECK: br i1 {{%.+}}, label %[[LABEL_IF_THEN:.+]], label %[[LABEL_IF_END:.+]]

// CHECK: [[LABEL_IF_THEN]]:
// CHECK: call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[CT_LOOP]]) ]
// CHECK: br label %[[LABEL_WHILE_END:.+]]
if (cond == 2) {
uint index = WaveGetLaneIndex();
break;
}

// CHECK: [[LABEL_IF_END]]:
// CHECK: br label %[[LABEL_WHILE_COND]]
cond++;
}

// CHECK: [[LABEL_WHILE_END]]:
// CHECK: ret void
}

// CHECK-DAG: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]

// CHECK-DAG: attributes [[A0]] = {{{.*}}convergent{{.*}}}
// CHECK-DAG: attributes [[A1]] = {{{.*}}convergent{{.*}}}

14 changes: 14 additions & 0 deletions clang/test/CodeGenHLSL/builtins/wave_get_lane_index_simple.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s

// CHECK: define spir_func noundef i32 @_Z6test_1v() [[A0:#[0-9]+]] {
// CHECK: %[[CI:[0-9]+]] = call token @llvm.experimental.convergence.entry()
// CHECK: call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[CI]]) ]
uint test_1() {
return WaveGetLaneIndex();
}

// CHECK: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]

// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
// CHECK-DAG: attributes [[A1]] = { {{.*}}convergent{{.*}} }
21 changes: 21 additions & 0 deletions clang/test/CodeGenHLSL/builtins/wave_get_lane_index_subcall.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s

// CHECK: define spir_func noundef i32 @_Z6test_1v() [[A0:#[0-9]+]] {
// CHECK: %[[C1:[0-9]+]] = call token @llvm.experimental.convergence.entry()
// CHECK: call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[C1]]) ]
uint test_1() {
return WaveGetLaneIndex();
}

// CHECK-DAG: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]

// CHECK: define spir_func noundef i32 @_Z6test_2v() [[A0]] {
// CHECK: %[[C2:[0-9]+]] = call token @llvm.experimental.convergence.entry()
// CHECK: call spir_func noundef i32 @_Z6test_1v() [ "convergencectrl"(token %[[C2]]) ]
uint test_2() {
return test_1();
}

// CHECK-DAG: attributes [[A0]] = {{{.*}}convergent{{.*}}}
// CHECK-DAG: attributes [[A1]] = {{{.*}}convergent{{.*}}}
13 changes: 13 additions & 0 deletions llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,19 @@ class ConvergenceControlInst : public IntrinsicInst {
static bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}

// Returns the convergence intrinsic referenced by |I|'s convergencectrl
// attribute if any.
static IntrinsicInst *getParentConvergenceToken(Instruction *I) {
auto *CI = dyn_cast<llvm::CallInst>(I);
if (!CI)
return nullptr;

auto Bundle = CI->getOperandBundle(llvm::LLVMContext::OB_convergencectrl);
assert(Bundle->Inputs.size() == 1 &&
Bundle->Inputs[0]->getType()->isTokenTy());
return dyn_cast<llvm::IntrinsicInst>(Bundle->Inputs[0].get());
}
};

} // end namespace llvm
Expand Down