Skip to content

Commit 274637d

Browse files
authored
[HLSL] Implement Append and Consume methods on Append/ConsumeStructuredBuffer (#118536)
The methods are using existing clang builtins `__builtin_hlsl_buffer_update_counter` and `__builtin_hlsl_resource_getpointer` to update the buffer counter and then load or store the value. Fixes #112968
1 parent a0eb794 commit 274637d

File tree

3 files changed

+142
-24
lines changed

3 files changed

+142
-24
lines changed

clang/lib/Sema/HLSLExternalSemaSource.cpp

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ class BuiltinTypeDeclBuilder {
246246
BuiltinTypeDeclBuilder &addDecrementCounterMethod();
247247
BuiltinTypeDeclBuilder &addHandleAccessFunction(DeclarationName &Name,
248248
bool IsConst, bool IsRef);
249+
BuiltinTypeDeclBuilder &addAppendMethod();
250+
BuiltinTypeDeclBuilder &addConsumeMethod();
249251
};
250252

251253
struct TemplateParameterListBuilder {
@@ -443,14 +445,26 @@ struct BuiltinTypeMethodBuilder {
443445
llvm::SmallVector<Stmt *> StmtsList;
444446

445447
// Argument placeholders, inspired by std::placeholder. These are the indices
446-
// of arguments to forward to `callBuiltin`, and additionally `Handle` which
447-
// refers to the resource handle.
448-
enum class PlaceHolder { _0, _1, _2, _3, Handle = 127 };
448+
// of arguments to forward to `callBuiltin` and other method builder methods.
449+
// Additional special values are:
450+
// Handle - refers to the resource handle.
451+
// LastStmt - refers to the last statement in the method body; referencing
452+
// LastStmt will remove the statement from the method body since
453+
// it will be linked from the new expression being constructed.
454+
enum class PlaceHolder { _0, _1, _2, _3, Handle = 128, LastStmt };
449455

450456
Expr *convertPlaceholder(PlaceHolder PH) {
451457
if (PH == PlaceHolder::Handle)
452458
return getResourceHandleExpr();
453459

460+
if (PH == PlaceHolder::LastStmt) {
461+
assert(!StmtsList.empty() && "no statements in the list");
462+
Stmt *LastStmt = StmtsList.pop_back_val();
463+
assert(isa<ValueStmt>(LastStmt) &&
464+
"last statement does not have a value");
465+
return cast<ValueStmt>(LastStmt)->getExprStmt();
466+
}
467+
454468
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
455469
ParmVarDecl *ParamDecl = Method->getParamDecl(static_cast<unsigned>(PH));
456470
return DeclRefExpr::Create(
@@ -573,17 +587,25 @@ struct BuiltinTypeMethodBuilder {
573587
return *this;
574588
}
575589

576-
BuiltinTypeMethodBuilder &dereference() {
577-
assert(!StmtsList.empty() && "Nothing to dereference");
578-
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
590+
template <typename TLHS, typename TRHS>
591+
BuiltinTypeMethodBuilder &assign(TLHS LHS, TRHS RHS) {
592+
Expr *LHSExpr = convertPlaceholder(LHS);
593+
Expr *RHSExpr = convertPlaceholder(RHS);
594+
Stmt *AssignStmt = BinaryOperator::Create(
595+
DeclBuilder.SemaRef.getASTContext(), LHSExpr, RHSExpr, BO_Assign,
596+
LHSExpr->getType(), ExprValueKind::VK_PRValue,
597+
ExprObjectKind::OK_Ordinary, SourceLocation(), FPOptionsOverride());
598+
StmtsList.push_back(AssignStmt);
599+
return *this;
600+
}
579601

580-
Expr *LastExpr = dyn_cast<Expr>(StmtsList.back());
581-
assert(LastExpr && "No expression to dereference");
582-
Expr *Deref = UnaryOperator::Create(
583-
AST, LastExpr, UO_Deref, LastExpr->getType()->getPointeeType(),
584-
VK_PRValue, OK_Ordinary, SourceLocation(),
585-
/*CanOverflow=*/false, FPOptionsOverride());
586-
StmtsList.pop_back();
602+
template <typename T> BuiltinTypeMethodBuilder &dereference(T Ptr) {
603+
Expr *PtrExpr = convertPlaceholder(Ptr);
604+
Expr *Deref =
605+
UnaryOperator::Create(DeclBuilder.SemaRef.getASTContext(), PtrExpr,
606+
UO_Deref, PtrExpr->getType()->getPointeeType(),
607+
VK_PRValue, OK_Ordinary, SourceLocation(),
608+
/*CanOverflow=*/false, FPOptionsOverride());
587609
StmtsList.push_back(Deref);
588610
return *this;
589611
}
@@ -685,7 +707,35 @@ BuiltinTypeDeclBuilder::addHandleAccessFunction(DeclarationName &Name,
685707
.addParam("Index", AST.UnsignedIntTy)
686708
.callBuiltin("__builtin_hlsl_resource_getpointer", ElemPtrTy, PH::Handle,
687709
PH::_0)
688-
.dereference()
710+
.dereference(PH::LastStmt)
711+
.finalizeMethod();
712+
}
713+
714+
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() {
715+
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
716+
ASTContext &AST = SemaRef.getASTContext();
717+
QualType ElemTy = getHandleElementType();
718+
return BuiltinTypeMethodBuilder(*this, "Append", AST.VoidTy)
719+
.addParam("value", ElemTy)
720+
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
721+
PH::Handle, getConstantIntExpr(1))
722+
.callBuiltin("__builtin_hlsl_resource_getpointer",
723+
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
724+
.dereference(PH::LastStmt)
725+
.assign(PH::LastStmt, PH::_0)
726+
.finalizeMethod();
727+
}
728+
729+
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addConsumeMethod() {
730+
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
731+
ASTContext &AST = SemaRef.getASTContext();
732+
QualType ElemTy = getHandleElementType();
733+
return BuiltinTypeMethodBuilder(*this, "Consume", ElemTy)
734+
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
735+
PH::Handle, getConstantIntExpr(-1))
736+
.callBuiltin("__builtin_hlsl_resource_getpointer",
737+
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
738+
.dereference(PH::LastStmt)
689739
.finalizeMethod();
690740
}
691741

@@ -915,6 +965,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
915965
onCompletion(Decl, [this](CXXRecordDecl *Decl) {
916966
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer,
917967
/*IsROV=*/false, /*RawBuffer=*/true)
968+
.addAppendMethod()
918969
.completeDefinition();
919970
});
920971

@@ -925,6 +976,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
925976
onCompletion(Decl, [this](CXXRecordDecl *Decl) {
926977
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer,
927978
/*IsROV=*/false, /*RawBuffer=*/true)
979+
.addConsumeMethod()
928980
.completeDefinition();
929981
});
930982

clang/test/AST/HLSL/StructuredBuffers-AST.hlsl

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
//
2121
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \
2222
// RUN: -DRESOURCE=AppendStructuredBuffer %s | FileCheck -DRESOURCE=AppendStructuredBuffer \
23-
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s
23+
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-APPEND %s
2424
//
2525
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \
2626
// RUN: -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \
2727
// RUN: -check-prefix=EMPTY %s
2828
//
2929
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \
3030
// RUN: -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \
31-
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s
31+
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-CONSUME %s
3232
//
3333
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \
3434
// RUN: -DRESOURCE=RasterizerOrderedStructuredBuffer %s | FileCheck -DRESOURCE=RasterizerOrderedStructuredBuffer \
@@ -135,6 +135,48 @@ RESOURCE<float> Buffer;
135135
// CHECK-COUNTER-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' -1
136136
// CHECK-COUNTER-NEXT: AlwaysInlineAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit always_inline
137137

138+
// CHECK-APPEND: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Append 'void (element_type)'
139+
// CHECK-APPEND-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> value 'element_type'
140+
// CHECK-APPEND-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
141+
// CHECK-APPEND-NEXT: BinaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' '='
142+
// CHECK-APPEND-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow
143+
// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *'
144+
// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept'
145+
// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
146+
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
147+
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]]
148+
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
149+
// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
150+
// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int'
151+
// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept'
152+
// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
153+
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
154+
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]]
155+
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
156+
// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
157+
// CHECK-APPEND-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' 1
158+
// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' ParmVar 0x{{[0-9A-Fa-f]+}} 'value' 'element_type'
159+
160+
// CHECK-CONSUME: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Consume 'element_type ()'
161+
// CHECK-CONSUME-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
162+
// CHECK-CONSUME-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
163+
// CHECK-CONSUME-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow
164+
// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *'
165+
// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept'
166+
// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
167+
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
168+
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]]
169+
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
170+
// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
171+
// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int'
172+
// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept'
173+
// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
174+
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
175+
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]]
176+
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
177+
// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
178+
// CHECK-CONSUME-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' -1
179+
138180
// CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> class [[RESOURCE]] definition
139181

140182
// CHECK: TemplateArgument type 'float'

clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,45 @@
55

66
RWStructuredBuffer<float> RWSB1 : register(u0);
77
RWStructuredBuffer<float> RWSB2 : register(u1);
8+
AppendStructuredBuffer<float> ASB : register(u2);
9+
ConsumeStructuredBuffer<float> CSB : register(u3);
810

911
// CHECK: %"class.hlsl::RWStructuredBuffer" = type { target("dx.RawBuffer", float, 1, 0) }
1012

11-
export void TestIncrementCounter() {
12-
RWSB1.IncrementCounter();
13+
export int TestIncrementCounter() {
14+
return RWSB1.IncrementCounter();
1315
}
1416

15-
// CHECK: define void @_Z20TestIncrementCounterv()
16-
// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
17+
// CHECK: define noundef i32 @_Z20TestIncrementCounterv()
18+
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
19+
// CHECK-DXIL: ret i32 %[[INDEX]]
20+
export int TestDecrementCounter() {
21+
return RWSB2.DecrementCounter();
22+
}
23+
24+
// CHECK: define noundef i32 @_Z20TestDecrementCounterv()
25+
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1)
26+
// CHECK-DXIL: ret i32 %[[INDEX]]
27+
28+
export void TestAppend(float value) {
29+
ASB.Append(value);
30+
}
31+
32+
// CHECK: define void @_Z10TestAppendf(float noundef %value)
33+
// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %value.addr, align 4
34+
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
35+
// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i32 %[[INDEX]])
36+
// CHECK-DXIL: store float %[[VALUE]], ptr %[[RESPTR]], align 4
1737

18-
export void TestDecrementCounter() {
19-
RWSB2.DecrementCounter();
38+
export float TestConsume() {
39+
return CSB.Consume();
2040
}
2141

22-
// CHECK: define void @_Z20TestDecrementCounterv()
23-
// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1)
42+
// CHECK: define noundef float @_Z11TestConsumev()
43+
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %1, i8 -1)
44+
// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %0, i32 %[[INDEX]])
45+
// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %[[RESPTR]], align 4
46+
// CHECK-DXIL: ret float %[[VALUE]]
2447

2548
// CHECK: declare i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i8)
49+
// CHECK: declare ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i32)

0 commit comments

Comments
 (0)