Skip to content

Commit b69a80e

Browse files
dlei6gsys_zuul
authored andcommitted
Enable recursion and convert all recursive calls to stack calls
Change-Id: Idd91b2406c5c3691280a3d4daf3d992cb9e59cff
1 parent f025b19 commit b69a80e

File tree

4 files changed

+49
-59
lines changed

4 files changed

+49
-59
lines changed

IGC/AdaptorCommon/AddImplicitArgs.cpp

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -517,59 +517,16 @@ BuiltinCallGraphAnalysis::BuiltinCallGraphAnalysis() : ModulePass(ID)
517517
initializeBuiltinCallGraphAnalysisPass(*PassRegistry::getPassRegistry());
518518
}
519519

520-
/// Check whether there are recursions.
521-
static bool hasRecursion(CallGraph &CG)
522-
{
523-
// Use Tarjan's algorithm to detect recursions.
524-
for (auto I = scc_begin(&CG), E = scc_end(&CG); I != E; ++I)
525-
{
526-
const std::vector<CallGraphNode *> &SCCNodes = *I;
527-
if (SCCNodes.size() >= 2)
528-
{
529-
return true;
530-
}
531-
532-
// Check self-recursion.
533-
auto Node = SCCNodes.back();
534-
for (auto Callee : *Node)
535-
{
536-
if (Callee.second == Node)
537-
{
538-
return true;
539-
}
540-
}
541-
}
542-
543-
// No recursion.
544-
return false;
545-
}
546-
547520
bool BuiltinCallGraphAnalysis::runOnModule(Module &M)
548521
{
549522
if (IGC_GET_FLAG_VALUE(FunctionControl) == FLAG_FCALL_FORCE_INLINE)
550523
{
551524
return false;
552525
}
553526

554-
CodeGenContext* ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
555527
m_pMdUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
556528
CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
557529

558-
if (IGC_IS_FLAG_DISABLED(EnableRecursionOpenCL) &&
559-
!ctx->m_DriverInfo.AllowRecursion() &&
560-
hasRecursion(CG))
561-
{
562-
IGC_ASSERT_MESSAGE(0, "Recursion detected!");
563-
ctx->EmitError(" undefined reference to `jmp()' ");
564-
return false;
565-
}
566-
567-
//Return if any error
568-
if (!(ctx->oclErrorMessage.empty()))
569-
{
570-
return false;
571-
}
572-
573530
for (auto I = scc_begin(&CG), IE = scc_end(&CG); I != IE; ++I)
574531
{
575532
const std::vector<CallGraphNode *> &SCCNodes = *I;
@@ -614,7 +571,7 @@ void BuiltinCallGraphAnalysis::traveseCallGraphSCC(const std::vector<CallGraphNo
614571
if (argMapIter != argMap.end())
615572
{
616573
IGC_ASSERT(nullptr != argMapIter->second);
617-
combineTwoArgDetail(*argData, *(argMapIter->second),
574+
combineTwoArgDetail(*argData, *(argMapIter->second),
618575
#if LLVM_VERSION_MAJOR <= 10
619576
N.first
620577
#else

IGC/AdaptorCommon/ProcessFuncAttributes.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
4545
#include <llvm/IR/Value.h>
4646
#include <llvm/IR/Attributes.h>
4747
#include <llvm/Support/raw_ostream.h>
48+
#include <llvm/Analysis/CallGraph.h>
49+
#include "llvm/ADT/SCCIterator.h"
4850
#include "common/LLVMWarningsPop.hpp"
4951
#include "common/igc_regkeys.hpp"
5052
#include <string>
@@ -66,6 +68,7 @@ class ProcessFuncAttributes : public ModulePass
6668
AU.addRequired<MetaDataUtilsWrapper>();
6769
AU.addRequired<CodeGenContextWrapper>();
6870
AU.addRequired<EstimateFunctionSize>();
71+
AU.addRequired<llvm::CallGraphWrapperPass>();
6972
}
7073

7174
ProcessFuncAttributes();
@@ -96,6 +99,7 @@ IGC_INITIALIZE_PASS_BEGIN(ProcessFuncAttributes, PASS_FLAG, PASS_DESCRIPTION, PA
9699
IGC_INITIALIZE_PASS_DEPENDENCY(MetaDataUtilsWrapper)
97100
IGC_INITIALIZE_PASS_DEPENDENCY(CodeGenContextWrapper)
98101
IGC_INITIALIZE_PASS_DEPENDENCY(EstimateFunctionSize)
102+
IGC_INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
99103
IGC_INITIALIZE_PASS_END(ProcessFuncAttributes, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
100104

101105
char ProcessFuncAttributes::ID = 0;
@@ -166,6 +170,37 @@ static bool containsOpaque(llvm::Type *T)
166170
return false;
167171
}
168172

173+
// Convert functions with recursion to stackcall
174+
static bool convertRecursionToStackCall(CallGraph& CG, CodeGenContext* pCtx, IGCMD::MetaDataUtils* pM)
175+
{
176+
bool hasRecursion = false;
177+
// Use Tarjan's algorithm to detect recursions.
178+
for (auto I = scc_begin(&CG), E = scc_end(&CG); I != E; ++I)
179+
{
180+
const std::vector<CallGraphNode*>& SCCNodes = *I;
181+
if (SCCNodes.size() >= 2)
182+
{
183+
hasRecursion = true;
184+
// Convert all functions in the recursion call graph to stackcall
185+
for (auto Node : SCCNodes)
186+
{
187+
Node->getFunction()->addFnAttr("visaStackCall");
188+
}
189+
}
190+
// Check self-recursion.
191+
auto Node = SCCNodes.back();
192+
for (auto Callee : *Node)
193+
{
194+
if (Callee.second == Node)
195+
{
196+
hasRecursion = true;
197+
Node->getFunction()->addFnAttr("visaStackCall");
198+
}
199+
}
200+
}
201+
return hasRecursion;
202+
}
203+
169204
// __builtin_spirv related OpGroup call implementations contain both
170205
// workgroup and subgroup code in them that is switched on based on the
171206
// 'Execution' and 'Operation' parameters and these will almost always
@@ -517,6 +552,18 @@ bool ProcessFuncAttributes::runOnModule(Module& M)
517552
}
518553
}
519554
}
555+
556+
// Detect recursive calls, and convert them to stack calls, since subroutines does not support recursion
557+
CallGraph& CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
558+
if (convertRecursionToStackCall(CG, pCtx, pMdUtils))
559+
{
560+
if (IGC_IS_FLAG_DISABLED(EnableRecursionOpenCL) && !pCtx->m_DriverInfo.AllowRecursion())
561+
{
562+
IGC_ASSERT_MESSAGE(0, "Recursion detected!");
563+
}
564+
pCtx->m_enableStackCall = true;
565+
}
566+
520567
return Changed;
521568
}
522569

IGC/Compiler/CISACodeGen/GenCodeGenModule.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,6 @@ inline Function* getCallerFunc(Value* user)
194194

195195
void GenXCodeGenModule::processFunction(Function& F)
196196
{
197-
// force stack-call for self-recursion
198-
for (auto U : F.users())
199-
{
200-
if (CallInst * CI = dyn_cast<CallInst>(U))
201-
{
202-
Function* Caller = CI->getParent()->getParent();
203-
if (Caller == &F)
204-
{
205-
F.addFnAttr("visaStackCall");
206-
break;
207-
}
208-
}
209-
}
210-
211197
// See what FunctionGroups this Function is called from.
212198
SetVector<std::pair<FunctionGroup*, Function*>> CallerFGs;
213199
for (auto U : F.users())

IGC/common/igc_flags.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ DECLARE_IGC_REGKEY(bool, Enable64BitEmulation, false, "Enable 64-bit em
300300
DECLARE_IGC_REGKEY(bool, Enable64BitEmulationOnSelectedPlatform, true, "Enable 64-bit emulation on selected platforms", false)
301301
DECLARE_IGC_REGKEY(DWORD, EnableConstIntDivReduction, 0x1, "Enables strength reduction on integer division/remainder with constant divisors/moduli", true)
302302
DECLARE_IGC_REGKEY(DWORD, EnableIntDivRemCombine, 0x0, "Given div/rem pairs with same operands merged; replace rem with mul+sub on quotient; 0x3 (set bit[1]) forces this on constant power of two divisors as well", true)
303-
DECLARE_IGC_REGKEY(bool, EnableRecursionOpenCL, false, "Enable recursion with OpenCL user functions", false)
303+
DECLARE_IGC_REGKEY(bool, EnableRecursionOpenCL, true, "Enable recursion with OpenCL user functions", false)
304304
DECLARE_IGC_REGKEY(bool, ForceDPEmulation, false, "Force double emulation for testing purpose", false)
305305
DECLARE_IGC_REGKEY(bool, EnableDPEmulation, false, "Enforce double precision floating point operations emulation on platforms that do not support it natively", true)
306306
DECLARE_IGC_REGKEY(bool, DPEmuNeedI64Emu, true, "Double Emulation needs I64 emulation. Unsetting it to disable I64 Emulation for testing.", false)

0 commit comments

Comments
 (0)