Skip to content

Commit c0d8b67

Browse files
committed
[DirectX] Lower @llvm.dx.handle.fromBinding to DXIL ops
The `@llvm.dx.handle.fromBinding` intrinsic is lowered either to the `CreateHandle` op or a pair of `CreateHandleFromBinding` and `AnnotateHandle` ops, depending on the DXIL version. Regardless of the DXIL version we need to emit metadata about the binding, but that's left to a separate change. These DXIL ops all need to return the `%dx.types.Handle` type, but the llvm intrinsic returns a target extension type. To facilitate changing the type of the operation and all of its users, we introduce `%llvm.dx.cast.handle`, which can cast between the two handle representations. Pull Request: llvm#104251
1 parent c76bc28 commit c0d8b67

File tree

8 files changed

+351
-7
lines changed

8 files changed

+351
-7
lines changed

llvm/include/llvm/Analysis/DXILResource.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class TargetExtType;
2323
namespace dxil {
2424

2525
class ResourceInfo {
26+
public:
2627
struct ResourceBinding {
2728
uint32_t RecordID;
2829
uint32_t Space;
@@ -89,6 +90,7 @@ class ResourceInfo {
8990
bool operator!=(const FeedbackInfo &RHS) const { return !(*this == RHS); }
9091
};
9192

93+
private:
9294
// Universal properties.
9395
Value *Symbol;
9496
StringRef Name;
@@ -115,6 +117,10 @@ class ResourceInfo {
115117

116118
MSInfo MultiSample;
117119

120+
// We need a default constructor if we want to insert this in a MapVector.
121+
ResourceInfo() {}
122+
friend class MapVector<CallInst *, ResourceInfo>;
123+
118124
public:
119125
ResourceInfo(dxil::ResourceClass RC, dxil::ResourceKind Kind, Value *Symbol,
120126
StringRef Name)
@@ -166,6 +172,8 @@ class ResourceInfo {
166172
MultiSample.Count = Count;
167173
}
168174

175+
dxil::ResourceClass getResourceClass() const { return RC; }
176+
169177
bool operator==(const ResourceInfo &RHS) const;
170178

171179
static ResourceInfo SRV(Value *Symbol, StringRef Name,

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def int_dx_handle_fromBinding
3030
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
3131
[IntrNoMem]>;
3232

33+
// Cast between target extension handle types and dxil-style opaque handles
34+
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
35+
3336
def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
3437
def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
3538
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def FloatTy : DXILOpParamType;
4242
def DoubleTy : DXILOpParamType;
4343
def ResRetTy : DXILOpParamType;
4444
def HandleTy : DXILOpParamType;
45+
def ResBindTy : DXILOpParamType;
46+
def ResPropsTy : DXILOpParamType;
4547

4648
class DXILOpClass;
4749

@@ -673,6 +675,14 @@ def Dot4 : DXILOp<56, dot4> {
673675
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
674676
}
675677

678+
def CreateHandle : DXILOp<57, createHandle> {
679+
let Doc = "creates the handle to a resource";
680+
// ResourceClass, RangeID, Index, NonUniform
681+
let arguments = [Int8Ty, Int32Ty, Int32Ty, Int1Ty];
682+
let result = HandleTy;
683+
let stages = [Stages<DXIL1_0, [all_stages]>];
684+
}
685+
676686
def ThreadId : DXILOp<93, threadId> {
677687
let Doc = "Reads the thread ID";
678688
let LLVMIntrinsic = int_dx_thread_id;
@@ -712,3 +722,17 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
712722
let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
713723
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
714724
}
725+
726+
def AnnotateHandle : DXILOp<217, annotateHandle> {
727+
let Doc = "annotate handle with resource properties";
728+
let arguments = [HandleTy, ResPropsTy];
729+
let result = HandleTy;
730+
let stages = [Stages<DXIL1_6, [all_stages]>];
731+
}
732+
733+
def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
734+
let Doc = "create resource handle from binding";
735+
let arguments = [ResBindTy, Int32Ty, Int1Ty];
736+
let result = HandleTy;
737+
let stages = [Stages<DXIL1_6, [all_stages]>];
738+
}

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,23 @@ static StructType *getHandleType(LLVMContext &Ctx) {
208208
Ctx);
209209
}
210210

211+
static StructType *getResBindType(LLVMContext &Context) {
212+
if (auto *ST = StructType::getTypeByName(Context, "dx.types.ResBind"))
213+
return ST;
214+
Type *Int32Ty = Type::getInt32Ty(Context);
215+
Type *Int8Ty = Type::getInt8Ty(Context);
216+
return StructType::create({Int32Ty, Int32Ty, Int32Ty, Int8Ty},
217+
"dx.types.ResBind");
218+
}
219+
220+
static StructType *getResPropsType(LLVMContext &Context) {
221+
if (auto *ST =
222+
StructType::getTypeByName(Context, "dx.types.ResourceProperties"))
223+
return ST;
224+
Type *Int32Ty = Type::getInt32Ty(Context);
225+
return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
226+
}
227+
211228
static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
212229
Type *OverloadTy) {
213230
switch (Kind) {
@@ -235,6 +252,10 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
235252
return getResRetType(OverloadTy, Ctx);
236253
case OpParamType::HandleTy:
237254
return getHandleType(Ctx);
255+
case OpParamType::ResBindTy:
256+
return getResBindType(Ctx);
257+
case OpParamType::ResPropsTy:
258+
return getResPropsType(Ctx);
238259
}
239260
llvm_unreachable("Invalid parameter kind");
240261
return nullptr;
@@ -430,6 +451,29 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
430451
return *Result;
431452
}
432453

454+
StructType *DXILOpBuilder::getHandleType() {
455+
return ::getHandleType(IRB.getContext());
456+
}
457+
458+
Constant *DXILOpBuilder::getResBind(uint32_t LowerBound, uint32_t UpperBound,
459+
uint32_t SpaceID, dxil::ResourceClass RC) {
460+
Type *Int32Ty = IRB.getInt32Ty();
461+
Type *Int8Ty = IRB.getInt8Ty();
462+
return ConstantStruct::get(
463+
getResBindType(IRB.getContext()),
464+
{ConstantInt::get(Int32Ty, LowerBound),
465+
ConstantInt::get(Int32Ty, UpperBound),
466+
ConstantInt::get(Int32Ty, SpaceID),
467+
ConstantInt::get(Int8Ty, llvm::to_underlying(RC))});
468+
}
469+
470+
Constant *DXILOpBuilder::getResProps(uint32_t Word0, uint32_t Word1) {
471+
Type *Int32Ty = IRB.getInt32Ty();
472+
return ConstantStruct::get(
473+
getResPropsType(IRB.getContext()),
474+
{ConstantInt::get(Int32Ty, Word0), ConstantInt::get(Int32Ty, Word1)});
475+
}
476+
433477
const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
434478
return ::getOpCodeName(DXILOp);
435479
}

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
#include "DXILConstants.h"
1616
#include "llvm/ADT/SmallVector.h"
1717
#include "llvm/IR/IRBuilder.h"
18+
#include "llvm/Support/DXILABI.h"
1819
#include "llvm/Support/Error.h"
1920
#include "llvm/TargetParser/Triple.h"
2021

2122
namespace llvm {
2223
class Module;
2324
class IRBuilderBase;
2425
class CallInst;
26+
class Constant;
2527
class Value;
2628
class Type;
2729
class FunctionType;
@@ -44,6 +46,15 @@ class DXILOpBuilder {
4446
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
4547
Type *RetTy = nullptr);
4648

49+
/// Get the `%dx.types.Handle` type.
50+
StructType *getHandleType();
51+
52+
/// Get a constant `%dx.types.ResBind` value.
53+
Constant *getResBind(uint32_t LowerBound, uint32_t UpperBound,
54+
uint32_t SpaceID, dxil::ResourceClass RC);
55+
/// Get a constant `%dx.types.ResourceProperties` value.
56+
Constant *getResProps(uint32_t Word0, uint32_t Word1);
57+
4758
/// Return the name of the given opcode.
4859
static const char *getOpCodeName(dxil::OpCode DXILOp);
4960

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 139 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "DXILOpBuilder.h"
1313
#include "DirectX.h"
1414
#include "llvm/ADT/SmallVector.h"
15+
#include "llvm/Analysis/DXILResource.h"
1516
#include "llvm/CodeGen/Passes.h"
1617
#include "llvm/IR/DiagnosticInfo.h"
1718
#include "llvm/IR/IRBuilder.h"
@@ -20,6 +21,7 @@
2021
#include "llvm/IR/IntrinsicsDirectX.h"
2122
#include "llvm/IR/Module.h"
2223
#include "llvm/IR/PassManager.h"
24+
#include "llvm/InitializePasses.h"
2325
#include "llvm/Pass.h"
2426
#include "llvm/Support/ErrorHandling.h"
2527

@@ -74,9 +76,11 @@ namespace {
7476
class OpLowerer {
7577
Module &M;
7678
DXILOpBuilder OpBuilder;
79+
DXILResourceMap &DRM;
80+
SmallVector<CallInst *> CleanupCasts;
7781

7882
public:
79-
OpLowerer(Module &M) : M(M), OpBuilder(M) {}
83+
OpLowerer(Module &M, DXILResourceMap &DRM) : M(M), OpBuilder(M), DRM(DRM) {}
8084

8185
void replaceFunction(Function &F,
8286
llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
@@ -119,6 +123,119 @@ class OpLowerer {
119123
});
120124
}
121125

126+
Value *createTmpHandleCast(Value *V, Type *Ty) {
127+
Function *CastFn = Intrinsic::getDeclaration(&M, Intrinsic::dx_cast_handle,
128+
{Ty, V->getType()});
129+
CallInst *Cast = OpBuilder.getIRB().CreateCall(CastFn, {V});
130+
CleanupCasts.push_back(Cast);
131+
return Cast;
132+
}
133+
134+
void cleanupHandleCasts() {
135+
SmallVector<CallInst *> ToRemove;
136+
SmallVector<Function *> CastFns;
137+
138+
for (CallInst *Cast : CleanupCasts) {
139+
CastFns.push_back(Cast->getCalledFunction());
140+
// All of the ops should be using `dx.types.Handle` at this point, so if
141+
// we're not producing that we should be part of a pair. Track this so we
142+
// can remove it at the end.
143+
if (Cast->getType() != OpBuilder.getHandleType()) {
144+
ToRemove.push_back(Cast);
145+
continue;
146+
}
147+
// Otherwise, we're the second handle in a pair. Forward the arguments and
148+
// remove the (second) cast.
149+
CallInst *Def = cast<CallInst>(Cast->getOperand(0));
150+
assert(Def->getIntrinsicID() == Intrinsic::dx_cast_handle &&
151+
"Unbalanced pair of temporary handle casts");
152+
Cast->replaceAllUsesWith(Def->getOperand(0));
153+
Cast->eraseFromParent();
154+
}
155+
for (CallInst *Cast : ToRemove) {
156+
assert(Cast->user_empty() && "Temporary handle cast still has users");
157+
Cast->eraseFromParent();
158+
}
159+
llvm::sort(CastFns);
160+
CastFns.erase(llvm::unique(CastFns), CastFns.end());
161+
for (Function *F : CastFns)
162+
F->eraseFromParent();
163+
164+
CleanupCasts.clear();
165+
}
166+
167+
void lowerToCreateHandle(Function &F) {
168+
IRBuilder<> &IRB = OpBuilder.getIRB();
169+
Type *Int8Ty = IRB.getInt8Ty();
170+
Type *Int32Ty = IRB.getInt32Ty();
171+
172+
replaceFunction(F, [&](CallInst *CI) -> Error {
173+
IRB.SetInsertPoint(CI);
174+
175+
dxil::ResourceInfo &RI = DRM[CI];
176+
dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding();
177+
178+
std::array<Value *, 4> Args{
179+
ConstantInt::get(Int8Ty, llvm::to_underlying(RI.getResourceClass())),
180+
ConstantInt::get(Int32Ty, Binding.RecordID), CI->getArgOperand(3),
181+
CI->getArgOperand(4)};
182+
Expected<CallInst *> OpCall =
183+
OpBuilder.tryCreateOp(OpCode::CreateHandle, Args);
184+
if (Error E = OpCall.takeError())
185+
return E;
186+
187+
Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
188+
189+
CI->replaceAllUsesWith(Cast);
190+
CI->eraseFromParent();
191+
return Error::success();
192+
});
193+
}
194+
195+
void lowerToBindAndAnnotateHandle(Function &F) {
196+
IRBuilder<> &IRB = OpBuilder.getIRB();
197+
198+
replaceFunction(F, [&](CallInst *CI) -> Error {
199+
IRB.SetInsertPoint(CI);
200+
201+
dxil::ResourceInfo &RI = DRM[CI];
202+
dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding();
203+
std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps();
204+
205+
Constant *ResBind = OpBuilder.getResBind(
206+
Binding.LowerBound, Binding.LowerBound + Binding.Size - 1,
207+
Binding.Space, RI.getResourceClass());
208+
std::array<Value *, 3> BindArgs{ResBind, CI->getArgOperand(3),
209+
CI->getArgOperand(4)};
210+
Expected<CallInst *> OpBind =
211+
OpBuilder.tryCreateOp(OpCode::CreateHandleFromBinding, BindArgs);
212+
if (Error E = OpBind.takeError())
213+
return E;
214+
215+
std::array<Value *, 2> AnnotateArgs{
216+
*OpBind, OpBuilder.getResProps(Props.first, Props.second)};
217+
Expected<CallInst *> OpAnnotate =
218+
OpBuilder.tryCreateOp(OpCode::AnnotateHandle, AnnotateArgs);
219+
if (Error E = OpAnnotate.takeError())
220+
return E;
221+
222+
Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());
223+
224+
CI->replaceAllUsesWith(Cast);
225+
CI->eraseFromParent();
226+
227+
return Error::success();
228+
});
229+
}
230+
231+
void lowerHandleFromBinding(Function &F) {
232+
Triple TT(Triple(M.getTargetTriple()));
233+
if (TT.getDXILVersion() < VersionTuple(1, 6))
234+
lowerToCreateHandle(F);
235+
else
236+
lowerToBindAndAnnotateHandle(F);
237+
}
238+
122239
bool lowerIntrinsics() {
123240
bool Updated = false;
124241

@@ -134,40 +251,55 @@ class OpLowerer {
134251
replaceFunctionWithOp(F, OpCode); \
135252
break;
136253
#include "DXILOperation.inc"
254+
case Intrinsic::dx_handle_fromBinding:
255+
lowerHandleFromBinding(F);
137256
}
138257
Updated = true;
139258
}
259+
if (Updated)
260+
cleanupHandleCasts();
261+
140262
return Updated;
141263
}
142264
};
143265
} // namespace
144266

145-
PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &) {
146-
if (OpLowerer(M).lowerIntrinsics())
147-
return PreservedAnalyses::none();
148-
return PreservedAnalyses::all();
267+
PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
268+
DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
269+
270+
bool MadeChanges = OpLowerer(M, DRM).lowerIntrinsics();
271+
if (!MadeChanges)
272+
return PreservedAnalyses::all();
273+
PreservedAnalyses PA;
274+
PA.preserve<DXILResourceAnalysis>();
275+
return PA;
149276
}
150277

151278
namespace {
152279
class DXILOpLoweringLegacy : public ModulePass {
153280
public:
154281
bool runOnModule(Module &M) override {
155-
return OpLowerer(M).lowerIntrinsics();
282+
DXILResourceMap &DRM =
283+
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
284+
285+
return OpLowerer(M, DRM).lowerIntrinsics();
156286
}
157287
StringRef getPassName() const override { return "DXIL Op Lowering"; }
158288
DXILOpLoweringLegacy() : ModulePass(ID) {}
159289

160290
static char ID; // Pass identification.
161291
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
162-
// Specify the passes that your pass depends on
163292
AU.addRequired<DXILIntrinsicExpansionLegacy>();
293+
AU.addRequired<DXILResourceWrapperPass>();
294+
AU.addPreserved<DXILResourceWrapperPass>();
164295
}
165296
};
166297
char DXILOpLoweringLegacy::ID = 0;
167298
} // end anonymous namespace
168299

169300
INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
170301
false, false)
302+
INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
171303
INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
172304
false)
173305

0 commit comments

Comments
 (0)