Skip to content

Commit e595ba7

Browse files
committed
[DirectX] Introduce the DXILResourceAccess pass
This pass transforms resource access via `llvm.dx.resource.getpointer` into buffer loads and stores. Fixes #114848.
1 parent a480d51 commit e595ba7

File tree

12 files changed

+394
-3
lines changed

12 files changed

+394
-3
lines changed

llvm/include/llvm/Analysis/DXILResource.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ class DXILResourceMap {
275275
DXILResourceMap(
276276
SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);
277277

278+
bool invalidate(Module &M, const PreservedAnalyses &PA,
279+
ModuleAnalysisManager::Invalidator &Inv);
280+
278281
iterator begin() { return Resources.begin(); }
279282
const_iterator begin() const { return Resources.begin(); }
280283
iterator end() { return Resources.end(); }

llvm/lib/Analysis/DXILResource.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,12 @@ DXILResourceMap::DXILResourceMap(
744744
}
745745
}
746746

747+
bool DXILResourceMap::invalidate(Module &M, const PreservedAnalyses &PA,
748+
ModuleAnalysisManager::Invalidator &Inv) {
749+
auto PAC = PA.getChecker<DXILResourceAnalysis>();
750+
return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>());
751+
}
752+
747753
void DXILResourceMap::print(raw_ostream &OS) const {
748754
for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
749755
OS << "Binding " << I << ":\n";

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
3030
DXILPrettyPrinter.cpp
3131
DXILResource.cpp
3232
DXILResourceAnalysis.cpp
33+
DXILResourceAccess.cpp
3334
DXILShaderFlags.cpp
3435
DXILTranslateMetadata.cpp
3536

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "DXILResourceAccess.h"
10+
#include "DirectX.h"
11+
#include "llvm/Analysis/DXILResource.h"
12+
#include "llvm/IR/Dominators.h"
13+
#include "llvm/IR/IRBuilder.h"
14+
#include "llvm/IR/Instructions.h"
15+
#include "llvm/IR/IntrinsicInst.h"
16+
#include "llvm/IR/Intrinsics.h"
17+
#include "llvm/IR/IntrinsicsDirectX.h"
18+
#include "llvm/InitializePasses.h"
19+
20+
#define DEBUG_TYPE "dxil-resource-access"
21+
22+
using namespace llvm;
23+
24+
static void replaceTypedBufferAccess(IntrinsicInst *II,
25+
dxil::ResourceInfo &RI) {
26+
const DataLayout &DL = II->getDataLayout();
27+
28+
auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
29+
assert(HandleType->getName() == "dx.TypedBuffer" &&
30+
"Unexpected typed buffer type");
31+
Type *ContainedType = HandleType->getTypeParameter(0);
32+
Type *ScalarType = ContainedType->getScalarType();
33+
uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
34+
int NumElements = ContainedType->getNumContainedTypes();
35+
if (!NumElements)
36+
NumElements = 1;
37+
38+
// Process users keeping track of indexing accumulated from GEPs.
39+
struct AccessAndIndex {
40+
User *Access;
41+
Value *Index;
42+
};
43+
SmallVector<AccessAndIndex> Worklist;
44+
for (User *U : II->users())
45+
Worklist.push_back({U, nullptr});
46+
47+
SmallVector<Instruction *> DeadInsts;
48+
while (!Worklist.empty()) {
49+
AccessAndIndex Current = Worklist.back();
50+
Worklist.pop_back();
51+
52+
if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
53+
IRBuilder<> Builder(GEP);
54+
55+
Value *Index;
56+
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
57+
if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
58+
APInt Scaled = ConstantOffset.udiv(ScalarSize);
59+
Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
60+
} else {
61+
auto IndexIt = GEP->idx_begin();
62+
assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
63+
"GEP is not indexing through pointer");
64+
++IndexIt;
65+
Index = *IndexIt;
66+
assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
67+
}
68+
69+
for (User *U : GEP->users())
70+
Worklist.push_back({U, Index});
71+
DeadInsts.push_back(GEP);
72+
73+
} else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
74+
assert(SI->getValueOperand() != II && "Pointer escaped!");
75+
IRBuilder<> Builder(SI);
76+
77+
Value *V = SI->getValueOperand();
78+
if (V->getType() == ContainedType) {
79+
// V is already the right type.
80+
} else if (V->getType() == ScalarType) {
81+
// We're storing a scalar, so we need to load the current value and only
82+
// replace the relevant part.
83+
auto *Load = Builder.CreateIntrinsic(
84+
ContainedType, Intrinsic::dx_typedBufferLoad,
85+
{II->getOperand(0), II->getOperand(1)});
86+
// If we have an offset from seeing a GEP earlier, use it.
87+
Value *IndexOp = Current.Index
88+
? Current.Index
89+
: ConstantInt::get(Builder.getInt32Ty(), 0);
90+
V = Builder.CreateInsertElement(Load, V, IndexOp);
91+
} else {
92+
llvm_unreachable("Store to typed resource has invalid type");
93+
}
94+
95+
auto *Inst = Builder.CreateIntrinsic(
96+
Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
97+
{II->getOperand(0), II->getOperand(1), V});
98+
SI->replaceAllUsesWith(Inst);
99+
DeadInsts.push_back(SI);
100+
101+
} else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
102+
IRBuilder<> Builder(LI);
103+
Value *V =
104+
Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
105+
{II->getOperand(0), II->getOperand(1)});
106+
if (Current.Index)
107+
V = Builder.CreateExtractElement(V, Current.Index);
108+
109+
LI->replaceAllUsesWith(V);
110+
DeadInsts.push_back(LI);
111+
112+
} else
113+
llvm_unreachable("Unhandled instruction - pointer escaped?");
114+
}
115+
116+
// Traverse the now-dead instructions in RPO and remove them.
117+
for (Instruction *Dead : llvm::reverse(DeadInsts))
118+
Dead->eraseFromParent();
119+
II->eraseFromParent();
120+
}
121+
122+
static bool transformResourcePointers(Function &F, DXILResourceMap &DRM) {
123+
// TODO: Should we have a more efficient way to find resources used in a
124+
// particular function?
125+
SmallVector<std::pair<IntrinsicInst *, dxil::ResourceInfo &>> Resources;
126+
for (BasicBlock &BB : F)
127+
for (Instruction &I : BB)
128+
if (auto *CI = dyn_cast<CallInst>(&I)) {
129+
auto It = DRM.find(CI);
130+
if (It == DRM.end())
131+
continue;
132+
for (User *U : CI->users())
133+
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U))
134+
if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer)
135+
Resources.emplace_back(II, *It);
136+
}
137+
138+
for (const auto &[II, RI] : Resources) {
139+
if (RI.isTyped())
140+
replaceTypedBufferAccess(II, RI);
141+
142+
// TODO: handle other resource types. We should probably have an
143+
// `unreachable` here once we've added support for all of them.
144+
}
145+
146+
return false;
147+
}
148+
149+
PreservedAnalyses DXILResourceAccess::run(Function &F,
150+
FunctionAnalysisManager &FAM) {
151+
auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
152+
DXILResourceMap *DRM =
153+
MAMProxy.getCachedResult<DXILResourceAnalysis>(*F.getParent());
154+
assert(DRM && "DXILResourceAnalysis must be available");
155+
156+
bool MadeChanges = transformResourcePointers(F, *DRM);
157+
if (!MadeChanges)
158+
return PreservedAnalyses::all();
159+
160+
PreservedAnalyses PA;
161+
PA.preserve<DXILResourceAnalysis>();
162+
PA.preserve<DominatorTreeAnalysis>();
163+
return PA;
164+
}
165+
166+
namespace {
167+
class DXILResourceAccessLegacy : public FunctionPass {
168+
public:
169+
bool runOnFunction(Function &F) override {
170+
DXILResourceMap &DRM =
171+
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
172+
173+
return transformResourcePointers(F, DRM);
174+
}
175+
StringRef getPassName() const override { return "DXIL Resource Access"; }
176+
DXILResourceAccessLegacy() : FunctionPass(ID) {}
177+
178+
static char ID; // Pass identification.
179+
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
180+
AU.addRequired<DXILResourceWrapperPass>();
181+
AU.addPreserved<DXILResourceWrapperPass>();
182+
AU.addPreserved<DominatorTreeWrapperPass>();
183+
}
184+
};
185+
char DXILResourceAccessLegacy::ID = 0;
186+
} // end anonymous namespace
187+
188+
INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
189+
"DXIL Resource Access", false, false)
190+
INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
191+
INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
192+
"DXIL Resource Access", false, false)
193+
194+
FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
195+
return new DXILResourceAccessLegacy();
196+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file Pass for replacing pointers to DXIL resources with load and store
10+
// operations.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
15+
#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
16+
17+
#include "llvm/IR/PassManager.h"
18+
19+
namespace llvm {
20+
21+
class DXILResourceAccess: public PassInfoMixin<DXILResourceAccess> {
22+
public:
23+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
24+
};
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H

llvm/lib/Target/DirectX/DirectX.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H
1313

1414
namespace llvm {
15+
class FunctionPass;
1516
class ModulePass;
1617
class PassRegistry;
1718
class raw_ostream;
@@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
5253
/// Pass to lowering LLVM intrinsic call to DXIL op function call.
5354
ModulePass *createDXILOpLoweringLegacyPass();
5455

56+
/// Initializer for DXILResourceAccess
57+
void initializeDXILResourceAccessLegacyPass(PassRegistry &);
58+
59+
/// Pass to update resource accesses to use load/store directly.
60+
FunctionPass *createDXILResourceAccessLegacyPass();
61+
5562
/// Initializer for DXILTranslateMetadata.
5663
void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);
5764

llvm/lib/Target/DirectX/DirectXPassRegistry.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
3232
// TODO: rename to print<foo> after NPM switch
3333
MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
3434
#undef MODULE_PASS
35+
36+
#ifndef FUNCTION_PASS
37+
#define FUNCTION_PASS(NAME, CREATE_PASS)
38+
#endif
39+
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
40+
#undef FUNCTION_PASS

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "DXILIntrinsicExpansion.h"
1818
#include "DXILOpLowering.h"
1919
#include "DXILPrettyPrinter.h"
20+
#include "DXILResourceAccess.h"
2021
#include "DXILResourceAnalysis.h"
2122
#include "DXILShaderFlags.h"
2223
#include "DXILTranslateMetadata.h"
@@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
5657
initializeWriteDXILPassPass(*PR);
5758
initializeDXContainerGlobalsPass(*PR);
5859
initializeDXILOpLoweringLegacyPass(*PR);
60+
initializeDXILResourceAccessLegacyPass(*PR);
5961
initializeDXILTranslateMetadataLegacyPass(*PR);
6062
initializeDXILResourceMDWrapperPass(*PR);
6163
initializeShaderFlagsAnalysisWrapperPass(*PR);
@@ -92,9 +94,10 @@ class DirectXPassConfig : public TargetPassConfig {
9294
addPass(createDXILFinalizeLinkageLegacyPass());
9395
addPass(createDXILIntrinsicExpansionLegacyPass());
9496
addPass(createDXILDataScalarizationLegacyPass());
97+
addPass(createDXILFlattenArraysLegacyPass());
98+
addPass(createDXILResourceAccessLegacyPass());
9599
ScalarizerPassOptions DxilScalarOptions;
96100
DxilScalarOptions.ScalarizeLoadStore = true;
97-
addPass(createDXILFlattenArraysLegacyPass());
98101
addPass(createScalarizerPass(DxilScalarOptions));
99102
addPass(createDXILOpLoweringLegacyPass());
100103
addPass(createDXILTranslateMetadataLegacyPass());

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/ADT/PostOrderIterator.h"
1919
#include "llvm/ADT/SmallVector.h"
2020
#include "llvm/ADT/Twine.h"
21+
#include "llvm/Analysis/DXILResource.h"
2122
#include "llvm/Analysis/TargetTransformInfo.h"
2223
#include "llvm/Analysis/VectorUtils.h"
2324
#include "llvm/IR/Argument.h"
@@ -351,6 +352,7 @@ void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
351352
AU.addRequired<DominatorTreeWrapperPass>();
352353
AU.addRequired<TargetTransformInfoWrapperPass>();
353354
AU.addPreserved<DominatorTreeWrapperPass>();
355+
AU.addPreserved<DXILResourceWrapperPass>();
354356
}
355357

356358
char ScalarizerLegacyPass::ID = 0;
@@ -1348,5 +1350,6 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM)
13481350
bool Changed = Impl.visit(F);
13491351
PreservedAnalyses PA;
13501352
PA.preserve<DominatorTreeAnalysis>();
1353+
PA.preserve<DXILResourceAnalysis>();
13511354
return Changed ? PA : PreservedAnalyses::all();
13521355
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; RUN: opt -S -dxil-resource-access %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.6-compute"
4+
5+
declare void @use_float4(<4 x float>)
6+
declare void @use_float(<4 x float>)
7+
8+
; CHECK-LABEL: define void @load_float4
9+
define void @load_float4(i32 %index, i32 %elemindex) {
10+
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
11+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
12+
i32 0, i32 0, i32 1, i32 0, i1 false)
13+
14+
; CHECK-NOT: @llvm.dx.resource.getpointer
15+
%ptr = call ptr @llvm.dx.resource.getpointer(
16+
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
17+
18+
; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
19+
%vec_data = load <4 x float>, ptr %ptr
20+
call void @use_float4(<4 x float> %vec_data)
21+
22+
; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
23+
; CHECK: extractelement <4 x float> %[[VALUE]], i32 4
24+
%y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 4
25+
%y_data = load float, ptr %y_ptr
26+
call void @use_float(float %y_data)
27+
28+
; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
29+
; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
30+
%dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
31+
%dyndata = load float, ptr %dynamic
32+
call void @use_float(float %dyndata)
33+
34+
ret void
35+
}

0 commit comments

Comments
 (0)