12
12
#include " DXILOpBuilder.h"
13
13
#include " DirectX.h"
14
14
#include " llvm/ADT/SmallVector.h"
15
+ #include " llvm/Analysis/DXILResource.h"
15
16
#include " llvm/CodeGen/Passes.h"
16
17
#include " llvm/IR/DiagnosticInfo.h"
17
18
#include " llvm/IR/IRBuilder.h"
20
21
#include " llvm/IR/IntrinsicsDirectX.h"
21
22
#include " llvm/IR/Module.h"
22
23
#include " llvm/IR/PassManager.h"
24
+ #include " llvm/InitializePasses.h"
23
25
#include " llvm/Pass.h"
24
26
#include " llvm/Support/ErrorHandling.h"
25
27
@@ -74,9 +76,11 @@ namespace {
74
76
class OpLowerer {
75
77
Module &M;
76
78
DXILOpBuilder OpBuilder;
79
+ DXILResourceMap &DRM;
80
+ SmallVector<CallInst *> CleanupCasts;
77
81
78
82
public:
79
- OpLowerer (Module &M) : M(M), OpBuilder(M) {}
83
+ OpLowerer (Module &M, DXILResourceMap &DRM ) : M(M), OpBuilder(M), DRM(DRM ) {}
80
84
81
85
void replaceFunction (Function &F,
82
86
llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
@@ -119,6 +123,119 @@ class OpLowerer {
119
123
});
120
124
}
121
125
126
+ Value *createTmpHandleCast (Value *V, Type *Ty) {
127
+ Function *CastFn = Intrinsic::getDeclaration (&M, Intrinsic::dx_cast_handle,
128
+ {Ty, V->getType ()});
129
+ CallInst *Cast = OpBuilder.getIRB ().CreateCall (CastFn, {V});
130
+ CleanupCasts.push_back (Cast);
131
+ return Cast;
132
+ }
133
+
134
+ void cleanupHandleCasts () {
135
+ SmallVector<CallInst *> ToRemove;
136
+ SmallVector<Function *> CastFns;
137
+
138
+ for (CallInst *Cast : CleanupCasts) {
139
+ CastFns.push_back (Cast->getCalledFunction ());
140
+ // All of the ops should be using `dx.types.Handle` at this point, so if
141
+ // we're not producing that we should be part of a pair. Track this so we
142
+ // can remove it at the end.
143
+ if (Cast->getType () != OpBuilder.getHandleType ()) {
144
+ ToRemove.push_back (Cast);
145
+ continue ;
146
+ }
147
+ // Otherwise, we're the second handle in a pair. Forward the arguments and
148
+ // remove the (second) cast.
149
+ CallInst *Def = cast<CallInst>(Cast->getOperand (0 ));
150
+ assert (Def->getIntrinsicID () == Intrinsic::dx_cast_handle &&
151
+ " Unbalanced pair of temporary handle casts" );
152
+ Cast->replaceAllUsesWith (Def->getOperand (0 ));
153
+ Cast->eraseFromParent ();
154
+ }
155
+ for (CallInst *Cast : ToRemove) {
156
+ assert (Cast->user_empty () && " Temporary handle cast still has users" );
157
+ Cast->eraseFromParent ();
158
+ }
159
+ llvm::sort (CastFns);
160
+ CastFns.erase (llvm::unique (CastFns), CastFns.end ());
161
+ for (Function *F : CastFns)
162
+ F->eraseFromParent ();
163
+
164
+ CleanupCasts.clear ();
165
+ }
166
+
167
+ void lowerToCreateHandle (Function &F) {
168
+ IRBuilder<> &IRB = OpBuilder.getIRB ();
169
+ Type *Int8Ty = IRB.getInt8Ty ();
170
+ Type *Int32Ty = IRB.getInt32Ty ();
171
+
172
+ replaceFunction (F, [&](CallInst *CI) -> Error {
173
+ IRB.SetInsertPoint (CI);
174
+
175
+ dxil::ResourceInfo &RI = DRM[CI];
176
+ dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding ();
177
+
178
+ std::array<Value *, 4 > Args{
179
+ ConstantInt::get (Int8Ty, llvm::to_underlying (RI.getResourceClass ())),
180
+ ConstantInt::get (Int32Ty, Binding.RecordID ), CI->getArgOperand (3 ),
181
+ CI->getArgOperand (4 )};
182
+ Expected<CallInst *> OpCall =
183
+ OpBuilder.tryCreateOp (OpCode::CreateHandle, Args);
184
+ if (Error E = OpCall.takeError ())
185
+ return E;
186
+
187
+ Value *Cast = createTmpHandleCast (*OpCall, CI->getType ());
188
+
189
+ CI->replaceAllUsesWith (Cast);
190
+ CI->eraseFromParent ();
191
+ return Error::success ();
192
+ });
193
+ }
194
+
195
+ void lowerToBindAndAnnotateHandle (Function &F) {
196
+ IRBuilder<> &IRB = OpBuilder.getIRB ();
197
+
198
+ replaceFunction (F, [&](CallInst *CI) -> Error {
199
+ IRB.SetInsertPoint (CI);
200
+
201
+ dxil::ResourceInfo &RI = DRM[CI];
202
+ dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding ();
203
+ std::pair<uint32_t , uint32_t > Props = RI.getAnnotateProps ();
204
+
205
+ Constant *ResBind = OpBuilder.getResBind (
206
+ Binding.LowerBound , Binding.LowerBound + Binding.Size - 1 ,
207
+ Binding.Space , RI.getResourceClass ());
208
+ std::array<Value *, 3 > BindArgs{ResBind, CI->getArgOperand (3 ),
209
+ CI->getArgOperand (4 )};
210
+ Expected<CallInst *> OpBind =
211
+ OpBuilder.tryCreateOp (OpCode::CreateHandleFromBinding, BindArgs);
212
+ if (Error E = OpBind.takeError ())
213
+ return E;
214
+
215
+ std::array<Value *, 2 > AnnotateArgs{
216
+ *OpBind, OpBuilder.getResProps (Props.first , Props.second )};
217
+ Expected<CallInst *> OpAnnotate =
218
+ OpBuilder.tryCreateOp (OpCode::AnnotateHandle, AnnotateArgs);
219
+ if (Error E = OpAnnotate.takeError ())
220
+ return E;
221
+
222
+ Value *Cast = createTmpHandleCast (*OpAnnotate, CI->getType ());
223
+
224
+ CI->replaceAllUsesWith (Cast);
225
+ CI->eraseFromParent ();
226
+
227
+ return Error::success ();
228
+ });
229
+ }
230
+
231
+ void lowerHandleFromBinding (Function &F) {
232
+ Triple TT (Triple (M.getTargetTriple ()));
233
+ if (TT.getDXILVersion () < VersionTuple (1 , 6 ))
234
+ lowerToCreateHandle (F);
235
+ else
236
+ lowerToBindAndAnnotateHandle (F);
237
+ }
238
+
122
239
bool lowerIntrinsics () {
123
240
bool Updated = false ;
124
241
@@ -134,40 +251,55 @@ class OpLowerer {
134
251
replaceFunctionWithOp (F, OpCode); \
135
252
break ;
136
253
#include " DXILOperation.inc"
254
+ case Intrinsic::dx_handle_fromBinding:
255
+ lowerHandleFromBinding (F);
137
256
}
138
257
Updated = true ;
139
258
}
259
+ if (Updated)
260
+ cleanupHandleCasts ();
261
+
140
262
return Updated;
141
263
}
142
264
};
143
265
} // namespace
144
266
145
- PreservedAnalyses DXILOpLowering::run (Module &M, ModuleAnalysisManager &) {
146
- if (OpLowerer (M).lowerIntrinsics ())
147
- return PreservedAnalyses::none ();
148
- return PreservedAnalyses::all ();
267
+ PreservedAnalyses DXILOpLowering::run (Module &M, ModuleAnalysisManager &MAM) {
268
+ DXILResourceMap &DRM = MAM.getResult <DXILResourceAnalysis>(M);
269
+
270
+ bool MadeChanges = OpLowerer (M, DRM).lowerIntrinsics ();
271
+ if (!MadeChanges)
272
+ return PreservedAnalyses::all ();
273
+ PreservedAnalyses PA;
274
+ PA.preserve <DXILResourceAnalysis>();
275
+ return PA;
149
276
}
150
277
151
278
namespace {
152
279
class DXILOpLoweringLegacy : public ModulePass {
153
280
public:
154
281
bool runOnModule (Module &M) override {
155
- return OpLowerer (M).lowerIntrinsics ();
282
+ DXILResourceMap &DRM =
283
+ getAnalysis<DXILResourceWrapperPass>().getResourceMap ();
284
+
285
+ return OpLowerer (M, DRM).lowerIntrinsics ();
156
286
}
157
287
StringRef getPassName () const override { return " DXIL Op Lowering" ; }
158
288
DXILOpLoweringLegacy () : ModulePass(ID) {}
159
289
160
290
static char ID; // Pass identification.
161
291
void getAnalysisUsage (llvm::AnalysisUsage &AU) const override {
162
- // Specify the passes that your pass depends on
163
292
AU.addRequired <DXILIntrinsicExpansionLegacy>();
293
+ AU.addRequired <DXILResourceWrapperPass>();
294
+ AU.addPreserved <DXILResourceWrapperPass>();
164
295
}
165
296
};
166
297
char DXILOpLoweringLegacy::ID = 0 ;
167
298
} // end anonymous namespace
168
299
169
300
INITIALIZE_PASS_BEGIN (DXILOpLoweringLegacy, DEBUG_TYPE, " DXIL Op Lowering" ,
170
301
false , false )
302
+ INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
171
303
INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, " DXIL Op Lowering" , false ,
172
304
false )
173
305
0 commit comments