Skip to content

Commit 773ccee

Browse files
bokrzesiigcbot
authored andcommitted
Adding function argument resolution support for JointMatrixFuncsResolutionPass
Adding function argument resolution support for JointMatrixFuncsResolutionPass
1 parent 86502fa commit 773ccee

File tree

4 files changed

+375
-6
lines changed

4 files changed

+375
-6
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 276 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*========================== begin_copyright_notice ============================
22
3-
Copyright (C) 2021 Intel Corporation
3+
Copyright (C) 2024 Intel Corporation
44
55
SPDX-License-Identifier: MIT
66
@@ -23,6 +23,9 @@ SPDX-License-Identifier: MIT
2323
#include <llvm/ADT/Sequence.h>
2424
#include <llvm/ADT/STLExtras.h>
2525
#include <llvm/ADT/PostOrderIterator.h>
26+
#include "llvm/IR/DebugInfo.h"
27+
#include "llvm/IR/DIBuilder.h"
28+
#include "llvmWrapper/Transforms/Utils/Cloning.h"
2629
#include <llvmWrapper/ADT/Optional.h>
2730
#include "llvmWrapper/IR/Value.h"
2831
#include <llvmWrapper/Analysis/ValueTracking.h>
@@ -79,7 +82,6 @@ static const char *CooperativeMatrixLengthPrefx = "CooperativeMatrixLengthKHR";
7982
static const char *CooperativeMatrixGetElementCoordPrefx ="CooperativeMatrixGetElementCoordINTEL";
8083
static const char *AccessChainPrefx = "__spirv_AccessChain";
8184

82-
8385
// We need module pass, since:
8486
// 1) we inspect multiple functions to find entry function to get sub group size
8587
// 2) we maintain map of functions to entry functions across functions we process
@@ -94,6 +96,8 @@ bool JointMatrixFuncsResolutionPass::runOnModule(Module &M)
9496
m_Ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
9597
m_mdUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
9698
FunctionsMap.clear();
99+
ResolvedFunctions.clear();
100+
ResolvedTypes.clear();
97101
Changed = false;
98102

99103
for (auto &F : M) {
@@ -104,6 +108,28 @@ bool JointMatrixFuncsResolutionPass::runOnModule(Module &M)
104108
preprocessAccessChain(&F);
105109
}
106110

111+
for (auto& F : M) {
112+
bool stop = false;
113+
for (auto& entry : ResolvedFunctions)
114+
{
115+
if (entry.second == &F)
116+
{
117+
stop = true;
118+
break;
119+
}
120+
}
121+
122+
if (stop)
123+
break;
124+
125+
auto argsWithMatrixType = GetFunctionArgsWithMatrixType(&F);
126+
127+
if (argsWithMatrixType.size() > 0) {
128+
ResolveSIMDSize(&F);
129+
ResolveFunctionSignature(&F);
130+
}
131+
}
132+
107133
for (auto &F : M)
108134
{
109135
if (F.isDeclaration())
@@ -297,12 +323,14 @@ bool JointMatrixFuncsResolutionPass::runOnFunction(Function& F)
297323
{
298324
PlaceholderInstructions.clear();
299325
ResolvedValues.clear();
300-
ResolvedTypes.clear();
301326
InstsToErase.clear();
302327
MatrixAllocas.clear();
303328
m_SIMDSize = 0;
304329

305-
// Use reverse post order traversal to reduce level or recursion
330+
if (ResolvedFunctions.count(&F) > 0)
331+
return false;
332+
333+
// Use reverse post order traversal to reduce level or recursion.
306334
ReversePostOrderTraversal<Function *> RPOT(&F);
307335
for (BasicBlock *BB : RPOT)
308336
visit(BB);
@@ -2328,6 +2356,54 @@ Value *JointMatrixFuncsResolutionPass::Resolve(Value *v)
23282356
return nullptr;
23292357
}
23302358

2359+
Function* JointMatrixFuncsResolutionPass::CloneFunction(Function* pOriginalFunction)
2360+
{
2361+
if (pOriginalFunction == nullptr) {
2362+
return nullptr;
2363+
}
2364+
2365+
std::vector<Type*> params;
2366+
2367+
for (auto &arg : pOriginalFunction->args())
2368+
{
2369+
auto type = isOrContainsMatrixType(arg.getType()) ? ResolveTypes(arg.getType()) : arg.getType();
2370+
params.push_back(type);
2371+
}
2372+
2373+
auto newFunctionTy = FunctionType::get(ResolveTypes(pOriginalFunction->getReturnType()), params, pOriginalFunction->isVarArg());
2374+
2375+
Function* pNewFunction = Function::Create(
2376+
newFunctionTy,
2377+
pOriginalFunction->getLinkage(),
2378+
pOriginalFunction->getAddressSpace(),
2379+
pOriginalFunction->getName() + "_resolved",
2380+
pOriginalFunction->getParent());
2381+
2382+
pNewFunction->setCallingConv(pOriginalFunction->getCallingConv());
2383+
pNewFunction->setSubprogram(pOriginalFunction->getSubprogram());
2384+
pNewFunction->copyAttributesFrom(pOriginalFunction);
2385+
2386+
ValueToValueMapTy VMap;
2387+
2388+
auto originalFunctionArgIt = pOriginalFunction->arg_begin();
2389+
auto newFunctionArgIt = pNewFunction->arg_begin();
2390+
2391+
while (originalFunctionArgIt != pOriginalFunction->arg_end())
2392+
{
2393+
newFunctionArgIt->setName(originalFunctionArgIt->getName());
2394+
VMap[&(*originalFunctionArgIt++)] = newFunctionArgIt++;
2395+
}
2396+
2397+
if (!pOriginalFunction->isDeclaration())
2398+
{
2399+
SmallVector<ReturnInst*, 8> Returns;
2400+
IGCLLVM::CloneFunctionChangeType changeType = IGCLLVM::CloneFunctionChangeType::LocalChangesOnly;
2401+
IGCLLVM::CloneFunctionInto(pNewFunction, pOriginalFunction, VMap, changeType, Returns);
2402+
}
2403+
2404+
return pNewFunction;
2405+
}
2406+
23312407
void JointMatrixFuncsResolutionPass::visitCallInst(CallInst& CI)
23322408
{
23332409
Function* func = CI.getCalledFunction();
@@ -2387,15 +2463,211 @@ void JointMatrixFuncsResolutionPass::visitCallInst(CallInst& CI)
23872463
}
23882464
}
23892465
}
2466+
2467+
auto argsWithMatrixType = GetFunctionArgsWithMatrixType(func);
2468+
2469+
if (argsWithMatrixType.size() > 0) {
2470+
auto resolvedFunc = ResolvedFunctions.count(func) > 0 ? ResolvedFunctions[func] : ResolveFunctionSignature(func);
2471+
UpdateCallInstAfterFunctionResolve(resolvedFunc, &CI);
2472+
}
2473+
}
2474+
2475+
std::vector<Argument*> JointMatrixFuncsResolutionPass::GetFunctionArgsWithMatrixType(Function* func)
2476+
{
2477+
if (func == nullptr)
2478+
return std::vector<Argument*>();
2479+
2480+
std::vector<Argument*> argsWithMatrixType;
2481+
2482+
for (Argument &arg : func->args()) {
2483+
if (isOrContainsMatrixType(arg.getType())) {
2484+
argsWithMatrixType.push_back(&arg);
2485+
}
2486+
}
2487+
2488+
return argsWithMatrixType;
2489+
}
2490+
2491+
bool JointMatrixFuncsResolutionPass::UpdateCallInstAfterFunctionResolve(Function* ResolvedFunction, CallInst* CI)
2492+
{
2493+
if (!CI || !ResolvedFunction)
2494+
return false;
2495+
2496+
std::vector<Value*> params;
2497+
2498+
for (auto& callArg : CI->args())
2499+
{
2500+
auto callArgInst = callArg.get();
2501+
if (isOrContainsMatrixType(callArgInst->getType()))
2502+
{
2503+
Value* resolvedArg = ResolvedValues.count(callArgInst) > 0 ?
2504+
ResolvedValues[callArgInst] :
2505+
Resolve(callArgInst);
2506+
params.push_back(resolvedArg);
2507+
}
2508+
else
2509+
{
2510+
params.push_back(callArg.get());
2511+
}
2512+
}
2513+
2514+
IRBuilder<> b(CI);
2515+
auto newCall = b.CreateCall(ResolvedFunction, params);
2516+
newCall->setDebugLoc(CI->getDebugLoc());
2517+
newCall->setCallingConv(CI->getCallingConv());
2518+
newCall->setAttributes(CI->getAttributes());
2519+
2520+
if (CI->hasName())
2521+
{
2522+
newCall->setName(CI->getName());
2523+
}
2524+
2525+
InstsToErase.insert(CI);
2526+
return true;
2527+
}
2528+
2529+
Function* JointMatrixFuncsResolutionPass::ResolveFunctionSignature(Function* OriginalFunction)
2530+
{
2531+
if (ResolvedFunctions.count(OriginalFunction) > 0 && isa<Function>(ResolvedFunctions[OriginalFunction])) {
2532+
Function* cachedFunction = dyn_cast<Function>(ResolvedFunctions[OriginalFunction]);
2533+
return cachedFunction;
2534+
}
2535+
2536+
Function* newFunction = CloneFunction(OriginalFunction);
2537+
2538+
CacheResolvedValue(OriginalFunction, newFunction);
2539+
ResolvedFunctions[OriginalFunction] = newFunction;
2540+
return newFunction;
2541+
}
2542+
2543+
std::string getTypeName(Type* T)
2544+
{
2545+
std::string TypeName;
2546+
raw_string_ostream TypeStream(TypeName);
2547+
if (T)
2548+
T->print(TypeStream);
2549+
else
2550+
TypeStream << "Printing <null> Type";
2551+
TypeStream.flush();
2552+
return TypeName;
2553+
}
2554+
2555+
DIType* getOrCreateType(Type* T, Module* M) {
2556+
DIType* N = nullptr;
2557+
DIBuilder Builder(*M, true);
2558+
DataLayout Layout(M);
2559+
2560+
if (T->isPointerTy()) {
2561+
2562+
uint align = 0;
2563+
#if LLVM_VERSION_MAJOR < 10
2564+
align = IGCLLVM::getPrefTypeAlign(Layout, T);
2565+
#else
2566+
align = IGCLLVM::getPrefTypeAlign(Layout, T).value();
2567+
#endif
2568+
2569+
llvm::Optional<unsigned int> opt(llvm::None);
2570+
N = Builder.createPointerType(
2571+
nullptr, Layout.getPointerTypeSizeInBits(T),
2572+
align * CHAR_BIT, /*DWARFAddressSpace=*/opt,
2573+
getTypeName(T));
2574+
}
2575+
else
2576+
{
2577+
int encoding = llvm::dwarf::DW_ATE_signed;
2578+
if (T->isIntegerTy())
2579+
encoding = llvm::dwarf::DW_ATE_unsigned;
2580+
else if (T->isFloatingPointTy())
2581+
encoding = llvm::dwarf::DW_ATE_float;
2582+
2583+
N = Builder.createBasicType(getTypeName(T), T->getPrimitiveSizeInBits(),
2584+
encoding);
2585+
}
2586+
2587+
return N;
23902588
}
23912589

2590+
23922591
void JointMatrixFuncsResolutionPass::visitAllocaInst(AllocaInst &I)
23932592
{
23942593
if (ResolvedValues.count(&I) > 0)
23952594
return;
23962595

23972596
if (!isOrContainsMatrixType(I.getAllocatedType()))
23982597
return;
2598+
2599+
ResolveSIMDSize(I.getParent()->getParent());
2600+
2601+
Value *newInst = ResolveGeneric(&I);
2602+
2603+
if (newInst)
2604+
{
2605+
TinyPtrVector<DbgDeclareInst*> DDIs;
2606+
for (DbgVariableIntrinsic* DVI : FindDbgAddrUses(&I))
2607+
if (auto* DDI = dyn_cast<DbgDeclareInst>(DVI))
2608+
DDIs.push_back(DDI);
2609+
2610+
for (DbgDeclareInst* ddi : DDIs) {
2611+
auto loc = ddi->getDebugLoc();
2612+
auto var = ddi->getVariable();
2613+
auto file = var->getFile();
2614+
auto lineNo = var->getLine();
2615+
auto scope = var->getScope();
2616+
2617+
auto type = getOrCreateType(newInst->getType(), I.getModule());
2618+
2619+
llvm::DIBuilder builder(*(I.getModule()));
2620+
auto created = builder.createAutoVariable(scope, var->getName(), file, lineNo, type);
2621+
builder.insertDbgValueIntrinsic(newInst, created, builder.createExpression(), loc, ddi);
2622+
ddi->eraseFromParent();
2623+
}
2624+
}
2625+
}
2626+
2627+
void JointMatrixFuncsResolutionPass::visitAddrSpaceCastInst(llvm::AddrSpaceCastInst& I)
2628+
{
2629+
if (ResolvedValues.count(&I) > 0)
2630+
return;
2631+
2632+
if (!isOrContainsMatrixType(I.getType()))
2633+
return;
2634+
2635+
ResolveSIMDSize(I.getParent()->getParent());
2636+
ResolveGeneric(&I);
2637+
}
2638+
2639+
void JointMatrixFuncsResolutionPass::visitLoadInst(llvm::LoadInst& I)
2640+
{
2641+
if (ResolvedValues.count(&I) > 0)
2642+
return;
2643+
2644+
if (!isOrContainsMatrixType(I.getType()))
2645+
return;
2646+
2647+
ResolveSIMDSize(I.getParent()->getParent());
2648+
ResolveGeneric(&I);
2649+
}
2650+
2651+
void JointMatrixFuncsResolutionPass::visitPHINode(llvm::PHINode& I)
2652+
{
2653+
if (ResolvedValues.count(&I) > 0)
2654+
return;
2655+
2656+
if (!isOrContainsMatrixType(I.getType()))
2657+
return;
2658+
2659+
ResolveSIMDSize(I.getParent()->getParent());
2660+
ResolveGeneric(&I);
2661+
}
2662+
2663+
void JointMatrixFuncsResolutionPass::visitReturnInst(llvm::ReturnInst& I)
2664+
{
2665+
if (ResolvedValues.count(&I) > 0)
2666+
return;
2667+
2668+
if (I.getReturnValue() == nullptr || !isOrContainsMatrixType(I.getReturnValue()->getType()))
2669+
return;
2670+
23992671
ResolveSIMDSize(I.getParent()->getParent());
24002672
ResolveGeneric(&I);
24012673
}

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*========================== begin_copyright_notice ============================
22
3-
Copyright (C) 2021 Intel Corporation
3+
Copyright (C) 2024 Intel Corporation
44
55
SPDX-License-Identifier: MIT
66
@@ -50,8 +50,15 @@ namespace IGC
5050
void visitPtrToIntInst(llvm::PtrToIntInst &I);
5151
void visitStoreInst(llvm::StoreInst &I);
5252
void visitBitCastInst(llvm::BitCastInst &I);
53+
void visitAddrSpaceCastInst(llvm::AddrSpaceCastInst& I);
54+
void visitLoadInst(llvm::LoadInst& I);
55+
void visitPHINode(llvm::PHINode& I);
56+
void visitReturnInst(llvm::ReturnInst& I);
5357

5458
private:
59+
std::vector<llvm::Argument*> GetFunctionArgsWithMatrixType(llvm::Function* func);
60+
llvm::Function* ResolveFunctionSignature(llvm::Function* OriginalFunction);
61+
bool UpdateCallInstAfterFunctionResolve(llvm::Function* ResolvedFunction, llvm::CallInst* OptionalCallInst);
5562
llvm::Instruction *ResolvePrefetch(llvm::CallInst *CI);
5663
template <bool IsJointMatrix, bool isChecked>
5764
llvm::Instruction *ResolveLoad(llvm::CallInst *CI);
@@ -83,6 +90,7 @@ namespace IGC
8390
void CacheResolvedValue(llvm::Value *oldValue, llvm::Value *newValue);
8491
void CacheResolvedTypes(llvm::Type *oldType, llvm::Type *newType);
8592
void InsertPlaceholder(llvm::Value *v);
93+
llvm::Function* CloneFunction(llvm::Function* pOriginalFunction);
8694

8795
enum GetMatrixFuncNameOperation {
8896
GetCoord,
@@ -119,6 +127,7 @@ namespace IGC
119127
llvm::SmallPtrSet<llvm::Instruction *, 8> InstsToErase;
120128
// Maps function to it's kernel entry function
121129
std::unordered_map<llvm::Function *, llvm::Function *> FunctionsMap;
130+
std::unordered_map<llvm::Function *, llvm::Function *> ResolvedFunctions;
122131

123132
ModuleMetaData* MMD = nullptr;
124133
CodeGenContext* m_Ctx = nullptr;

0 commit comments

Comments
 (0)