Skip to content

Commit 0f77bdd

Browse files
authored
[HLSL] generate hlsl.wavesize attribute (#107176)
Generate function attribute hlsl.wavesize from [WaveSize]. For #70118
1 parent 7046a9f commit 0f77bdd

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,13 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
338338
NumThreadsAttr->getZ());
339339
Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
340340
}
341+
if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr<HLSLWaveSizeAttr>()) {
342+
const StringRef WaveSizeKindStr = "hlsl.wavesize";
343+
std::string WaveSizeStr =
344+
formatv("{0},{1},{2}", WaveSizeAttr->getMin(), WaveSizeAttr->getMax(),
345+
WaveSizeAttr->getPreferred());
346+
Fn->addFnAttr(WaveSizeKindStr, WaveSizeStr);
347+
}
341348
Fn->addFnAttr(llvm::Attribute::NoInline);
342349
}
343350

clang/test/CodeGenHLSL/wavesize.hlsl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.6-compute %s -DSM66 -hlsl-entry foo \
3+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s
4+
5+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
6+
// RUN: dxil-pc-shadermodel6.8-compute %s -DNO_PREFERR -hlsl-entry foo \
7+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=NO_PREFERR
8+
9+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
10+
// RUN: dxil-pc-shadermodel6.8-compute %s -hlsl-entry foo \
11+
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=CHECK-SM68
12+
13+
14+
// Make sure wavesize attribute get correct value for sm66 and sm68.
15+
// CHECK:define void @foo()
16+
// CHECK:"hlsl.wavesize"="8,0,0"
17+
18+
// NO_PREFERR:define void @foo()
19+
// NO_PREFERR:"hlsl.wavesize"="8,128,0"
20+
21+
// CHECK-SM68:define void @foo()
22+
// CHECK-SM68:"hlsl.wavesize"="8,128,64"
23+
24+
[numthreads(16,8,1)]
25+
#ifdef SM66
26+
[WaveSize(8)]
27+
#elif NO_PREFERR
28+
[WaveSize(8, 128)]
29+
#else
30+
[WaveSize(8, 128, 64)]
31+
#endif
32+
void foo() {
33+
34+
}

0 commit comments

Comments
 (0)