Skip to content

[DirectX] Create symbols for resource handles #119775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions llvm/include/llvm/Analysis/DXILResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class ResourceTypeInfo {
GloballyCoherent, HasCounter) {}

TargetExtType *getHandleTy() const { return HandleTy; }
StructType *createElementStruct();

// Conditions to check before accessing specific views.
bool isUAV() const;
Expand Down Expand Up @@ -329,28 +330,31 @@ class ResourceBindingInfo {
private:
ResourceBinding Binding;
TargetExtType *HandleTy;
GlobalVariable *Symbol = nullptr;

public:
ResourceBindingInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
uint32_t Size, TargetExtType *HandleTy)
: Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy) {}
uint32_t Size, TargetExtType *HandleTy,
GlobalVariable *Symbol = nullptr)
: Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy),
Symbol(Symbol) {}

void setBindingID(unsigned ID) { Binding.RecordID = ID; }

const ResourceBinding &getBinding() const { return Binding; }
TargetExtType *getHandleTy() const { return HandleTy; }
const StringRef getName() const {
// TODO: Get the name from the symbol once we include one here.
return "";
}
const StringRef getName() const { return Symbol ? Symbol->getName() : ""; }

bool hasSymbol() const { return Symbol; }
GlobalVariable *createSymbol(Module &M, StructType *Ty, StringRef Name = "");
MDTuple *getAsMetadata(Module &M, dxil::ResourceTypeInfo &RTI) const;

std::pair<uint32_t, uint32_t>
getAnnotateProps(Module &M, dxil::ResourceTypeInfo &RTI) const;

bool operator==(const ResourceBindingInfo &RHS) const {
return std::tie(Binding, HandleTy) == std::tie(RHS.Binding, RHS.HandleTy);
return std::tie(Binding, HandleTy, Symbol) ==
std::tie(RHS.Binding, RHS.HandleTy, RHS.Symbol);
}
bool operator!=(const ResourceBindingInfo &RHS) const {
return !(*this == RHS);
Expand Down
100 changes: 93 additions & 7 deletions llvm/lib/Analysis/DXILResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,81 @@ ResourceTypeInfo::ResourceTypeInfo(TargetExtType *HandleTy,
llvm_unreachable("Unknown handle type");
}

static void formatTypeName(SmallString<64> &Dest, StringRef Name,
bool isWriteable, bool isROV) {
Dest = isWriteable ? (isROV ? "RasterizerOrdered" : "RW") : "";
Dest += Name;
}

StructType *ResourceTypeInfo::createElementStruct() {
SmallString<64> TypeName;

switch (Kind) {
case ResourceKind::Texture1D:
case ResourceKind::Texture2D:
case ResourceKind::Texture3D:
case ResourceKind::TextureCube:
case ResourceKind::Texture1DArray:
case ResourceKind::Texture2DArray:
case ResourceKind::TextureCubeArray: {
auto *RTy = cast<TextureExtType>(HandleTy);
formatTypeName(TypeName, getResourceKindName(Kind), RTy->isWriteable(),
RTy->isROV());
return StructType::create(RTy->getResourceType(), TypeName);
}
case ResourceKind::Texture2DMS:
case ResourceKind::Texture2DMSArray: {
auto *RTy = cast<MSTextureExtType>(HandleTy);
formatTypeName(TypeName, getResourceKindName(Kind), RTy->isWriteable(),
/*IsROV=*/false);
return StructType::create(RTy->getResourceType(), TypeName);
}
case ResourceKind::TypedBuffer: {
auto *RTy = cast<TypedBufferExtType>(HandleTy);
formatTypeName(TypeName, getResourceKindName(Kind), RTy->isWriteable(),
RTy->isROV());
return StructType::create(RTy->getResourceType(), TypeName);
}
case ResourceKind::RawBuffer: {
auto *RTy = cast<RawBufferExtType>(HandleTy);
formatTypeName(TypeName, "ByteAddressBuffer", RTy->isWriteable(),
RTy->isROV());
return StructType::create(Type::getInt32Ty(HandleTy->getContext()),
TypeName);
}
case ResourceKind::StructuredBuffer: {
auto *RTy = cast<RawBufferExtType>(HandleTy);
formatTypeName(TypeName, "StructuredBuffer", RTy->isWriteable(),
RTy->isROV());
return StructType::create(RTy->getResourceType(), TypeName);
}
case ResourceKind::FeedbackTexture2D:
case ResourceKind::FeedbackTexture2DArray: {
auto *RTy = cast<FeedbackTextureExtType>(HandleTy);
TypeName = formatv("{0}<{1}>", getResourceKindName(Kind),
llvm::to_underlying(RTy->getFeedbackType()));
return StructType::create(Type::getInt32Ty(HandleTy->getContext()),
TypeName);
}
case ResourceKind::CBuffer:
return StructType::create(HandleTy->getContext(), "cbuffer");
case ResourceKind::Sampler: {
auto *RTy = cast<SamplerExtType>(HandleTy);
TypeName = formatv("SamplerState<{0}>",
llvm::to_underlying(RTy->getSamplerType()));
return StructType::create(Type::getInt32Ty(HandleTy->getContext()),
TypeName);
}
case ResourceKind::TBuffer:
case ResourceKind::RTAccelerationStructure:
llvm_unreachable("Unhandled resource kind");
case ResourceKind::Invalid:
case ResourceKind::NumEntries:
llvm_unreachable("Invalid resource kind");
}
llvm_unreachable("Unhandled ResourceKind enum");
}

bool ResourceTypeInfo::isUAV() const { return RC == ResourceClass::UAV; }

bool ResourceTypeInfo::isCBuffer() const {
Expand Down Expand Up @@ -449,6 +524,15 @@ void ResourceTypeInfo::print(raw_ostream &OS, const DataLayout &DL) const {
}
}

GlobalVariable *ResourceBindingInfo::createSymbol(Module &M, StructType *Ty,
StringRef Name) {
assert(!Symbol && "Symbol has already been created");
Symbol = new GlobalVariable(M, Ty, /*isConstant=*/true,
GlobalValue::ExternalLinkage,
/*Initializer=*/nullptr, Name);
return Symbol;
}

MDTuple *ResourceBindingInfo::getAsMetadata(Module &M,
dxil::ResourceTypeInfo &RTI) const {
LLVMContext &Ctx = M.getContext();
Expand All @@ -468,13 +552,9 @@ MDTuple *ResourceBindingInfo::getAsMetadata(Module &M,
};

MDVals.push_back(getIntMD(Binding.RecordID));

// TODO: We need API to create a symbol of the appropriate type to emit here.
// See https://github.com/llvm/llvm-project/issues/116849
MDVals.push_back(
ValueAsMetadata::get(UndefValue::get(PointerType::getUnqual(Ctx))));
MDVals.push_back(MDString::get(Ctx, ""));

assert(Symbol && "Cannot yet create useful resource metadata without symbol");
MDVals.push_back(ValueAsMetadata::get(Symbol));
MDVals.push_back(MDString::get(Ctx, Symbol->getName()));
MDVals.push_back(getIntMD(Binding.Space));
MDVals.push_back(getIntMD(Binding.LowerBound));
MDVals.push_back(getIntMD(Binding.Size));
Expand Down Expand Up @@ -573,6 +653,12 @@ ResourceBindingInfo::getAnnotateProps(Module &M,

void ResourceBindingInfo::print(raw_ostream &OS, dxil::ResourceTypeInfo &RTI,
const DataLayout &DL) const {
if (Symbol) {
OS << " Symbol: ";
Symbol->printAsOperand(OS);
OS << "\n";
}

OS << " Binding:\n"
<< " Record ID: " << Binding.RecordID << "\n"
<< " Space: " << Binding.Space << "\n"
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ static NamedMDNode *emitResourceMetadata(Module &M, DXILBindingMap &DBM,
const dxil::Resources &MDResources) {
LLVMContext &Context = M.getContext();

for (ResourceBindingInfo &RI : DBM)
if (!RI.hasSymbol())
RI.createSymbol(M, DRTM[RI.getHandleTy()].createElementStruct());

SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
for (const ResourceBindingInfo &RI : DBM.srvs())
SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
Expand Down
48 changes: 48 additions & 0 deletions llvm/test/CodeGen/DirectX/Metadata/resource-symbols.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
; RUN: opt -S -passes=dxil-translate-metadata %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.6-compute"

%struct.S = type { <4 x float>, <4 x i32> }

define void @test() {
; Buffer<float4>
%float4 = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
@llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false)
; CHECK: %TypedBuffer = type { <4 x float> }

; Buffer<int>
%int = call target("dx.TypedBuffer", i32, 0, 0, 1)
@llvm.dx.handle.fromBinding(i32 0, i32 1, i32 1, i32 0, i1 false)
; CHECK: %TypedBuffer.0 = type { i32 }

; Buffer<uint3>
%uint3 = call target("dx.TypedBuffer", <3 x i32>, 0, 0, 0)
@llvm.dx.handle.fromBinding(i32 0, i32 2, i32 1, i32 0, i1 false)
; CHECK: %TypedBuffer.1 = type { <3 x i32> }

; StructuredBuffer<S>
%struct0 = call target("dx.RawBuffer", %struct.S, 0, 0)
@llvm.dx.handle.fromBinding(i32 0, i32 10, i32 1, i32 0, i1 true)
; CHECK: %StructuredBuffer = type { %struct.S }

; ByteAddressBuffer
%byteaddr = call target("dx.RawBuffer", i8, 0, 0)
@llvm.dx.handle.fromBinding(i32 0, i32 20, i32 1, i32 0, i1 false)
; CHECK: %ByteAddressBuffer = type { i32 }

ret void
}

; CHECK: @[[T0:.*]] = external constant %TypedBuffer
; CHECK-NEXT: @[[T1:.*]] = external constant %TypedBuffer.0
; CHECK-NEXT: @[[T2:.*]] = external constant %TypedBuffer.1
; CHECK-NEXT: @[[S0:.*]] = external constant %StructuredBuffer
; CHECK-NEXT: @[[B0:.*]] = external constant %ByteAddressBuffer

; CHECK: !{i32 0, ptr @[[T0]], !""
; CHECK: !{i32 1, ptr @[[T1]], !""
; CHECK: !{i32 2, ptr @[[T2]], !""
; CHECK: !{i32 3, ptr @[[S0]], !""
; CHECK: !{i32 4, ptr @[[B0]], !""

attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) }
Loading
Loading