@@ -54,7 +54,6 @@ SPDX-License-Identifier: MIT
54
54
// / write if there is a StackIDRelease after it.
55
55
// ===----------------------------------------------------------------------===//
56
56
57
-
58
57
#include " RTBuilder.h"
59
58
#include " Compiler/IGCPassSupport.h"
60
59
#include " iStdLib/utility.h"
@@ -73,193 +72,6 @@ using namespace llvm;
73
72
using namespace IGC ;
74
73
using namespace ShaderProperties ;
75
74
76
-
77
- class PayloadSinkingAnalysisPass : public FunctionPass
78
- {
79
- public:
80
- PayloadSinkingAnalysisPass () : FunctionPass(ID)
81
- {
82
- initializePayloadSinkingAnalysisPassPass (*PassRegistry::getPassRegistry ());
83
- }
84
-
85
- bool runOnFunction (Function& F) override ;
86
- StringRef getPassName () const override
87
- {
88
- return " PayloadSinkingAnalysisPass" ;
89
- }
90
-
91
- void getAnalysisUsage (llvm::AnalysisUsage& AU) const override
92
- {
93
- AU.setPreservesCFG ();
94
- AU.addRequired <CodeGenContextWrapper>();
95
- }
96
-
97
- static char ID;
98
- private:
99
- std::vector<llvm::CallShaderHLIntrinsic*> m_CallShaders;
100
- std::vector<llvm::TraceRayAsyncHLIntrinsic*> m_TraceRays;
101
- std::vector<llvm::SwitchInst*> m_Switches;
102
- std::vector<llvm::BranchInst*> m_ContidionalBranches;
103
- };
104
-
105
- char PayloadSinkingAnalysisPass::ID = 0 ;
106
-
107
-
108
- // Register pass to igc-opt
109
- #define PASS_FLAG " payload-sinking-analysis"
110
- #define PASS_DESCRIPTION " Perform analysis on whether Payload Sinking optimization should be applied or not"
111
- #define PASS_CFG_ONLY false
112
- #define PASS_ANALYSIS true
113
- IGC_INITIALIZE_PASS_BEGIN (PayloadSinkingAnalysisPass, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
114
- IGC_INITIALIZE_PASS_DEPENDENCY(CodeGenContextWrapper)
115
- IGC_INITIALIZE_PASS_END(PayloadSinkingAnalysisPass, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
116
- #undef PASS_FLAG
117
- #undef PASS_DESCRIPTION
118
- #undef PASS_CFG_ONLY
119
- #undef PASS_ANALYSIS
120
-
121
-
122
- bool PayloadSinkingAnalysisPass::runOnFunction (Function& F)
123
- {
124
- RayDispatchShaderContext* CGCtx = (RayDispatchShaderContext*)getAnalysis<CodeGenContextWrapper>().getCodeGenContext ();
125
- // early return if we already don't want payload sinking
126
- if (CGCtx->hasUnsupportedPayloadSinkingCase )
127
- {
128
- return false ;
129
- }
130
-
131
- // collect callable and switch instructions
132
- for (auto BI = F.begin (); BI != F.end (); BI++)
133
- {
134
- for (auto II = BI->begin (); II != BI->end (); II++)
135
- {
136
- if (llvm::CallShaderHLIntrinsic* inst = llvm::dyn_cast<llvm::CallShaderHLIntrinsic>(II))
137
- {
138
- m_CallShaders.push_back (inst);
139
- }
140
- else if (llvm::TraceRayAsyncHLIntrinsic* inst = llvm::dyn_cast<llvm::TraceRayAsyncHLIntrinsic>(II))
141
- {
142
- m_TraceRays.push_back (inst);
143
- }
144
- else if (llvm::SwitchInst* inst = llvm::dyn_cast<llvm::SwitchInst>(II))
145
- {
146
- m_Switches.push_back (inst);
147
- }
148
- else if (llvm::BranchInst* inst = llvm::dyn_cast<llvm::BranchInst>(II))
149
- {
150
- if (inst->isConditional ())
151
- {
152
- m_ContidionalBranches.push_back (inst);
153
- }
154
- }
155
- }
156
- }
157
-
158
- // early return if shader doesn't have switches and if-else or callables and tracerays
159
- if ((m_CallShaders.size () == 0 && m_TraceRays.size () == 0 ) ||
160
- (m_Switches.size () == 0 && m_ContidionalBranches.size () == 0 ))
161
- {
162
- return false ;
163
- }
164
-
165
- for (auto s : m_Switches)
166
- {
167
- // This map<ShaderIndex,Parameter> stores param of the first call of
168
- // given shader index callable shader. All calls of the same call shader
169
- // must have the same parameter in all switch-cases
170
- std::unordered_map<llvm::Value*, llvm::Value*> callableAndAllowedParam;
171
- // Same purpose as above, but for trace rays
172
- llvm::Value* allowedRayPayload = nullptr ;
173
- for (auto c : s->cases ())
174
- {
175
- for (auto callShader : m_CallShaders)
176
- {
177
- // check if call shader is under switch-case label
178
- // TODO: what if there is additional control flow under case label?
179
- if (callShader->getParent () == c.getCaseSuccessor ())
180
- {
181
- auto firstCall = callableAndAllowedParam.find (callShader->getShaderIndex ());
182
- if (firstCall == callableAndAllowedParam.end ())
183
- {
184
- // if it is the first call shader with this shader index
185
- // under this switch, remember it
186
- callableAndAllowedParam[callShader->getShaderIndex ()] = callShader->getParameter ();
187
- }
188
- else if (firstCall->second != callShader->getParameter ())
189
- {
190
- // if its not the first call shader with this shader index
191
- // ant it doesn't have the same param as the first,
192
- // then we cannot do payload sinking
193
- CGCtx->hasUnsupportedPayloadSinkingCase = true ;
194
- return false ;
195
- }
196
- }
197
- }
198
-
199
- for (auto rayTrace : m_TraceRays)
200
- {
201
- if (rayTrace->getParent () == c.getCaseSuccessor ())
202
- {
203
- if (!allowedRayPayload)
204
- {
205
- allowedRayPayload = rayTrace->getPayload ();
206
- }
207
- else if (allowedRayPayload != rayTrace->getPayload ())
208
- {
209
- CGCtx->hasUnsupportedPayloadSinkingCase = true ;
210
- return false ;
211
- }
212
- }
213
- }
214
- }
215
- }
216
-
217
- for (auto cb : m_ContidionalBranches)
218
- {
219
- // see above comments for switches
220
- std::unordered_map<llvm::Value*, llvm::Value*> callableAndAllowedParam;
221
- llvm::Value* allowedRayPayload = nullptr ;
222
- for (auto s : cb->successors ())
223
- {
224
- for (auto callShader : m_CallShaders)
225
- {
226
- if (callShader->getParent () == s)
227
- {
228
- auto firstCall = callableAndAllowedParam.find (callShader->getShaderIndex ());
229
- if (firstCall == callableAndAllowedParam.end ())
230
- {
231
- callableAndAllowedParam[callShader->getShaderIndex ()] = callShader->getParameter ();
232
- }
233
- else if (firstCall->second != callShader->getParameter ())
234
- {
235
- CGCtx->hasUnsupportedPayloadSinkingCase = true ;
236
- return false ;
237
- }
238
- }
239
- }
240
-
241
- for (auto rayTrace : m_TraceRays)
242
- {
243
- if (rayTrace->getParent () == s)
244
- {
245
- if (!allowedRayPayload)
246
- {
247
- allowedRayPayload = rayTrace->getPayload ();
248
- }
249
- else if (allowedRayPayload != rayTrace->getPayload ())
250
- {
251
- CGCtx->hasUnsupportedPayloadSinkingCase = true ;
252
- return false ;
253
- }
254
- }
255
- }
256
- }
257
- }
258
-
259
- return false ;
260
- }
261
-
262
-
263
75
class PayloadSinkingPass : public FunctionPass
264
76
{
265
77
public:
@@ -394,13 +206,11 @@ bool PayloadSinkingPass::canSink(
394
206
// If this shader returns to a continuation, this guarantees that all the
395
207
// inlined continuations collectively post dominate all payload writes
396
208
// in the current shader.
397
- const RayDispatchShaderContext& rdsC = (const RayDispatchShaderContext&)Ctx;
398
209
return (shaderReturnsToContinuation (ShaderTy) || ShaderTy == AnyHit) &&
399
- !rtInfo.isContinuation &&
400
- // Don't sink in callable since we don't know what the recursion
401
- // limit is. If there is 1, that is safe.
402
- (ShaderTy != Callable || modMD->rtInfo .NumContinuations == 1 ) &&
403
- !rdsC.hasUnsupportedPayloadSinkingCase ;
210
+ !rtInfo.isContinuation &&
211
+ // Don't sink in callable since we don't know what the recursion
212
+ // limit is. If there is 1, that is safe.
213
+ (ShaderTy != Callable || modMD->rtInfo .NumContinuations == 1 );
404
214
}
405
215
406
216
bool PayloadSinkingPass::runOnFunction (Function &F)
@@ -486,11 +296,6 @@ bool PayloadSinkingPass::runOnFunction(Function &F)
486
296
namespace IGC
487
297
{
488
298
489
- Pass* createPayloadSinkingAnalysisPass (void )
490
- {
491
- return new PayloadSinkingAnalysisPass ();
492
- }
493
-
494
299
Pass* createPayloadSinkingPass (void )
495
300
{
496
301
return new PayloadSinkingPass ();
0 commit comments