Skip to content

Commit bd92e46

Browse files
authored
[HLSL] Implement RWBuffer::operator[] via __builtin_hlsl_resource_getpointer (#117017)
This introduces `__builtin_hlsl_resource_getpointer`, which lowers to `llvm.dx.resource.getpointer` and is used to implement indexing into resources. This will only work through the backend for typed buffers at this point, but the changes to structured buffers should be correct as far as the frontend is concerned. Note: We probably want this to return a reference in the HLSL device address space, but for now we're just using address space 0. Creating a device address space and updating this code can be done later as necessary. Fixes #95956
1 parent 1250a1d commit bd92e46

22 files changed

+297
-231
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4738,6 +4738,12 @@ def GetDeviceSideMangledName : LangBuiltin<"CUDA_LANG"> {
47384738
}
47394739

47404740
// HLSL
4741+
def HLSLResourceGetPointer : LangBuiltin<"HLSL_LANG"> {
4742+
let Spellings = ["__builtin_hlsl_resource_getpointer"];
4743+
let Attributes = [NoThrow];
4744+
let Prototype = "void(...)";
4745+
}
4746+
47414747
def HLSLAll : LangBuiltin<"HLSL_LANG"> {
47424748
let Spellings = ["__builtin_hlsl_all"];
47434749
let Attributes = [NoThrow, Const];

clang/lib/AST/Type.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4723,7 +4723,9 @@ LinkageInfo LinkageComputer::computeTypeLinkageInfo(const Type *T) {
47234723
case Type::Pipe:
47244724
return computeTypeLinkageInfo(cast<PipeType>(T)->getElementType());
47254725
case Type::HLSLAttributedResource:
4726-
llvm_unreachable("not yet implemented");
4726+
return computeTypeLinkageInfo(cast<HLSLAttributedResourceType>(T)
4727+
->getContainedType()
4728+
->getCanonicalTypeInternal());
47274729
}
47284730

47294731
llvm_unreachable("unhandled type class");

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19060,6 +19060,16 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1906019060
return nullptr;
1906119061

1906219062
switch (BuiltinID) {
19063+
case Builtin::BI__builtin_hlsl_resource_getpointer: {
19064+
Value *HandleOp = EmitScalarExpr(E->getArg(0));
19065+
Value *IndexOp = EmitScalarExpr(E->getArg(1));
19066+
19067+
// TODO: Map to an hlsl_device address space.
19068+
llvm::Type *RetTy = llvm::PointerType::getUnqual(getLLVMContext());
19069+
19070+
return Builder.CreateIntrinsic(RetTy, Intrinsic::dx_resource_getpointer,
19071+
ArrayRef<Value *>{HandleOp, IndexOp});
19072+
}
1906319073
case Builtin::BI__builtin_hlsl_all: {
1906419074
Value *Op0 = EmitScalarExpr(E->getArg(0));
1906519075
return Builder.CreateIntrinsic(

clang/lib/Sema/HLSLExternalSemaSource.cpp

Lines changed: 87 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,8 @@ struct BuiltinTypeDeclBuilder {
123123
assert(!Record->isCompleteDefinition() && "record is already complete");
124124

125125
ASTContext &Ctx = SemaRef.getASTContext();
126-
TypeSourceInfo *ElementTypeInfo = nullptr;
127-
128-
QualType ElemTy = Ctx.Char8Ty;
129-
if (Template)
130-
ElemTy = getFirstTemplateTypeParam();
131-
ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
126+
TypeSourceInfo *ElementTypeInfo =
127+
Ctx.getTrivialTypeSourceInfo(getHandleElementType(), SourceLocation());
132128

133129
// add handle member with resource type attributes
134130
QualType AttributedResTy = QualType();
@@ -171,80 +167,12 @@ struct BuiltinTypeDeclBuilder {
171167
}
172168

173169
BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
174-
addArraySubscriptOperator(true);
175-
addArraySubscriptOperator(false);
176-
return *this;
177-
}
178-
179-
BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
180-
assert(!Record->isCompleteDefinition() && "record is already complete");
181-
182170
ASTContext &AST = Record->getASTContext();
183-
QualType ElemTy = AST.Char8Ty;
184-
if (Template)
185-
ElemTy = getFirstTemplateTypeParam();
186-
QualType ReturnTy = ElemTy;
187-
188-
FunctionProtoType::ExtProtoInfo ExtInfo;
189-
190-
// Subscript operators return references to elements, const makes the
191-
// reference and method const so that the underlying data is not mutable.
192-
if (IsConst) {
193-
ExtInfo.TypeQuals.addConst();
194-
ReturnTy.addConst();
195-
}
196-
ReturnTy = AST.getLValueReferenceType(ReturnTy);
197-
198-
QualType MethodTy =
199-
AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
200-
auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
201-
auto *MethodDecl = CXXMethodDecl::Create(
202-
AST, Record, SourceLocation(),
203-
DeclarationNameInfo(
204-
AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
205-
SourceLocation()),
206-
MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
207-
SourceLocation());
208-
209-
IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
210-
auto *IdxParam = ParmVarDecl::Create(
211-
AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
212-
&II, AST.UnsignedIntTy,
213-
AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
214-
SC_None, nullptr);
215-
MethodDecl->setParams({IdxParam});
216-
217-
// Also add the parameter to the function prototype.
218-
auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
219-
FnProtoLoc.setParam(0, IdxParam);
220-
221-
// FIXME: Placeholder to make sure we return the correct type - create
222-
// field of element_type and return reference to it. This field will go
223-
// away once indexing into resources is properly implemented in
224-
// llvm/llvm-project#95956.
225-
if (Fields.count("e") == 0) {
226-
addMemberVariable("e", ElemTy, {});
227-
}
228-
FieldDecl *ElemFieldDecl = Fields["e"];
229-
230-
auto *This =
231-
CXXThisExpr::Create(AST, SourceLocation(),
232-
MethodDecl->getFunctionObjectParameterType(), true);
233-
Expr *ElemField = MemberExpr::CreateImplicit(
234-
AST, This, false, ElemFieldDecl, ElemFieldDecl->getType(), VK_LValue,
235-
OK_Ordinary);
236-
auto *Return =
237-
ReturnStmt::Create(AST, SourceLocation(), ElemField, nullptr);
238-
239-
MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
240-
SourceLocation(),
241-
SourceLocation()));
242-
MethodDecl->setLexicalDeclContext(Record);
243-
MethodDecl->setAccess(AccessSpecifier::AS_public);
244-
MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
245-
AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
246-
Record->addDecl(MethodDecl);
171+
DeclarationName Subscript =
172+
AST.DeclarationNames.getCXXOperatorName(OO_Subscript);
247173

174+
addHandleAccessFunction(Subscript, /*IsConst=*/true, /*IsRef=*/true);
175+
addHandleAccessFunction(Subscript, /*IsConst=*/false, /*IsRef=*/true);
248176
return *this;
249177
}
250178

@@ -265,6 +193,13 @@ struct BuiltinTypeDeclBuilder {
265193
return QualType();
266194
}
267195

196+
QualType getHandleElementType() {
197+
if (Template)
198+
return getFirstTemplateTypeParam();
199+
// TODO: Should we default to VoidTy? Using `i8` is arguably ambiguous.
200+
return SemaRef.getASTContext().Char8Ty;
201+
}
202+
268203
BuiltinTypeDeclBuilder &startDefinition() {
269204
assert(!Record->isCompleteDefinition() && "record is already complete");
270205
Record->startDefinition();
@@ -294,6 +229,8 @@ struct BuiltinTypeDeclBuilder {
294229
// Builtin types methods
295230
BuiltinTypeDeclBuilder &addIncrementCounterMethod();
296231
BuiltinTypeDeclBuilder &addDecrementCounterMethod();
232+
BuiltinTypeDeclBuilder &addHandleAccessFunction(DeclarationName &Name,
233+
bool IsConst, bool IsRef);
297234
};
298235

299236
struct TemplateParameterListBuilder {
@@ -453,7 +390,7 @@ struct TemplateParameterListBuilder {
453390
// Builder for methods of builtin types. Allows adding methods to builtin types
454391
// using the builder pattern like this:
455392
//
456-
// BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
393+
// BuiltinTypeMethodBuilder(RecordBuilder, "MethodName", ReturnType)
457394
// .addParam("param_name", Type, InOutModifier)
458395
// .callBuiltin("builtin_name", BuiltinParams...)
459396
// .finalizeMethod();
@@ -486,6 +423,7 @@ struct BuiltinTypeMethodBuilder {
486423
DeclarationNameInfo NameInfo;
487424
QualType ReturnTy;
488425
CXXMethodDecl *Method;
426+
bool IsConst;
489427
llvm::SmallVector<MethodParam> Params;
490428
llvm::SmallVector<Stmt *> StmtsList;
491429

@@ -508,11 +446,16 @@ struct BuiltinTypeMethodBuilder {
508446
Expr *convertPlaceholder(Expr *E) { return E; }
509447

510448
public:
511-
BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
512-
QualType ReturnTy)
513-
: DeclBuilder(DB), ReturnTy(ReturnTy), Method(nullptr) {
449+
BuiltinTypeMethodBuilder(BuiltinTypeDeclBuilder &DB, DeclarationName &Name,
450+
QualType ReturnTy, bool IsConst = false)
451+
: DeclBuilder(DB), NameInfo(DeclarationNameInfo(Name, SourceLocation())),
452+
ReturnTy(ReturnTy), Method(nullptr), IsConst(IsConst) {}
453+
454+
BuiltinTypeMethodBuilder(BuiltinTypeDeclBuilder &DB, StringRef Name,
455+
QualType ReturnTy, bool IsConst = false)
456+
: DeclBuilder(DB), ReturnTy(ReturnTy), Method(nullptr), IsConst(IsConst) {
514457
const IdentifierInfo &II =
515-
S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
458+
DB.SemaRef.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
516459
NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
517460
}
518461

@@ -535,8 +478,12 @@ struct BuiltinTypeMethodBuilder {
535478
SmallVector<QualType> ParamTypes;
536479
for (MethodParam &MP : Params)
537480
ParamTypes.emplace_back(MP.Ty);
538-
QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
539-
FunctionProtoType::ExtProtoInfo());
481+
482+
FunctionProtoType::ExtProtoInfo ExtInfo;
483+
if (IsConst)
484+
ExtInfo.TypeQuals.addConst();
485+
486+
QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes, ExtInfo);
540487

541488
// create method decl
542489
auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
@@ -586,7 +533,8 @@ struct BuiltinTypeMethodBuilder {
586533
}
587534

588535
template <typename... Ts>
589-
BuiltinTypeMethodBuilder &callBuiltin(StringRef BuiltinName, Ts... ArgSpecs) {
536+
BuiltinTypeMethodBuilder &callBuiltin(StringRef BuiltinName,
537+
QualType ReturnType, Ts... ArgSpecs) {
590538
std::array<Expr *, sizeof...(ArgSpecs)> Args{
591539
convertPlaceholder(std::forward<Ts>(ArgSpecs))...};
592540

@@ -599,15 +547,32 @@ struct BuiltinTypeMethodBuilder {
599547
FunctionDecl *FD = lookupBuiltinFunction(DeclBuilder.SemaRef, BuiltinName);
600548
DeclRefExpr *DRE = DeclRefExpr::Create(
601549
AST, NestedNameSpecifierLoc(), SourceLocation(), FD, false,
602-
FD->getNameInfo(), FD->getType(), VK_PRValue);
550+
FD->getNameInfo(), AST.BuiltinFnTy, VK_PRValue);
551+
552+
if (ReturnType.isNull())
553+
ReturnType = FD->getReturnType();
603554

604-
Expr *Call =
605-
CallExpr::Create(AST, DRE, Args, FD->getReturnType(), VK_PRValue,
606-
SourceLocation(), FPOptionsOverride());
555+
Expr *Call = CallExpr::Create(AST, DRE, Args, ReturnType, VK_PRValue,
556+
SourceLocation(), FPOptionsOverride());
607557
StmtsList.push_back(Call);
608558
return *this;
609559
}
610560

561+
BuiltinTypeMethodBuilder &dereference() {
562+
assert(!StmtsList.empty() && "Nothing to dereference");
563+
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
564+
565+
Expr *LastExpr = dyn_cast<Expr>(StmtsList.back());
566+
assert(LastExpr && "No expression to dereference");
567+
Expr *Deref = UnaryOperator::Create(
568+
AST, LastExpr, UO_Deref, LastExpr->getType()->getPointeeType(),
569+
VK_PRValue, OK_Ordinary, SourceLocation(),
570+
/*CanOverflow=*/false, FPOptionsOverride());
571+
StmtsList.pop_back();
572+
StmtsList.push_back(Deref);
573+
return *this;
574+
}
575+
611576
BuiltinTypeDeclBuilder &finalizeMethod() {
612577
assert(!DeclBuilder.Record->isCompleteDefinition() &&
613578
"record is already complete");
@@ -621,11 +586,8 @@ struct BuiltinTypeMethodBuilder {
621586
"nothing to return from non-void method");
622587
if (ReturnTy != AST.VoidTy) {
623588
if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
624-
assert(AST.hasSameUnqualifiedType(
625-
isa<CallExpr>(LastExpr)
626-
? cast<CallExpr>(LastExpr)->getCallReturnType(AST)
627-
: LastExpr->getType(),
628-
ReturnTy) &&
589+
assert(AST.hasSameUnqualifiedType(LastExpr->getType(),
590+
ReturnTy.getNonReferenceType()) &&
629591
"Return type of the last statement must match the return type "
630592
"of the method");
631593
if (!isa<ReturnStmt>(LastExpr)) {
@@ -672,19 +634,43 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(ArrayRef<StringRef> Names,
672634

673635
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addIncrementCounterMethod() {
674636
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
675-
return BuiltinTypeMethodBuilder(SemaRef, *this, "IncrementCounter",
637+
return BuiltinTypeMethodBuilder(*this, "IncrementCounter",
676638
SemaRef.getASTContext().UnsignedIntTy)
677-
.callBuiltin("__builtin_hlsl_buffer_update_counter", PH::Handle,
678-
getConstantIntExpr(1))
639+
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
640+
PH::Handle, getConstantIntExpr(1))
679641
.finalizeMethod();
680642
}
681643

682644
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDecrementCounterMethod() {
683645
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
684-
return BuiltinTypeMethodBuilder(SemaRef, *this, "DecrementCounter",
646+
return BuiltinTypeMethodBuilder(*this, "DecrementCounter",
685647
SemaRef.getASTContext().UnsignedIntTy)
686-
.callBuiltin("__builtin_hlsl_buffer_update_counter", PH::Handle,
687-
getConstantIntExpr(-1))
648+
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
649+
PH::Handle, getConstantIntExpr(-1))
650+
.finalizeMethod();
651+
}
652+
653+
BuiltinTypeDeclBuilder &
654+
BuiltinTypeDeclBuilder::addHandleAccessFunction(DeclarationName &Name,
655+
bool IsConst, bool IsRef) {
656+
assert(!Record->isCompleteDefinition() && "record is already complete");
657+
ASTContext &AST = SemaRef.getASTContext();
658+
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
659+
660+
QualType ElemTy = getHandleElementType();
661+
// TODO: Map to an hlsl_device address space.
662+
QualType ElemPtrTy = AST.getPointerType(ElemTy);
663+
QualType ReturnTy = ElemTy;
664+
if (IsConst)
665+
ReturnTy.addConst();
666+
if (IsRef)
667+
ReturnTy = AST.getLValueReferenceType(ReturnTy);
668+
669+
return BuiltinTypeMethodBuilder(*this, Name, ReturnTy, IsConst)
670+
.addParam("Index", AST.UnsignedIntTy)
671+
.callBuiltin("__builtin_hlsl_resource_getpointer", ElemPtrTy, PH::Handle,
672+
PH::_0)
673+
.dereference()
688674
.finalizeMethod();
689675
}
690676

clang/lib/Sema/SemaExpr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,9 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
980980
if (Ty->isObjCObjectType())
981981
return VAK_Invalid;
982982

983+
if (getLangOpts().HLSL && Ty->getAs<HLSLAttributedResourceType>())
984+
return VAK_Valid;
985+
983986
if (getLangOpts().MSVCCompat)
984987
return VAK_MSVCUndefined;
985988

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1908,7 +1908,7 @@ static bool CheckResourceHandle(
19081908
const HLSLAttributedResourceType *ResTy =
19091909
ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
19101910
if (!ResTy) {
1911-
S->Diag(TheCall->getArg(0)->getBeginLoc(),
1911+
S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
19121912
diag::err_typecheck_expect_hlsl_resource)
19131913
<< ArgType;
19141914
return true;
@@ -1926,6 +1926,22 @@ static bool CheckResourceHandle(
19261926
// returning an ExprError
19271927
bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
19281928
switch (BuiltinID) {
1929+
case Builtin::BI__builtin_hlsl_resource_getpointer: {
1930+
if (SemaRef.checkArgCount(TheCall, 2) ||
1931+
CheckResourceHandle(&SemaRef, TheCall, 0) ||
1932+
CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
1933+
SemaRef.getASTContext().UnsignedIntTy))
1934+
return true;
1935+
1936+
auto *ResourceTy =
1937+
TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
1938+
QualType ContainedTy = ResourceTy->getContainedType();
1939+
// TODO: Map to an hlsl_device address space.
1940+
TheCall->setType(getASTContext().getPointerType(ContainedTy));
1941+
TheCall->setValueKind(VK_LValue);
1942+
1943+
break;
1944+
}
19291945
case Builtin::BI__builtin_hlsl_all:
19301946
case Builtin::BI__builtin_hlsl_any: {
19311947
if (SemaRef.checkArgCount(TheCall, 1))

0 commit comments

Comments
 (0)