Skip to content

Commit 08964d6

Browse files
authored
[HLSL][SPIRV] Handle uint type for spec constant (#145577)
The testing only tried `unsigned int` and not `uint`. We want to correctly handle these surgared types as specialization constants.
1 parent 4233ca1 commit 08964d6

File tree

3 files changed

+47
-17
lines changed

3 files changed

+47
-17
lines changed

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ static ResourceClass getResourceClass(RegisterType RT) {
120120
llvm_unreachable("unexpected RegisterType value");
121121
}
122122

123-
static Builtin::ID getSpecConstBuiltinId(QualType Type) {
123+
static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
124124
const auto *BT = dyn_cast<BuiltinType>(Type);
125125
if (!BT) {
126126
if (!Type->isEnumeralType())
@@ -654,7 +654,8 @@ SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
654654

655655
auto *VD = cast<VarDecl>(D);
656656

657-
if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
657+
if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
658+
Builtin::NotBuiltin) {
658659
Diag(VD->getLocation(), diag::err_specialization_const);
659660
return nullptr;
660661
}
@@ -3972,7 +3973,8 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
39723973
return false;
39733974
}
39743975

3975-
Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
3976+
Builtin::ID BID =
3977+
getSpecConstBuiltinId(VDecl->getType()->getUnqualifiedDesugaredType());
39763978

39773979
// Argument 1: The ID from the attribute
39783980
int ConstantID = ConstIdAttr->getId();

clang/test/AST/HLSL/vk.spec-constant.usage.hlsl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ const unsigned short ushort_const = 10;
6464
[[vk::constant_id(6)]]
6565
const unsigned int uint_const = 12;
6666

67+
// CHECK: VarDecl {{.*}} uint_const_2 'const hlsl_private uint':'const hlsl_private unsigned int' static cinit
68+
// CHECK-NEXT: CallExpr {{.*}} 'unsigned int'
69+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int (*)(unsigned int, unsigned int) noexcept' <FunctionToPointerDecay>
70+
// CHECK-NEXT: DeclRefExpr {{.*}} 'unsigned int (unsigned int, unsigned int) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_uint' 'unsigned int (unsigned int, unsigned int) noexcept'
71+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
72+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 6
73+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
74+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 12
75+
[[vk::constant_id(6)]]
76+
const uint uint_const_2 = 12;
77+
6778

6879
// CHECK: VarDecl {{.*}} ulong_const 'const hlsl_private unsigned long long' static cinit
6980
// CHECK-NEXT: CallExpr {{.*}} 'unsigned long long'

clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ const unsigned short ushort_const = 10;
2121
[[vk::constant_id(6)]]
2222
const unsigned int uint_const = 12;
2323

24+
[[vk::constant_id(6)]]
25+
const uint uint_const_2 = 12;
26+
2427
[[vk::constant_id(7)]]
2528
const unsigned long long ulong_const = 25;
2629

@@ -50,6 +53,7 @@ void main() {
5053
long long l = long_const;
5154
unsigned short us = ushort_const;
5255
unsigned int ui = uint_const;
56+
uint ui2 = uint_const_2;
5357
unsigned long long ul = ulong_const;
5458
half h = half_const;
5559
float f = float_const;
@@ -63,6 +67,7 @@ void main() {
6367
// CHECK: @_ZL10long_const = internal addrspace(10) global i64 0, align 8
6468
// CHECK: @_ZL12ushort_const = internal addrspace(10) global i16 0, align 2
6569
// CHECK: @_ZL10uint_const = internal addrspace(10) global i32 0, align 4
70+
// CHECK: @_ZL12uint_const_2 = internal addrspace(10) global i32 0, align 4
6671
// CHECK: @_ZL11ulong_const = internal addrspace(10) global i64 0, align 8
6772
// CHECK: @_ZL10half_const = internal addrspace(10) global float 0.000000e+00, align 4
6873
// CHECK: @_ZL11float_const = internal addrspace(10) global float 0.000000e+00, align 4
@@ -79,6 +84,7 @@ void main() {
7984
// CHECK-NEXT: [[L:%.*]] = alloca i64, align 8
8085
// CHECK-NEXT: [[US:%.*]] = alloca i16, align 2
8186
// CHECK-NEXT: [[UI:%.*]] = alloca i32, align 4
87+
// CHECK-NEXT: [[UI2:%.*]] = alloca i32, align 4
8288
// CHECK-NEXT: [[UL:%.*]] = alloca i64, align 8
8389
// CHECK-NEXT: [[H:%.*]] = alloca float, align 4
8490
// CHECK-NEXT: [[F:%.*]] = alloca float, align 4
@@ -98,16 +104,18 @@ void main() {
98104
// CHECK-NEXT: store i16 [[TMP5]], ptr [[US]], align 2
99105
// CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(10) @_ZL10uint_const, align 4
100106
// CHECK-NEXT: store i32 [[TMP6]], ptr [[UI]], align 4
101-
// CHECK-NEXT: [[TMP7:%.*]] = load i64, ptr addrspace(10) @_ZL11ulong_const, align 8
102-
// CHECK-NEXT: store i64 [[TMP7]], ptr [[UL]], align 8
103-
// CHECK-NEXT: [[TMP8:%.*]] = load float, ptr addrspace(10) @_ZL10half_const, align 4
104-
// CHECK-NEXT: store float [[TMP8]], ptr [[H]], align 4
105-
// CHECK-NEXT: [[TMP9:%.*]] = load float, ptr addrspace(10) @_ZL11float_const, align 4
106-
// CHECK-NEXT: store float [[TMP9]], ptr [[F]], align 4
107-
// CHECK-NEXT: [[TMP10:%.*]] = load double, ptr addrspace(10) @_ZL12double_const, align 8
108-
// CHECK-NEXT: store double [[TMP10]], ptr [[D]], align 8
109-
// CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr addrspace(10) @_ZL10enum_const, align 4
110-
// CHECK-NEXT: store i32 [[TMP11]], ptr [[E]], align 4
107+
// CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr addrspace(10) @_ZL12uint_const_2, align 4
108+
// CHECK-NEXT: store i32 [[TMP7]], ptr [[UI2]], align 4
109+
// CHECK-NEXT: [[TMP8:%.*]] = load i64, ptr addrspace(10) @_ZL11ulong_const, align 8
110+
// CHECK-NEXT: store i64 [[TMP8]], ptr [[UL]], align 8
111+
// CHECK-NEXT: [[TMP9:%.*]] = load float, ptr addrspace(10) @_ZL10half_const, align 4
112+
// CHECK-NEXT: store float [[TMP9]], ptr [[H]], align 4
113+
// CHECK-NEXT: [[TMP10:%.*]] = load float, ptr addrspace(10) @_ZL11float_const, align 4
114+
// CHECK-NEXT: store float [[TMP10]], ptr [[F]], align 4
115+
// CHECK-NEXT: [[TMP11:%.*]] = load double, ptr addrspace(10) @_ZL12double_const, align 8
116+
// CHECK-NEXT: store double [[TMP11]], ptr [[D]], align 8
117+
// CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(10) @_ZL10enum_const, align 4
118+
// CHECK-NEXT: store i32 [[TMP12]], ptr [[E]], align 4
111119
// CHECK-NEXT: ret void
112120
//
113121
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init(
@@ -169,12 +177,21 @@ void main() {
169177
// CHECK-SAME: ) #[[ATTR3]] {
170178
// CHECK-NEXT: [[ENTRY:.*:]]
171179
// CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry()
180+
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @_Z20__spirv_SpecConstantij(i32 6, i32 12)
181+
// CHECK-NEXT: store i32 [[TMP1]], ptr addrspace(10) @_ZL12uint_const_2, align 4
182+
// CHECK-NEXT: ret void
183+
//
184+
//
185+
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.7(
186+
// CHECK-SAME: ) #[[ATTR3]] {
187+
// CHECK-NEXT: [[ENTRY:.*:]]
188+
// CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry()
172189
// CHECK-NEXT: [[TMP1:%.*]] = call i64 @_Z20__spirv_SpecConstantiy(i32 7, i64 25)
173190
// CHECK-NEXT: store i64 [[TMP1]], ptr addrspace(10) @_ZL11ulong_const, align 8
174191
// CHECK-NEXT: ret void
175192
//
176193
//
177-
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.7(
194+
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.8(
178195
// CHECK-SAME: ) #[[ATTR3]] {
179196
// CHECK-NEXT: [[ENTRY:.*:]]
180197
// CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry()
@@ -183,7 +200,7 @@ void main() {
183200
// CHECK-NEXT: ret void
184201
//
185202
//
186-
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.8(
203+
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.9(
187204
// CHECK-SAME: ) #[[ATTR3]] {
188205
// CHECK-NEXT: [[ENTRY:.*:]]
189206
// CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry()
@@ -192,7 +209,7 @@ void main() {
192209
// CHECK-NEXT: ret void
193210
//
194211
//
195-
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.9(
212+
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.10(
196213
// CHECK-SAME: ) #[[ATTR3]] {
197214
// CHECK-NEXT: [[ENTRY:.*:]]
198215
// CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry()
@@ -201,7 +218,7 @@ void main() {
201218
// CHECK-NEXT: ret void
202219
//
203220
//
204-
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.10(
221+
// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.11(
205222
// CHECK-SAME: ) #[[ATTR3]] {
206223
// CHECK-NEXT: [[ENTRY:.*:]]
207224
// CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry()

0 commit comments

Comments
 (0)