Skip to content

Commit 0fca76d

Browse files
authored
[DirectX] Introduce the DXILResourceAccess pass (#116726)
This pass transforms resource access via `llvm.dx.resource.getpointer` into buffer loads and stores. Fixes #114848.
1 parent 21de514 commit 0fca76d

File tree

10 files changed

+389
-2
lines changed

10 files changed

+389
-2
lines changed

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
3030
DXILPrettyPrinter.cpp
3131
DXILResource.cpp
3232
DXILResourceAnalysis.cpp
33+
DXILResourceAccess.cpp
3334
DXILShaderFlags.cpp
3435
DXILTranslateMetadata.cpp
3536

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,14 @@ class OpLowerer {
554554
});
555555
}
556556

557+
[[nodiscard]] bool lowerGetPointer(Function &F) {
558+
// These should have already been handled in DXILResourceAccess, so we can
559+
// just clean up the dead prototype.
560+
assert(F.user_empty() && "getpointer operations should have been removed");
561+
F.eraseFromParent();
562+
return false;
563+
}
564+
557565
[[nodiscard]] bool lowerTypedBufferStore(Function &F) {
558566
IRBuilder<> &IRB = OpBuilder.getIRB();
559567
Type *Int8Ty = IRB.getInt8Ty();
@@ -707,6 +715,9 @@ class OpLowerer {
707715
case Intrinsic::dx_handle_fromBinding:
708716
HasErrors |= lowerHandleFromBinding(F);
709717
break;
718+
case Intrinsic::dx_resource_getpointer:
719+
HasErrors |= lowerGetPointer(F);
720+
break;
710721
case Intrinsic::dx_typedBufferLoad:
711722
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
712723
break;
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "DXILResourceAccess.h"
10+
#include "DirectX.h"
11+
#include "llvm/Analysis/DXILResource.h"
12+
#include "llvm/IR/Dominators.h"
13+
#include "llvm/IR/IRBuilder.h"
14+
#include "llvm/IR/Instructions.h"
15+
#include "llvm/IR/IntrinsicInst.h"
16+
#include "llvm/IR/Intrinsics.h"
17+
#include "llvm/IR/IntrinsicsDirectX.h"
18+
#include "llvm/InitializePasses.h"
19+
20+
#define DEBUG_TYPE "dxil-resource-access"
21+
22+
using namespace llvm;
23+
24+
static void replaceTypedBufferAccess(IntrinsicInst *II,
25+
dxil::ResourceTypeInfo &RTI) {
26+
const DataLayout &DL = II->getDataLayout();
27+
28+
auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
29+
assert(HandleType->getName() == "dx.TypedBuffer" &&
30+
"Unexpected typed buffer type");
31+
Type *ContainedType = HandleType->getTypeParameter(0);
32+
33+
// We need the size of an element in bytes so that we can calculate the offset
34+
// in elements given a total offset in bytes later.
35+
Type *ScalarType = ContainedType->getScalarType();
36+
uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
37+
38+
// Process users keeping track of indexing accumulated from GEPs.
39+
struct AccessAndIndex {
40+
User *Access;
41+
Value *Index;
42+
};
43+
SmallVector<AccessAndIndex> Worklist;
44+
for (User *U : II->users())
45+
Worklist.push_back({U, nullptr});
46+
47+
SmallVector<Instruction *> DeadInsts;
48+
while (!Worklist.empty()) {
49+
AccessAndIndex Current = Worklist.back();
50+
Worklist.pop_back();
51+
52+
if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
53+
IRBuilder<> Builder(GEP);
54+
55+
Value *Index;
56+
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
57+
if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
58+
APInt Scaled = ConstantOffset.udiv(ScalarSize);
59+
Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
60+
} else {
61+
auto IndexIt = GEP->idx_begin();
62+
assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
63+
"GEP is not indexing through pointer");
64+
++IndexIt;
65+
Index = *IndexIt;
66+
assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
67+
}
68+
69+
for (User *U : GEP->users())
70+
Worklist.push_back({U, Index});
71+
DeadInsts.push_back(GEP);
72+
73+
} else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
74+
assert(SI->getValueOperand() != II && "Pointer escaped!");
75+
IRBuilder<> Builder(SI);
76+
77+
Value *V = SI->getValueOperand();
78+
if (V->getType() == ContainedType) {
79+
// V is already the right type.
80+
} else if (V->getType() == ScalarType) {
81+
// We're storing a scalar, so we need to load the current value and only
82+
// replace the relevant part.
83+
auto *Load = Builder.CreateIntrinsic(
84+
ContainedType, Intrinsic::dx_typedBufferLoad,
85+
{II->getOperand(0), II->getOperand(1)});
86+
// If we have an offset from seeing a GEP earlier, use it.
87+
Value *IndexOp = Current.Index
88+
? Current.Index
89+
: ConstantInt::get(Builder.getInt32Ty(), 0);
90+
V = Builder.CreateInsertElement(Load, V, IndexOp);
91+
} else {
92+
llvm_unreachable("Store to typed resource has invalid type");
93+
}
94+
95+
auto *Inst = Builder.CreateIntrinsic(
96+
Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
97+
{II->getOperand(0), II->getOperand(1), V});
98+
SI->replaceAllUsesWith(Inst);
99+
DeadInsts.push_back(SI);
100+
101+
} else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
102+
IRBuilder<> Builder(LI);
103+
Value *V =
104+
Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
105+
{II->getOperand(0), II->getOperand(1)});
106+
if (Current.Index)
107+
V = Builder.CreateExtractElement(V, Current.Index);
108+
109+
LI->replaceAllUsesWith(V);
110+
DeadInsts.push_back(LI);
111+
112+
} else
113+
llvm_unreachable("Unhandled instruction - pointer escaped?");
114+
}
115+
116+
// Traverse the now-dead instructions in RPO and remove them.
117+
for (Instruction *Dead : llvm::reverse(DeadInsts))
118+
Dead->eraseFromParent();
119+
II->eraseFromParent();
120+
}
121+
122+
static bool transformResourcePointers(Function &F, DXILResourceTypeMap &DRTM) {
123+
bool Changed = false;
124+
SmallVector<std::pair<IntrinsicInst *, dxil::ResourceTypeInfo>> Resources;
125+
for (BasicBlock &BB : F)
126+
for (Instruction &I : BB)
127+
if (auto *II = dyn_cast<IntrinsicInst>(&I))
128+
if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) {
129+
auto *HandleTy = cast<TargetExtType>(II->getArgOperand(0)->getType());
130+
Resources.emplace_back(II, DRTM[HandleTy]);
131+
}
132+
133+
for (auto &[II, RI] : Resources) {
134+
if (RI.isTyped()) {
135+
Changed = true;
136+
replaceTypedBufferAccess(II, RI);
137+
}
138+
139+
// TODO: handle other resource types. We should probably have an
140+
// `unreachable` here once we've added support for all of them.
141+
}
142+
143+
return Changed;
144+
}
145+
146+
PreservedAnalyses DXILResourceAccess::run(Function &F,
147+
FunctionAnalysisManager &FAM) {
148+
auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
149+
DXILResourceTypeMap *DRTM =
150+
MAMProxy.getCachedResult<DXILResourceTypeAnalysis>(*F.getParent());
151+
assert(DRTM && "DXILResourceTypeAnalysis must be available");
152+
153+
bool MadeChanges = transformResourcePointers(F, *DRTM);
154+
if (!MadeChanges)
155+
return PreservedAnalyses::all();
156+
157+
PreservedAnalyses PA;
158+
PA.preserve<DXILResourceTypeAnalysis>();
159+
PA.preserve<DominatorTreeAnalysis>();
160+
return PA;
161+
}
162+
163+
namespace {
164+
class DXILResourceAccessLegacy : public FunctionPass {
165+
public:
166+
bool runOnFunction(Function &F) override {
167+
DXILResourceTypeMap &DRTM =
168+
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
169+
170+
return transformResourcePointers(F, DRTM);
171+
}
172+
StringRef getPassName() const override { return "DXIL Resource Access"; }
173+
DXILResourceAccessLegacy() : FunctionPass(ID) {}
174+
175+
static char ID; // Pass identification.
176+
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
177+
AU.addRequired<DXILResourceTypeWrapperPass>();
178+
AU.addPreserved<DominatorTreeWrapperPass>();
179+
}
180+
};
181+
char DXILResourceAccessLegacy::ID = 0;
182+
} // end anonymous namespace
183+
184+
INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
185+
"DXIL Resource Access", false, false)
186+
INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
187+
INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
188+
"DXIL Resource Access", false, false)
189+
190+
FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
191+
return new DXILResourceAccessLegacy();
192+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file Pass for replacing pointers to DXIL resources with load and store
10+
// operations.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
15+
#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
16+
17+
#include "llvm/IR/PassManager.h"
18+
19+
namespace llvm {
20+
21+
class DXILResourceAccess : public PassInfoMixin<DXILResourceAccess> {
22+
public:
23+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
24+
};
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H

llvm/lib/Target/DirectX/DirectX.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H
1313

1414
namespace llvm {
15+
class FunctionPass;
1516
class ModulePass;
1617
class PassRegistry;
1718
class raw_ostream;
@@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
5253
/// Pass to lowering LLVM intrinsic call to DXIL op function call.
5354
ModulePass *createDXILOpLoweringLegacyPass();
5455

56+
/// Initializer for DXILResourceAccess
57+
void initializeDXILResourceAccessLegacyPass(PassRegistry &);
58+
59+
/// Pass to update resource accesses to use load/store directly.
60+
FunctionPass *createDXILResourceAccessLegacyPass();
61+
5562
/// Initializer for DXILTranslateMetadata.
5663
void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);
5764

llvm/lib/Target/DirectX/DirectXPassRegistry.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
3232
// TODO: rename to print<foo> after NPM switch
3333
MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
3434
#undef MODULE_PASS
35+
36+
#ifndef FUNCTION_PASS
37+
#define FUNCTION_PASS(NAME, CREATE_PASS)
38+
#endif
39+
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
40+
#undef FUNCTION_PASS

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "DXILIntrinsicExpansion.h"
1818
#include "DXILOpLowering.h"
1919
#include "DXILPrettyPrinter.h"
20+
#include "DXILResourceAccess.h"
2021
#include "DXILResourceAnalysis.h"
2122
#include "DXILShaderFlags.h"
2223
#include "DXILTranslateMetadata.h"
@@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
5657
initializeWriteDXILPassPass(*PR);
5758
initializeDXContainerGlobalsPass(*PR);
5859
initializeDXILOpLoweringLegacyPass(*PR);
60+
initializeDXILResourceAccessLegacyPass(*PR);
5961
initializeDXILTranslateMetadataLegacyPass(*PR);
6062
initializeDXILResourceMDWrapperPass(*PR);
6163
initializeShaderFlagsAnalysisWrapperPass(*PR);
@@ -92,9 +94,10 @@ class DirectXPassConfig : public TargetPassConfig {
9294
addPass(createDXILFinalizeLinkageLegacyPass());
9395
addPass(createDXILIntrinsicExpansionLegacyPass());
9496
addPass(createDXILDataScalarizationLegacyPass());
97+
addPass(createDXILFlattenArraysLegacyPass());
98+
addPass(createDXILResourceAccessLegacyPass());
9599
ScalarizerPassOptions DxilScalarOptions;
96100
DxilScalarOptions.ScalarizeLoadStore = true;
97-
addPass(createDXILFlattenArraysLegacyPass());
98101
addPass(createScalarizerPass(DxilScalarOptions));
99102
addPass(createDXILOpLoweringLegacyPass());
100103
addPass(createDXILTranslateMetadataLegacyPass());
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; RUN: opt -S -dxil-resource-access %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.6-compute"
4+
5+
declare void @use_float4(<4 x float>)
6+
declare void @use_float(float)
7+
8+
; CHECK-LABEL: define void @load_float4
9+
define void @load_float4(i32 %index, i32 %elemindex) {
10+
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
11+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
12+
i32 0, i32 0, i32 1, i32 0, i1 false)
13+
14+
; CHECK-NOT: @llvm.dx.resource.getpointer
15+
%ptr = call ptr @llvm.dx.resource.getpointer(
16+
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
17+
18+
; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
19+
%vec_data = load <4 x float>, ptr %ptr
20+
call void @use_float4(<4 x float> %vec_data)
21+
22+
; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
23+
; CHECK: extractelement <4 x float> %[[VALUE]], i32 1
24+
%y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 1
25+
%y_data = load float, ptr %y_ptr
26+
call void @use_float(float %y_data)
27+
28+
; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
29+
; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
30+
%dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
31+
%dyndata = load float, ptr %dynamic
32+
call void @use_float(float %dyndata)
33+
34+
ret void
35+
}

0 commit comments

Comments
 (0)