Skip to content

[HLSL] Implement RWBuffer::operator[] via __builtin_hlsl_resource_getpointer #117017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4738,6 +4738,12 @@ def GetDeviceSideMangledName : LangBuiltin<"CUDA_LANG"> {
}

// HLSL
def HLSLResourceGetPointer : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_resource_getpointer"];
let Attributes = [NoThrow];
let Prototype = "void(...)";
}

def HLSLAll : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_all"];
let Attributes = [NoThrow, Const];
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4723,7 +4723,9 @@ LinkageInfo LinkageComputer::computeTypeLinkageInfo(const Type *T) {
case Type::Pipe:
return computeTypeLinkageInfo(cast<PipeType>(T)->getElementType());
case Type::HLSLAttributedResource:
llvm_unreachable("not yet implemented");
return computeTypeLinkageInfo(cast<HLSLAttributedResourceType>(T)
->getContainedType()
->getCanonicalTypeInternal());
}

llvm_unreachable("unhandled type class");
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19044,6 +19044,16 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return nullptr;

switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_resource_getpointer: {
Value *HandleOp = EmitScalarExpr(E->getArg(0));
Value *IndexOp = EmitScalarExpr(E->getArg(1));

// TODO: Map to an hlsl_device address space.
llvm::Type *RetTy = llvm::PointerType::getUnqual(getLLVMContext());

return Builder.CreateIntrinsic(RetTy, Intrinsic::dx_resource_getpointer,
ArrayRef<Value *>{HandleOp, IndexOp});
}
case Builtin::BI__builtin_hlsl_all: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
return Builder.CreateIntrinsic(
Expand Down
188 changes: 87 additions & 101 deletions clang/lib/Sema/HLSLExternalSemaSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,8 @@ struct BuiltinTypeDeclBuilder {
assert(!Record->isCompleteDefinition() && "record is already complete");

ASTContext &Ctx = SemaRef.getASTContext();
TypeSourceInfo *ElementTypeInfo = nullptr;

QualType ElemTy = Ctx.Char8Ty;
if (Template)
ElemTy = getFirstTemplateTypeParam();
ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
TypeSourceInfo *ElementTypeInfo =
Ctx.getTrivialTypeSourceInfo(getHandleElementType(), SourceLocation());

// add handle member with resource type attributes
QualType AttributedResTy = QualType();
Expand Down Expand Up @@ -171,80 +167,12 @@ struct BuiltinTypeDeclBuilder {
}

BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
addArraySubscriptOperator(true);
addArraySubscriptOperator(false);
return *this;
}

BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
assert(!Record->isCompleteDefinition() && "record is already complete");

ASTContext &AST = Record->getASTContext();
QualType ElemTy = AST.Char8Ty;
if (Template)
ElemTy = getFirstTemplateTypeParam();
QualType ReturnTy = ElemTy;

FunctionProtoType::ExtProtoInfo ExtInfo;

// Subscript operators return references to elements, const makes the
// reference and method const so that the underlying data is not mutable.
if (IsConst) {
ExtInfo.TypeQuals.addConst();
ReturnTy.addConst();
}
ReturnTy = AST.getLValueReferenceType(ReturnTy);

QualType MethodTy =
AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
auto *MethodDecl = CXXMethodDecl::Create(
AST, Record, SourceLocation(),
DeclarationNameInfo(
AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
SourceLocation()),
MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
SourceLocation());

IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
auto *IdxParam = ParmVarDecl::Create(
AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
&II, AST.UnsignedIntTy,
AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
SC_None, nullptr);
MethodDecl->setParams({IdxParam});

// Also add the parameter to the function prototype.
auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
FnProtoLoc.setParam(0, IdxParam);

// FIXME: Placeholder to make sure we return the correct type - create
// field of element_type and return reference to it. This field will go
// away once indexing into resources is properly implemented in
// llvm/llvm-project#95956.
if (Fields.count("e") == 0) {
addMemberVariable("e", ElemTy, {});
}
FieldDecl *ElemFieldDecl = Fields["e"];

auto *This =
CXXThisExpr::Create(AST, SourceLocation(),
MethodDecl->getFunctionObjectParameterType(), true);
Expr *ElemField = MemberExpr::CreateImplicit(
AST, This, false, ElemFieldDecl, ElemFieldDecl->getType(), VK_LValue,
OK_Ordinary);
auto *Return =
ReturnStmt::Create(AST, SourceLocation(), ElemField, nullptr);

MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
SourceLocation(),
SourceLocation()));
MethodDecl->setLexicalDeclContext(Record);
MethodDecl->setAccess(AccessSpecifier::AS_public);
MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
Record->addDecl(MethodDecl);
DeclarationName Subscript =
AST.DeclarationNames.getCXXOperatorName(OO_Subscript);

addHandleAccessFunction(Subscript, /*IsConst=*/true, /*IsRef=*/true);
addHandleAccessFunction(Subscript, /*IsConst=*/false, /*IsRef=*/true);
return *this;
}

Expand All @@ -265,6 +193,13 @@ struct BuiltinTypeDeclBuilder {
return QualType();
}

QualType getHandleElementType() {
if (Template)
return getFirstTemplateTypeParam();
// TODO: Should we default to VoidTy? Using `i8` is arguably ambiguous.
return SemaRef.getASTContext().Char8Ty;
}

BuiltinTypeDeclBuilder &startDefinition() {
assert(!Record->isCompleteDefinition() && "record is already complete");
Record->startDefinition();
Expand Down Expand Up @@ -294,6 +229,8 @@ struct BuiltinTypeDeclBuilder {
// Builtin types methods
BuiltinTypeDeclBuilder &addIncrementCounterMethod();
BuiltinTypeDeclBuilder &addDecrementCounterMethod();
BuiltinTypeDeclBuilder &addHandleAccessFunction(DeclarationName &Name,
bool IsConst, bool IsRef);
};

struct TemplateParameterListBuilder {
Expand Down Expand Up @@ -453,7 +390,7 @@ struct TemplateParameterListBuilder {
// Builder for methods of builtin types. Allows adding methods to builtin types
// using the builder pattern like this:
//
// BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
// BuiltinTypeMethodBuilder(RecordBuilder, "MethodName", ReturnType)
// .addParam("param_name", Type, InOutModifier)
// .callBuiltin("builtin_name", BuiltinParams...)
// .finalizeMethod();
Expand Down Expand Up @@ -486,6 +423,7 @@ struct BuiltinTypeMethodBuilder {
DeclarationNameInfo NameInfo;
QualType ReturnTy;
CXXMethodDecl *Method;
bool IsConst;
llvm::SmallVector<MethodParam> Params;
llvm::SmallVector<Stmt *> StmtsList;

Expand All @@ -508,11 +446,16 @@ struct BuiltinTypeMethodBuilder {
Expr *convertPlaceholder(Expr *E) { return E; }

public:
BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
QualType ReturnTy)
: DeclBuilder(DB), ReturnTy(ReturnTy), Method(nullptr) {
BuiltinTypeMethodBuilder(BuiltinTypeDeclBuilder &DB, DeclarationName &Name,
QualType ReturnTy, bool IsConst = false)
: DeclBuilder(DB), NameInfo(DeclarationNameInfo(Name, SourceLocation())),
ReturnTy(ReturnTy), Method(nullptr), IsConst(IsConst) {}

BuiltinTypeMethodBuilder(BuiltinTypeDeclBuilder &DB, StringRef Name,
QualType ReturnTy, bool IsConst = false)
: DeclBuilder(DB), ReturnTy(ReturnTy), Method(nullptr), IsConst(IsConst) {
const IdentifierInfo &II =
S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
DB.SemaRef.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
}

Expand All @@ -535,8 +478,12 @@ struct BuiltinTypeMethodBuilder {
SmallVector<QualType> ParamTypes;
for (MethodParam &MP : Params)
ParamTypes.emplace_back(MP.Ty);
QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
FunctionProtoType::ExtProtoInfo());

FunctionProtoType::ExtProtoInfo ExtInfo;
if (IsConst)
ExtInfo.TypeQuals.addConst();

QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes, ExtInfo);

// create method decl
auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
Expand Down Expand Up @@ -586,7 +533,8 @@ struct BuiltinTypeMethodBuilder {
}

template <typename... Ts>
BuiltinTypeMethodBuilder &callBuiltin(StringRef BuiltinName, Ts... ArgSpecs) {
BuiltinTypeMethodBuilder &callBuiltin(StringRef BuiltinName,
QualType ReturnType, Ts... ArgSpecs) {
std::array<Expr *, sizeof...(ArgSpecs)> Args{
convertPlaceholder(std::forward<Ts>(ArgSpecs))...};

Expand All @@ -599,15 +547,32 @@ struct BuiltinTypeMethodBuilder {
FunctionDecl *FD = lookupBuiltinFunction(DeclBuilder.SemaRef, BuiltinName);
DeclRefExpr *DRE = DeclRefExpr::Create(
AST, NestedNameSpecifierLoc(), SourceLocation(), FD, false,
FD->getNameInfo(), FD->getType(), VK_PRValue);
FD->getNameInfo(), AST.BuiltinFnTy, VK_PRValue);

if (ReturnType.isNull())
ReturnType = FD->getReturnType();

Expr *Call =
CallExpr::Create(AST, DRE, Args, FD->getReturnType(), VK_PRValue,
SourceLocation(), FPOptionsOverride());
Expr *Call = CallExpr::Create(AST, DRE, Args, ReturnType, VK_PRValue,
SourceLocation(), FPOptionsOverride());
StmtsList.push_back(Call);
return *this;
}

BuiltinTypeMethodBuilder &dereference() {
assert(!StmtsList.empty() && "Nothing to dereference");
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();

Expr *LastExpr = dyn_cast<Expr>(StmtsList.back());
assert(LastExpr && "No expression to dereference");
Expr *Deref = UnaryOperator::Create(
AST, LastExpr, UO_Deref, LastExpr->getType()->getPointeeType(),
VK_PRValue, OK_Ordinary, SourceLocation(),
/*CanOverflow=*/false, FPOptionsOverride());
StmtsList.pop_back();
StmtsList.push_back(Deref);
return *this;
}

BuiltinTypeDeclBuilder &finalizeMethod() {
assert(!DeclBuilder.Record->isCompleteDefinition() &&
"record is already complete");
Expand All @@ -621,11 +586,8 @@ struct BuiltinTypeMethodBuilder {
"nothing to return from non-void method");
if (ReturnTy != AST.VoidTy) {
if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
assert(AST.hasSameUnqualifiedType(
isa<CallExpr>(LastExpr)
? cast<CallExpr>(LastExpr)->getCallReturnType(AST)
: LastExpr->getType(),
ReturnTy) &&
assert(AST.hasSameUnqualifiedType(LastExpr->getType(),
ReturnTy.getNonReferenceType()) &&
"Return type of the last statement must match the return type "
"of the method");
if (!isa<ReturnStmt>(LastExpr)) {
Expand Down Expand Up @@ -672,19 +634,43 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(ArrayRef<StringRef> Names,

BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addIncrementCounterMethod() {
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
return BuiltinTypeMethodBuilder(SemaRef, *this, "IncrementCounter",
return BuiltinTypeMethodBuilder(*this, "IncrementCounter",
SemaRef.getASTContext().UnsignedIntTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", PH::Handle,
getConstantIntExpr(1))
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
PH::Handle, getConstantIntExpr(1))
.finalizeMethod();
}

BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDecrementCounterMethod() {
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
return BuiltinTypeMethodBuilder(SemaRef, *this, "DecrementCounter",
return BuiltinTypeMethodBuilder(*this, "DecrementCounter",
SemaRef.getASTContext().UnsignedIntTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", PH::Handle,
getConstantIntExpr(-1))
.callBuiltin("__builtin_hlsl_buffer_update_counter", QualType(),
PH::Handle, getConstantIntExpr(-1))
.finalizeMethod();
}

BuiltinTypeDeclBuilder &
BuiltinTypeDeclBuilder::addHandleAccessFunction(DeclarationName &Name,
bool IsConst, bool IsRef) {
assert(!Record->isCompleteDefinition() && "record is already complete");
ASTContext &AST = SemaRef.getASTContext();
using PH = BuiltinTypeMethodBuilder::PlaceHolder;

QualType ElemTy = getHandleElementType();
// TODO: Map to an hlsl_device address space.
QualType ElemPtrTy = AST.getPointerType(ElemTy);
QualType ReturnTy = ElemTy;
if (IsConst)
ReturnTy.addConst();
if (IsRef)
ReturnTy = AST.getLValueReferenceType(ReturnTy);

return BuiltinTypeMethodBuilder(*this, Name, ReturnTy, IsConst)
.addParam("Index", AST.UnsignedIntTy)
.callBuiltin("__builtin_hlsl_resource_getpointer", ElemPtrTy, PH::Handle,
PH::_0)
.dereference()
.finalizeMethod();
}

Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,9 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
if (Ty->isObjCObjectType())
return VAK_Invalid;

if (getLangOpts().HLSL && Ty->getAs<HLSLAttributedResourceType>())
return VAK_Valid;

if (getLangOpts().MSVCCompat)
return VAK_MSVCUndefined;

Expand Down
18 changes: 17 additions & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,7 @@ static bool CheckResourceHandle(
const HLSLAttributedResourceType *ResTy =
ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
if (!ResTy) {
S->Diag(TheCall->getArg(0)->getBeginLoc(),
S->Diag(TheCall->getArg(ArgIndex)->getBeginLoc(),
diag::err_typecheck_expect_hlsl_resource)
<< ArgType;
return true;
Expand All @@ -1926,6 +1926,22 @@ static bool CheckResourceHandle(
// returning an ExprError
bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_resource_getpointer: {
if (SemaRef.checkArgCount(TheCall, 2) ||
CheckResourceHandle(&SemaRef, TheCall, 0) ||
CheckArgTypeMatches(&SemaRef, TheCall->getArg(1),
SemaRef.getASTContext().UnsignedIntTy))
return true;

auto *ResourceTy =
TheCall->getArg(0)->getType()->castAs<HLSLAttributedResourceType>();
QualType ContainedTy = ResourceTy->getContainedType();
// TODO: Map to an hlsl_device address space.
TheCall->setType(getASTContext().getPointerType(ContainedTy));
TheCall->setValueKind(VK_LValue);

break;
}
case Builtin::BI__builtin_hlsl_all:
case Builtin::BI__builtin_hlsl_any: {
if (SemaRef.checkArgCount(TheCall, 1))
Expand Down
Loading
Loading