Skip to content

Commit 782bc4f

Browse files
authored
[DXIL][Analysis] Uniquify duplicate resources in DXILResourceAnalysis
If a resources is used multiple times, we should only have one resource record for it. This comes up most prominantly with arrays of resources like so: ```hlsl RWBuffer<float4> BufferArray[10] : register(u0, space4); RWBuffer<float4> B1 = BufferArray[0]; RWBuffer<float4> B2 = BufferArray[SomeIndex]; RWBuffer<float4> B3 = BufferArray[3]; ``` In this case, there's only one resource, but we'll generate 3 different `dx.handle.fromBinding` calls to access different slices. Note that this adds some API that won't be used until #104447 later in the stack. Trying to avoid that results in unnecessary churn. Fixes #105143 Pull Request: #105602
1 parent 00620ab commit 782bc4f

File tree

3 files changed

+251
-147
lines changed

3 files changed

+251
-147
lines changed

llvm/include/llvm/Analysis/DXILResource.h

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Value;
2525
namespace dxil {
2626

2727
class ResourceInfo {
28+
public:
2829
struct ResourceBinding {
2930
uint32_t RecordID;
3031
uint32_t Space;
@@ -38,6 +39,10 @@ class ResourceInfo {
3839
bool operator!=(const ResourceBinding &RHS) const {
3940
return !(*this == RHS);
4041
}
42+
bool operator<(const ResourceBinding &RHS) const {
43+
return std::tie(RecordID, Space, LowerBound, Size) <
44+
std::tie(RHS.RecordID, RHS.Space, RHS.LowerBound, RHS.Size);
45+
}
4146
};
4247

4348
struct UAVInfo {
@@ -50,6 +55,10 @@ class ResourceInfo {
5055
std::tie(RHS.GloballyCoherent, RHS.HasCounter, RHS.IsROV);
5156
}
5257
bool operator!=(const UAVInfo &RHS) const { return !(*this == RHS); }
58+
bool operator<(const UAVInfo &RHS) const {
59+
return std::tie(GloballyCoherent, HasCounter, IsROV) <
60+
std::tie(RHS.GloballyCoherent, RHS.HasCounter, RHS.IsROV);
61+
}
5362
};
5463

5564
struct StructInfo {
@@ -64,6 +73,9 @@ class ResourceInfo {
6473
return std::tie(Stride, AlignLog2) == std::tie(RHS.Stride, RHS.AlignLog2);
6574
}
6675
bool operator!=(const StructInfo &RHS) const { return !(*this == RHS); }
76+
bool operator<(const StructInfo &RHS) const {
77+
return std::tie(Stride, AlignLog2) < std::tie(RHS.Stride, RHS.AlignLog2);
78+
}
6779
};
6880

6981
struct TypedInfo {
@@ -75,22 +87,29 @@ class ResourceInfo {
7587
std::tie(RHS.ElementTy, RHS.ElementCount);
7688
}
7789
bool operator!=(const TypedInfo &RHS) const { return !(*this == RHS); }
90+
bool operator<(const TypedInfo &RHS) const {
91+
return std::tie(ElementTy, ElementCount) <
92+
std::tie(RHS.ElementTy, RHS.ElementCount);
93+
}
7894
};
7995

8096
struct MSInfo {
8197
uint32_t Count;
8298

8399
bool operator==(const MSInfo &RHS) const { return Count == RHS.Count; }
84100
bool operator!=(const MSInfo &RHS) const { return !(*this == RHS); }
101+
bool operator<(const MSInfo &RHS) const { return Count < RHS.Count; }
85102
};
86103

87104
struct FeedbackInfo {
88105
dxil::SamplerFeedbackType Type;
89106

90107
bool operator==(const FeedbackInfo &RHS) const { return Type == RHS.Type; }
91108
bool operator!=(const FeedbackInfo &RHS) const { return !(*this == RHS); }
109+
bool operator<(const FeedbackInfo &RHS) const { return Type < RHS.Type; }
92110
};
93111

112+
private:
94113
// Universal properties.
95114
Value *Symbol;
96115
StringRef Name;
@@ -138,6 +157,7 @@ class ResourceInfo {
138157
Binding.LowerBound = LowerBound;
139158
Binding.Size = Size;
140159
}
160+
const ResourceBinding &getBinding() const { return Binding; }
141161
void setUAV(bool GloballyCoherent, bool HasCounter, bool IsROV) {
142162
assert(isUAV() && "Not a UAV");
143163
UAVFlags.GloballyCoherent = GloballyCoherent;
@@ -168,7 +188,11 @@ class ResourceInfo {
168188
MultiSample.Count = Count;
169189
}
170190

191+
dxil::ResourceClass getResourceClass() const { return RC; }
192+
171193
bool operator==(const ResourceInfo &RHS) const;
194+
bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); }
195+
bool operator<(const ResourceInfo &RHS) const;
172196

173197
static ResourceInfo SRV(Value *Symbol, StringRef Name,
174198
dxil::ElementType ElementTy, uint32_t ElementCount,
@@ -216,15 +240,48 @@ class ResourceInfo {
216240

217241
MDTuple *getAsMetadata(LLVMContext &Ctx) const;
218242

219-
ResourceBinding getBinding() const { return Binding; }
220243
std::pair<uint32_t, uint32_t> getAnnotateProps() const;
221244

222245
void print(raw_ostream &OS) const;
223246
};
224247

225248
} // namespace dxil
226249

227-
using DXILResourceMap = MapVector<CallInst *, dxil::ResourceInfo>;
250+
class DXILResourceMap {
251+
SmallVector<dxil::ResourceInfo> Resources;
252+
DenseMap<CallInst *, unsigned> CallMap;
253+
unsigned FirstUAV = 0;
254+
unsigned FirstCBuffer = 0;
255+
unsigned FirstSampler = 0;
256+
257+
public:
258+
using iterator = SmallVector<dxil::ResourceInfo>::iterator;
259+
using const_iterator = SmallVector<dxil::ResourceInfo>::const_iterator;
260+
261+
DXILResourceMap(
262+
SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);
263+
264+
iterator begin() { return Resources.begin(); }
265+
const_iterator begin() const { return Resources.begin(); }
266+
iterator end() { return Resources.end(); }
267+
const_iterator end() const { return Resources.end(); }
268+
269+
bool empty() const { return Resources.empty(); }
270+
271+
iterator find(const CallInst *Key) {
272+
auto Pos = CallMap.find(Key);
273+
return Pos == CallMap.end() ? Resources.end()
274+
: (Resources.begin() + Pos->second);
275+
}
276+
277+
const_iterator find(const CallInst *Key) const {
278+
auto Pos = CallMap.find(Key);
279+
return Pos == CallMap.end() ? Resources.end()
280+
: (Resources.begin() + Pos->second);
281+
}
282+
283+
void print(raw_ostream &OS) const;
284+
};
228285

229286
class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
230287
friend AnalysisInfoMixin<DXILResourceAnalysis>;

llvm/lib/Analysis/DXILResource.cpp

Lines changed: 106 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -335,27 +335,45 @@ bool ResourceInfo::operator==(const ResourceInfo &RHS) const {
335335
if (std::tie(Symbol, Name, Binding, RC, Kind) !=
336336
std::tie(RHS.Symbol, RHS.Name, RHS.Binding, RHS.RC, RHS.Kind))
337337
return false;
338-
if (isCBuffer())
339-
return CBufferSize == RHS.CBufferSize;
340-
if (isSampler())
341-
return SamplerTy == RHS.SamplerTy;
342-
if (isUAV() && UAVFlags != RHS.UAVFlags)
338+
if (isCBuffer() && RHS.isCBuffer() && CBufferSize != RHS.CBufferSize)
343339
return false;
344-
345-
if (isStruct())
346-
return Struct == RHS.Struct;
347-
if (isFeedback())
348-
return Feedback == RHS.Feedback;
349-
if (isTyped() && Typed != RHS.Typed)
340+
if (isSampler() && RHS.isSampler() && SamplerTy != RHS.SamplerTy)
341+
return false;
342+
if (isUAV() && RHS.isUAV() && UAVFlags != RHS.UAVFlags)
343+
return false;
344+
if (isStruct() && RHS.isStruct() && Struct != RHS.Struct)
345+
return false;
346+
if (isFeedback() && RHS.isFeedback() && Feedback != RHS.Feedback)
347+
return false;
348+
if (isTyped() && RHS.isTyped() && Typed != RHS.Typed)
349+
return false;
350+
if (isMultiSample() && RHS.isMultiSample() && MultiSample != RHS.MultiSample)
350351
return false;
351-
352-
if (isMultiSample())
353-
return MultiSample == RHS.MultiSample;
354-
355-
assert((Kind == ResourceKind::RawBuffer) && "Unhandled resource kind");
356352
return true;
357353
}
358354

355+
bool ResourceInfo::operator<(const ResourceInfo &RHS) const {
356+
// Skip the symbol to avoid non-determinism, and the name to keep a consistent
357+
// ordering even when we strip reflection data.
358+
if (std::tie(Binding, RC, Kind) < std::tie(RHS.Binding, RHS.RC, RHS.Kind))
359+
return true;
360+
if (isCBuffer() && RHS.isCBuffer() && CBufferSize < RHS.CBufferSize)
361+
return true;
362+
if (isSampler() && RHS.isSampler() && SamplerTy < RHS.SamplerTy)
363+
return true;
364+
if (isUAV() && RHS.isUAV() && UAVFlags < RHS.UAVFlags)
365+
return true;
366+
if (isStruct() && RHS.isStruct() && Struct < RHS.Struct)
367+
return true;
368+
if (isFeedback() && RHS.isFeedback() && Feedback < RHS.Feedback)
369+
return true;
370+
if (isTyped() && RHS.isTyped() && Typed < RHS.Typed)
371+
return true;
372+
if (isMultiSample() && RHS.isMultiSample() && MultiSample < RHS.MultiSample)
373+
return true;
374+
return false;
375+
}
376+
359377
MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const {
360378
SmallVector<Metadata *, 11> MDVals;
361379

@@ -534,18 +552,10 @@ namespace {
534552
class ResourceMapper {
535553
Module &M;
536554
LLVMContext &Context;
537-
DXILResourceMap &Resources;
538-
539-
// In DXC, Record ID is unique per resource type. Match that.
540-
uint32_t NextUAV = 0;
541-
uint32_t NextSRV = 0;
542-
uint32_t NextCBuf = 0;
543-
uint32_t NextSmp = 0;
555+
SmallVector<std::pair<CallInst *, dxil::ResourceInfo>> Resources;
544556

545557
public:
546-
ResourceMapper(Module &M,
547-
MapVector<CallInst *, dxil::ResourceInfo> &Resources)
548-
: M(M), Context(M.getContext()), Resources(Resources) {}
558+
ResourceMapper(Module &M) : M(M), Context(M.getContext()) {}
549559

550560
void diagnoseHandle(CallInst *CI, const Twine &Msg,
551561
DiagnosticSeverity Severity = DS_Error) {
@@ -585,13 +595,11 @@ class ResourceMapper {
585595
// TODO: We don't actually keep track of the name right now...
586596
StringRef Name = "";
587597

588-
auto [It, Success] = Resources.try_emplace(CI, RC, Kind, Symbol, Name);
589-
assert(Success && "Mapping the same CallInst again?");
590-
(void)Success;
591-
// We grab a pointer into the map's storage, which isn't generally safe.
592-
// Since we're just using this to fill in the info the map won't mutate and
593-
// the pointer stays valid for as long as we need it to.
594-
ResourceInfo *RI = &(It->second);
598+
// Note that we return a pointer into the vector's storage. This is okay as
599+
// long as we don't add more elements until we're done with the pointer.
600+
auto &Pair =
601+
Resources.emplace_back(CI, ResourceInfo{RC, Kind, Symbol, Name});
602+
ResourceInfo *RI = &Pair.second;
595603

596604
if (RI->isUAV())
597605
// TODO: We need analysis for GloballyCoherent and HasCounter
@@ -658,27 +666,18 @@ class ResourceMapper {
658666
if (!RI)
659667
return nullptr;
660668

661-
uint32_t NextID;
662-
if (RI->isCBuffer())
663-
NextID = NextCBuf++;
664-
else if (RI->isSampler())
665-
NextID = NextSmp++;
666-
else if (RI->isUAV())
667-
NextID = NextUAV++;
668-
else
669-
NextID = NextSRV++;
670-
671669
uint32_t Space = cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue();
672670
uint32_t LowerBound =
673671
cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
674672
uint32_t Size = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
675673

676-
RI->bind(NextID, Space, LowerBound, Size);
674+
// We use a binding ID of zero for now - these will be filled in later.
675+
RI->bind(0U, Space, LowerBound, Size);
677676

678677
return RI;
679678
}
680679

681-
void mapResources() {
680+
DXILResourceMap mapResources() {
682681
for (Function &F : M.functions()) {
683682
if (!F.isDeclaration())
684683
continue;
@@ -697,11 +696,68 @@ class ResourceMapper {
697696
break;
698697
}
699698
}
699+
700+
return DXILResourceMap(std::move(Resources));
700701
}
701702
};
702703

703704
} // namespace
704705

706+
DXILResourceMap::DXILResourceMap(
707+
SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI) {
708+
if (CIToRI.empty())
709+
return;
710+
711+
llvm::stable_sort(CIToRI, [](auto &LHS, auto &RHS) {
712+
// Sort by resource class first for grouping purposes, and then by the rest
713+
// of the fields so that we can remove duplicates.
714+
ResourceClass LRC = LHS.second.getResourceClass();
715+
ResourceClass RRC = RHS.second.getResourceClass();
716+
return std::tie(LRC, LHS.second) < std::tie(RRC, RHS.second);
717+
});
718+
for (auto [CI, RI] : CIToRI) {
719+
if (Resources.empty() || RI != Resources.back())
720+
Resources.push_back(RI);
721+
CallMap[CI] = Resources.size() - 1;
722+
}
723+
724+
unsigned Size = Resources.size();
725+
// In DXC, Record ID is unique per resource type. Match that.
726+
FirstUAV = FirstCBuffer = FirstSampler = Size;
727+
uint32_t NextID = 0;
728+
for (unsigned I = 0, E = Size; I != E; ++I) {
729+
ResourceInfo &RI = Resources[I];
730+
if (RI.isUAV() && FirstUAV == Size) {
731+
FirstUAV = I;
732+
NextID = 0;
733+
} else if (RI.isCBuffer() && FirstCBuffer == Size) {
734+
FirstCBuffer = I;
735+
NextID = 0;
736+
} else if (RI.isSampler() && FirstSampler == Size) {
737+
FirstSampler = I;
738+
NextID = 0;
739+
}
740+
741+
// Adjust the resource binding to use the next ID.
742+
const ResourceInfo::ResourceBinding &Binding = RI.getBinding();
743+
RI.bind(NextID++, Binding.Space, Binding.LowerBound, Binding.Size);
744+
}
745+
}
746+
747+
void DXILResourceMap::print(raw_ostream &OS) const {
748+
for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
749+
OS << "Binding " << I << ":\n";
750+
Resources[I].print(OS);
751+
OS << "\n";
752+
}
753+
754+
for (const auto &[CI, Index] : CallMap) {
755+
OS << "Call bound to " << Index << ":";
756+
CI->print(OS);
757+
OS << "\n";
758+
}
759+
}
760+
705761
//===----------------------------------------------------------------------===//
706762
// DXILResourceAnalysis and DXILResourcePrinterPass
707763

@@ -710,24 +766,14 @@ AnalysisKey DXILResourceAnalysis::Key;
710766

711767
DXILResourceMap DXILResourceAnalysis::run(Module &M,
712768
ModuleAnalysisManager &AM) {
713-
DXILResourceMap Data;
714-
ResourceMapper(M, Data).mapResources();
769+
DXILResourceMap Data = ResourceMapper(M).mapResources();
715770
return Data;
716771
}
717772

718773
PreservedAnalyses DXILResourcePrinterPass::run(Module &M,
719774
ModuleAnalysisManager &AM) {
720-
DXILResourceMap &Data =
721-
AM.getResult<DXILResourceAnalysis>(M);
722-
723-
for (const auto &[Handle, Info] : Data) {
724-
OS << "Binding for ";
725-
Handle->print(OS);
726-
OS << "\n";
727-
Info.print(OS);
728-
OS << "\n";
729-
}
730-
775+
DXILResourceMap &DRM = AM.getResult<DXILResourceAnalysis>(M);
776+
DRM.print(OS);
731777
return PreservedAnalyses::all();
732778
}
733779

@@ -745,8 +791,7 @@ void DXILResourceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
745791
}
746792

747793
bool DXILResourceWrapperPass::runOnModule(Module &M) {
748-
ResourceMap.reset(new DXILResourceMap());
749-
ResourceMapper(M, *ResourceMap).mapResources();
794+
ResourceMap.reset(new DXILResourceMap(ResourceMapper(M).mapResources()));
750795
return false;
751796
}
752797

@@ -757,13 +802,7 @@ void DXILResourceWrapperPass::print(raw_ostream &OS, const Module *) const {
757802
OS << "No resource map has been built!\n";
758803
return;
759804
}
760-
for (const auto &[Handle, Info] : *ResourceMap) {
761-
OS << "Binding for ";
762-
Handle->print(OS);
763-
OS << "\n";
764-
Info.print(OS);
765-
OS << "\n";
766-
}
805+
ResourceMap->print(OS);
767806
}
768807

769808
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

0 commit comments

Comments
 (0)