Skip to content

[DXIL][Analysis] Uniquify duplicate resources in DXILResourceAnalysis #105602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions llvm/include/llvm/Analysis/DXILResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Value;
namespace dxil {

class ResourceInfo {
public:
struct ResourceBinding {
uint32_t RecordID;
uint32_t Space;
Expand All @@ -38,6 +39,10 @@ class ResourceInfo {
bool operator!=(const ResourceBinding &RHS) const {
return !(*this == RHS);
}
bool operator<(const ResourceBinding &RHS) const {
return std::tie(RecordID, Space, LowerBound, Size) <
std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
}
};

struct UAVInfo {
Expand All @@ -50,6 +55,10 @@ class ResourceInfo {
std::tie(RHS.GloballyCoherent, RHS.HasCounter, RHS.IsROV);
}
bool operator!=(const UAVInfo &RHS) const { return !(*this == RHS); }
bool operator<(const UAVInfo &RHS) const {
return std::tie(GloballyCoherent, HasCounter, IsROV) <
std::tie(RHS.GloballyCoherent, RHS.HasCounter, RHS.IsROV);
}
};

struct StructInfo {
Expand All @@ -64,6 +73,9 @@ class ResourceInfo {
return std::tie(Stride, AlignLog2) == std::tie(RHS.Stride, RHS.AlignLog2);
}
bool operator!=(const StructInfo &RHS) const { return !(*this == RHS); }
bool operator<(const StructInfo &RHS) const {
return std::tie(Stride, AlignLog2) < std::tie(RHS.Stride, RHS.AlignLog2);
}
};

struct TypedInfo {
Expand All @@ -75,22 +87,29 @@ class ResourceInfo {
std::tie(RHS.ElementTy, RHS.ElementCount);
}
bool operator!=(const TypedInfo &RHS) const { return !(*this == RHS); }
bool operator<(const TypedInfo &RHS) const {
return std::tie(ElementTy, ElementCount) <
std::tie(RHS.ElementTy, RHS.ElementCount);
}
};

struct MSInfo {
uint32_t Count;

bool operator==(const MSInfo &RHS) const { return Count == RHS.Count; }
bool operator!=(const MSInfo &RHS) const { return !(*this == RHS); }
bool operator<(const MSInfo &RHS) const { return Count < RHS.Count; }
};

struct FeedbackInfo {
dxil::SamplerFeedbackType Type;

bool operator==(const FeedbackInfo &RHS) const { return Type == RHS.Type; }
bool operator!=(const FeedbackInfo &RHS) const { return !(*this == RHS); }
bool operator<(const FeedbackInfo &RHS) const { return Type < RHS.Type; }
};

private:
// Universal properties.
Value *Symbol;
StringRef Name;
Expand Down Expand Up @@ -138,6 +157,7 @@ class ResourceInfo {
Binding.LowerBound = LowerBound;
Binding.Size = Size;
}
const ResourceBinding &getBinding() const { return Binding; }
void setUAV(bool GloballyCoherent, bool HasCounter, bool IsROV) {
assert(isUAV() && "Not a UAV");
UAVFlags.GloballyCoherent = GloballyCoherent;
Expand Down Expand Up @@ -168,7 +188,11 @@ class ResourceInfo {
MultiSample.Count = Count;
}

dxil::ResourceClass getResourceClass() const { return RC; }

bool operator==(const ResourceInfo &RHS) const;
bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); }
bool operator<(const ResourceInfo &RHS) const;

static ResourceInfo SRV(Value *Symbol, StringRef Name,
dxil::ElementType ElementTy, uint32_t ElementCount,
Expand Down Expand Up @@ -216,15 +240,48 @@ class ResourceInfo {

MDTuple *getAsMetadata(LLVMContext &Ctx) const;

ResourceBinding getBinding() const { return Binding; }
std::pair<uint32_t, uint32_t> getAnnotateProps() const;

void print(raw_ostream &OS) const;
};

} // namespace dxil

using DXILResourceMap = MapVector<CallInst *, dxil::ResourceInfo>;
class DXILResourceMap {
SmallVector<dxil::ResourceInfo> Resources;
DenseMap<CallInst *, unsigned> CallMap;
unsigned FirstUAV = 0;
unsigned FirstCBuffer = 0;
unsigned FirstSampler = 0;

public:
using iterator = SmallVector<dxil::ResourceInfo>::iterator;
using const_iterator = SmallVector<dxil::ResourceInfo>::const_iterator;

DXILResourceMap(
SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);

iterator begin() { return Resources.begin(); }
const_iterator begin() const { return Resources.begin(); }
iterator end() { return Resources.end(); }
const_iterator end() const { return Resources.end(); }

bool empty() const { return Resources.empty(); }

iterator find(const CallInst *Key) {
auto Pos = CallMap.find(Key);
return Pos == CallMap.end() ? Resources.end()
: (Resources.begin() + Pos->second);
}

const_iterator find(const CallInst *Key) const {
auto Pos = CallMap.find(Key);
return Pos == CallMap.end() ? Resources.end()
: (Resources.begin() + Pos->second);
}

void print(raw_ostream &OS) const;
};

class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
friend AnalysisInfoMixin<DXILResourceAnalysis>;
Expand Down
173 changes: 106 additions & 67 deletions llvm/lib/Analysis/DXILResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,27 +335,45 @@ bool ResourceInfo::operator==(const ResourceInfo &RHS) const {
if (std::tie(Symbol, Name, Binding, RC, Kind) !=
std::tie(RHS.Symbol, RHS.Name, RHS.Binding, RHS.RC, RHS.Kind))
return false;
if (isCBuffer())
return CBufferSize == RHS.CBufferSize;
if (isSampler())
return SamplerTy == RHS.SamplerTy;
if (isUAV() && UAVFlags != RHS.UAVFlags)
if (isCBuffer() && RHS.isCBuffer() && CBufferSize != RHS.CBufferSize)
return false;

if (isStruct())
return Struct == RHS.Struct;
if (isFeedback())
return Feedback == RHS.Feedback;
if (isTyped() && Typed != RHS.Typed)
if (isSampler() && RHS.isSampler() && SamplerTy != RHS.SamplerTy)
return false;
if (isUAV() && RHS.isUAV() && UAVFlags != RHS.UAVFlags)
return false;
if (isStruct() && RHS.isStruct() && Struct != RHS.Struct)
return false;
if (isFeedback() && RHS.isFeedback() && Feedback != RHS.Feedback)
return false;
if (isTyped() && RHS.isTyped() && Typed != RHS.Typed)
return false;
if (isMultiSample() && RHS.isMultiSample() && MultiSample != RHS.MultiSample)
return false;

if (isMultiSample())
return MultiSample == RHS.MultiSample;

assert((Kind == ResourceKind::RawBuffer) && "Unhandled resource kind");
return true;
}

bool ResourceInfo::operator<(const ResourceInfo &RHS) const {
// Skip the symbol to avoid non-determinism, and the name to keep a consistent
// ordering even when we strip reflection data.
if (std::tie(Binding, RC, Kind) < std::tie(RHS.Binding, RHS.RC, RHS.Kind))
return true;
if (isCBuffer() && RHS.isCBuffer() && CBufferSize < RHS.CBufferSize)
return true;
if (isSampler() && RHS.isSampler() && SamplerTy < RHS.SamplerTy)
return true;
if (isUAV() && RHS.isUAV() && UAVFlags < RHS.UAVFlags)
return true;
if (isStruct() && RHS.isStruct() && Struct < RHS.Struct)
return true;
if (isFeedback() && RHS.isFeedback() && Feedback < RHS.Feedback)
return true;
if (isTyped() && RHS.isTyped() && Typed < RHS.Typed)
return true;
if (isMultiSample() && RHS.isMultiSample() && MultiSample < RHS.MultiSample)
return true;
return false;
}

MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const {
SmallVector<Metadata *, 11> MDVals;

Expand Down Expand Up @@ -534,18 +552,10 @@ namespace {
class ResourceMapper {
Module &M;
LLVMContext &Context;
DXILResourceMap &Resources;

// In DXC, Record ID is unique per resource type. Match that.
uint32_t NextUAV = 0;
uint32_t NextSRV = 0;
uint32_t NextCBuf = 0;
uint32_t NextSmp = 0;
SmallVector<std::pair<CallInst *, dxil::ResourceInfo>> Resources;

public:
ResourceMapper(Module &M,
MapVector<CallInst *, dxil::ResourceInfo> &Resources)
: M(M), Context(M.getContext()), Resources(Resources) {}
ResourceMapper(Module &M) : M(M), Context(M.getContext()) {}

void diagnoseHandle(CallInst *CI, const Twine &Msg,
DiagnosticSeverity Severity = DS_Error) {
Expand Down Expand Up @@ -585,13 +595,11 @@ class ResourceMapper {
// TODO: We don't actually keep track of the name right now...
StringRef Name = "";

auto [It, Success] = Resources.try_emplace(CI, RC, Kind, Symbol, Name);
assert(Success && "Mapping the same CallInst again?");
(void)Success;
// We grab a pointer into the map's storage, which isn't generally safe.
// Since we're just using this to fill in the info the map won't mutate and
// the pointer stays valid for as long as we need it to.
ResourceInfo *RI = &(It->second);
// Note that we return a pointer into the vector's storage. This is okay as
// long as we don't add more elements until we're done with the pointer.
auto &Pair =
Resources.emplace_back(CI, ResourceInfo{RC, Kind, Symbol, Name});
ResourceInfo *RI = &Pair.second;

if (RI->isUAV())
// TODO: We need analysis for GloballyCoherent and HasCounter
Expand Down Expand Up @@ -658,27 +666,18 @@ class ResourceMapper {
if (!RI)
return nullptr;

uint32_t NextID;
if (RI->isCBuffer())
NextID = NextCBuf++;
else if (RI->isSampler())
NextID = NextSmp++;
else if (RI->isUAV())
NextID = NextUAV++;
else
NextID = NextSRV++;

uint32_t Space = cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue();
uint32_t LowerBound =
cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
uint32_t Size = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();

RI->bind(NextID, Space, LowerBound, Size);
// We use a binding ID of zero for now - these will be filled in later.
RI->bind(0U, Space, LowerBound, Size);

return RI;
}

void mapResources() {
DXILResourceMap mapResources() {
for (Function &F : M.functions()) {
if (!F.isDeclaration())
continue;
Expand All @@ -697,11 +696,68 @@ class ResourceMapper {
break;
}
}

return DXILResourceMap(std::move(Resources));
}
};

} // namespace

DXILResourceMap::DXILResourceMap(
SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI) {
if (CIToRI.empty())
return;

llvm::stable_sort(CIToRI, [](auto &LHS, auto &RHS) {
// Sort by resource class first for grouping purposes, and then by the rest
// of the fields so that we can remove duplicates.
ResourceClass LRC = LHS.second.getResourceClass();
ResourceClass RRC = RHS.second.getResourceClass();
return std::tie(LRC, LHS.second) < std::tie(RRC, RHS.second);
});
for (auto [CI, RI] : CIToRI) {
if (Resources.empty() || RI != Resources.back())
Resources.push_back(RI);
CallMap[CI] = Resources.size() - 1;
}

unsigned Size = Resources.size();
// In DXC, Record ID is unique per resource type. Match that.
FirstUAV = FirstCBuffer = FirstSampler = Size;
uint32_t NextID = 0;
for (unsigned I = 0, E = Size; I != E; ++I) {
ResourceInfo &RI = Resources[I];
if (RI.isUAV() && FirstUAV == Size) {
FirstUAV = I;
NextID = 0;
} else if (RI.isCBuffer() && FirstCBuffer == Size) {
FirstCBuffer = I;
NextID = 0;
} else if (RI.isSampler() && FirstSampler == Size) {
FirstSampler = I;
NextID = 0;
}

// Adjust the resource binding to use the next ID.
const ResourceInfo::ResourceBinding &Binding = RI.getBinding();
RI.bind(NextID++, Binding.Space, Binding.LowerBound, Binding.Size);
}
}

void DXILResourceMap::print(raw_ostream &OS) const {
for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
OS << "Binding " << I << ":\n";
Resources[I].print(OS);
OS << "\n";
}

for (const auto &[CI, Index] : CallMap) {
OS << "Call bound to " << Index << ":";
CI->print(OS);
OS << "\n";
}
}

//===----------------------------------------------------------------------===//
// DXILResourceAnalysis and DXILResourcePrinterPass

Expand All @@ -710,24 +766,14 @@ AnalysisKey DXILResourceAnalysis::Key;

DXILResourceMap DXILResourceAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
DXILResourceMap Data;
ResourceMapper(M, Data).mapResources();
DXILResourceMap Data = ResourceMapper(M).mapResources();
return Data;
}

PreservedAnalyses DXILResourcePrinterPass::run(Module &M,
ModuleAnalysisManager &AM) {
DXILResourceMap &Data =
AM.getResult<DXILResourceAnalysis>(M);

for (const auto &[Handle, Info] : Data) {
OS << "Binding for ";
Handle->print(OS);
OS << "\n";
Info.print(OS);
OS << "\n";
}

DXILResourceMap &DRM = AM.getResult<DXILResourceAnalysis>(M);
DRM.print(OS);
return PreservedAnalyses::all();
}

Expand All @@ -745,8 +791,7 @@ void DXILResourceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
}

bool DXILResourceWrapperPass::runOnModule(Module &M) {
ResourceMap.reset(new DXILResourceMap());
ResourceMapper(M, *ResourceMap).mapResources();
ResourceMap.reset(new DXILResourceMap(ResourceMapper(M).mapResources()));
return false;
}

Expand All @@ -757,13 +802,7 @@ void DXILResourceWrapperPass::print(raw_ostream &OS, const Module *) const {
OS << "No resource map has been built!\n";
return;
}
for (const auto &[Handle, Info] : *ResourceMap) {
OS << "Binding for ";
Handle->print(OS);
OS << "\n";
Info.print(OS);
OS << "\n";
}
ResourceMap->print(OS);
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
Expand Down
Loading
Loading