1
1
/* ========================== begin_copyright_notice ============================
2
2
3
- Copyright (C) 2021 Intel Corporation
3
+ Copyright (C) 2024 Intel Corporation
4
4
5
5
SPDX-License-Identifier: MIT
6
6
@@ -23,6 +23,9 @@ SPDX-License-Identifier: MIT
23
23
#include < llvm/ADT/Sequence.h>
24
24
#include < llvm/ADT/STLExtras.h>
25
25
#include < llvm/ADT/PostOrderIterator.h>
26
+ #include " llvm/IR/DebugInfo.h"
27
+ #include " llvm/IR/DIBuilder.h"
28
+ #include " llvmWrapper/Transforms/Utils/Cloning.h"
26
29
#include < llvmWrapper/ADT/Optional.h>
27
30
#include " llvmWrapper/IR/Value.h"
28
31
#include < llvmWrapper/Analysis/ValueTracking.h>
@@ -79,7 +82,6 @@ static const char *CooperativeMatrixLengthPrefx = "CooperativeMatrixLengthKHR";
79
82
static const char *CooperativeMatrixGetElementCoordPrefx =" CooperativeMatrixGetElementCoordINTEL" ;
80
83
static const char *AccessChainPrefx = " __spirv_AccessChain" ;
81
84
82
-
83
85
// We need module pass, since:
84
86
// 1) we inspect multiple functions to find entry function to get sub group size
85
87
// 2) we maintain map of functions to entry functions across functions we process
@@ -94,6 +96,8 @@ bool JointMatrixFuncsResolutionPass::runOnModule(Module &M)
94
96
m_Ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext ();
95
97
m_mdUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils ();
96
98
FunctionsMap.clear ();
99
+ ResolvedFunctions.clear ();
100
+ ResolvedTypes.clear ();
97
101
Changed = false ;
98
102
99
103
for (auto &F : M) {
@@ -104,6 +108,28 @@ bool JointMatrixFuncsResolutionPass::runOnModule(Module &M)
104
108
preprocessAccessChain (&F);
105
109
}
106
110
111
+ for (auto & F : M) {
112
+ bool stop = false ;
113
+ for (auto & entry : ResolvedFunctions)
114
+ {
115
+ if (entry.second == &F)
116
+ {
117
+ stop = true ;
118
+ break ;
119
+ }
120
+ }
121
+
122
+ if (stop)
123
+ break ;
124
+
125
+ auto argsWithMatrixType = GetFunctionArgsWithMatrixType (&F);
126
+
127
+ if (argsWithMatrixType.size () > 0 ) {
128
+ ResolveSIMDSize (&F);
129
+ ResolveFunctionSignature (&F);
130
+ }
131
+ }
132
+
107
133
for (auto &F : M)
108
134
{
109
135
if (F.isDeclaration ())
@@ -297,12 +323,14 @@ bool JointMatrixFuncsResolutionPass::runOnFunction(Function& F)
297
323
{
298
324
PlaceholderInstructions.clear ();
299
325
ResolvedValues.clear ();
300
- ResolvedTypes.clear ();
301
326
InstsToErase.clear ();
302
327
MatrixAllocas.clear ();
303
328
m_SIMDSize = 0 ;
304
329
305
- // Use reverse post order traversal to reduce level or recursion
330
+ if (ResolvedFunctions.count (&F) > 0 )
331
+ return false ;
332
+
333
+ // Use reverse post order traversal to reduce level or recursion.
306
334
ReversePostOrderTraversal<Function *> RPOT (&F);
307
335
for (BasicBlock *BB : RPOT)
308
336
visit (BB);
@@ -2328,6 +2356,54 @@ Value *JointMatrixFuncsResolutionPass::Resolve(Value *v)
2328
2356
return nullptr ;
2329
2357
}
2330
2358
2359
+ Function* JointMatrixFuncsResolutionPass::CloneFunction (Function* pOriginalFunction)
2360
+ {
2361
+ if (pOriginalFunction == nullptr ) {
2362
+ return nullptr ;
2363
+ }
2364
+
2365
+ std::vector<Type*> params;
2366
+
2367
+ for (auto &arg : pOriginalFunction->args ())
2368
+ {
2369
+ auto type = isOrContainsMatrixType (arg.getType ()) ? ResolveTypes (arg.getType ()) : arg.getType ();
2370
+ params.push_back (type);
2371
+ }
2372
+
2373
+ auto newFunctionTy = FunctionType::get (ResolveTypes (pOriginalFunction->getReturnType ()), params, pOriginalFunction->isVarArg ());
2374
+
2375
+ Function* pNewFunction = Function::Create (
2376
+ newFunctionTy,
2377
+ pOriginalFunction->getLinkage (),
2378
+ pOriginalFunction->getAddressSpace (),
2379
+ pOriginalFunction->getName () + " _resolved" ,
2380
+ pOriginalFunction->getParent ());
2381
+
2382
+ pNewFunction->setCallingConv (pOriginalFunction->getCallingConv ());
2383
+ pNewFunction->setSubprogram (pOriginalFunction->getSubprogram ());
2384
+ pNewFunction->copyAttributesFrom (pOriginalFunction);
2385
+
2386
+ ValueToValueMapTy VMap;
2387
+
2388
+ auto originalFunctionArgIt = pOriginalFunction->arg_begin ();
2389
+ auto newFunctionArgIt = pNewFunction->arg_begin ();
2390
+
2391
+ while (originalFunctionArgIt != pOriginalFunction->arg_end ())
2392
+ {
2393
+ newFunctionArgIt->setName (originalFunctionArgIt->getName ());
2394
+ VMap[&(*originalFunctionArgIt++)] = newFunctionArgIt++;
2395
+ }
2396
+
2397
+ if (!pOriginalFunction->isDeclaration ())
2398
+ {
2399
+ SmallVector<ReturnInst*, 8 > Returns;
2400
+ IGCLLVM::CloneFunctionChangeType changeType = IGCLLVM::CloneFunctionChangeType::LocalChangesOnly;
2401
+ IGCLLVM::CloneFunctionInto (pNewFunction, pOriginalFunction, VMap, changeType, Returns);
2402
+ }
2403
+
2404
+ return pNewFunction;
2405
+ }
2406
+
2331
2407
void JointMatrixFuncsResolutionPass::visitCallInst (CallInst& CI)
2332
2408
{
2333
2409
Function* func = CI.getCalledFunction ();
@@ -2387,15 +2463,211 @@ void JointMatrixFuncsResolutionPass::visitCallInst(CallInst& CI)
2387
2463
}
2388
2464
}
2389
2465
}
2466
+
2467
+ auto argsWithMatrixType = GetFunctionArgsWithMatrixType (func);
2468
+
2469
+ if (argsWithMatrixType.size () > 0 ) {
2470
+ auto resolvedFunc = ResolvedFunctions.count (func) > 0 ? ResolvedFunctions[func] : ResolveFunctionSignature (func);
2471
+ UpdateCallInstAfterFunctionResolve (resolvedFunc, &CI);
2472
+ }
2473
+ }
2474
+
2475
+ std::vector<Argument*> JointMatrixFuncsResolutionPass::GetFunctionArgsWithMatrixType (Function* func)
2476
+ {
2477
+ if (func == nullptr )
2478
+ return std::vector<Argument*>();
2479
+
2480
+ std::vector<Argument*> argsWithMatrixType;
2481
+
2482
+ for (Argument &arg : func->args ()) {
2483
+ if (isOrContainsMatrixType (arg.getType ())) {
2484
+ argsWithMatrixType.push_back (&arg);
2485
+ }
2486
+ }
2487
+
2488
+ return argsWithMatrixType;
2489
+ }
2490
+
2491
+ bool JointMatrixFuncsResolutionPass::UpdateCallInstAfterFunctionResolve (Function* ResolvedFunction, CallInst* CI)
2492
+ {
2493
+ if (!CI || !ResolvedFunction)
2494
+ return false ;
2495
+
2496
+ std::vector<Value*> params;
2497
+
2498
+ for (auto & callArg : CI->args ())
2499
+ {
2500
+ auto callArgInst = callArg.get ();
2501
+ if (isOrContainsMatrixType (callArgInst->getType ()))
2502
+ {
2503
+ Value* resolvedArg = ResolvedValues.count (callArgInst) > 0 ?
2504
+ ResolvedValues[callArgInst] :
2505
+ Resolve (callArgInst);
2506
+ params.push_back (resolvedArg);
2507
+ }
2508
+ else
2509
+ {
2510
+ params.push_back (callArg.get ());
2511
+ }
2512
+ }
2513
+
2514
+ IRBuilder<> b (CI);
2515
+ auto newCall = b.CreateCall (ResolvedFunction, params);
2516
+ newCall->setDebugLoc (CI->getDebugLoc ());
2517
+ newCall->setCallingConv (CI->getCallingConv ());
2518
+ newCall->setAttributes (CI->getAttributes ());
2519
+
2520
+ if (CI->hasName ())
2521
+ {
2522
+ newCall->setName (CI->getName ());
2523
+ }
2524
+
2525
+ InstsToErase.insert (CI);
2526
+ return true ;
2527
+ }
2528
+
2529
+ Function* JointMatrixFuncsResolutionPass::ResolveFunctionSignature (Function* OriginalFunction)
2530
+ {
2531
+ if (ResolvedFunctions.count (OriginalFunction) > 0 && isa<Function>(ResolvedFunctions[OriginalFunction])) {
2532
+ Function* cachedFunction = dyn_cast<Function>(ResolvedFunctions[OriginalFunction]);
2533
+ return cachedFunction;
2534
+ }
2535
+
2536
+ Function* newFunction = CloneFunction (OriginalFunction);
2537
+
2538
+ CacheResolvedValue (OriginalFunction, newFunction);
2539
+ ResolvedFunctions[OriginalFunction] = newFunction;
2540
+ return newFunction;
2541
+ }
2542
+
2543
+ std::string getTypeName (Type* T)
2544
+ {
2545
+ std::string TypeName;
2546
+ raw_string_ostream TypeStream (TypeName);
2547
+ if (T)
2548
+ T->print (TypeStream);
2549
+ else
2550
+ TypeStream << " Printing <null> Type" ;
2551
+ TypeStream.flush ();
2552
+ return TypeName;
2553
+ }
2554
+
2555
+ DIType* getOrCreateType (Type* T, Module* M) {
2556
+ DIType* N = nullptr ;
2557
+ DIBuilder Builder (*M, true );
2558
+ DataLayout Layout (M);
2559
+
2560
+ if (T->isPointerTy ()) {
2561
+
2562
+ uint align = 0 ;
2563
+ #if LLVM_VERSION_MAJOR < 10
2564
+ align = IGCLLVM::getPrefTypeAlign (Layout, T);
2565
+ #else
2566
+ align = IGCLLVM::getPrefTypeAlign (Layout, T).value ();
2567
+ #endif
2568
+
2569
+ llvm::Optional<unsigned int > opt (llvm::None);
2570
+ N = Builder.createPointerType (
2571
+ nullptr , Layout.getPointerTypeSizeInBits (T),
2572
+ align * CHAR_BIT, /* DWARFAddressSpace=*/ opt,
2573
+ getTypeName (T));
2574
+ }
2575
+ else
2576
+ {
2577
+ int encoding = llvm::dwarf::DW_ATE_signed;
2578
+ if (T->isIntegerTy ())
2579
+ encoding = llvm::dwarf::DW_ATE_unsigned;
2580
+ else if (T->isFloatingPointTy ())
2581
+ encoding = llvm::dwarf::DW_ATE_float;
2582
+
2583
+ N = Builder.createBasicType (getTypeName (T), T->getPrimitiveSizeInBits (),
2584
+ encoding);
2585
+ }
2586
+
2587
+ return N;
2390
2588
}
2391
2589
2590
+
2392
2591
void JointMatrixFuncsResolutionPass::visitAllocaInst (AllocaInst &I)
2393
2592
{
2394
2593
if (ResolvedValues.count (&I) > 0 )
2395
2594
return ;
2396
2595
2397
2596
if (!isOrContainsMatrixType (I.getAllocatedType ()))
2398
2597
return ;
2598
+
2599
+ ResolveSIMDSize (I.getParent ()->getParent ());
2600
+
2601
+ Value *newInst = ResolveGeneric (&I);
2602
+
2603
+ if (newInst)
2604
+ {
2605
+ TinyPtrVector<DbgDeclareInst*> DDIs;
2606
+ for (DbgVariableIntrinsic* DVI : FindDbgAddrUses (&I))
2607
+ if (auto * DDI = dyn_cast<DbgDeclareInst>(DVI))
2608
+ DDIs.push_back (DDI);
2609
+
2610
+ for (DbgDeclareInst* ddi : DDIs) {
2611
+ auto loc = ddi->getDebugLoc ();
2612
+ auto var = ddi->getVariable ();
2613
+ auto file = var->getFile ();
2614
+ auto lineNo = var->getLine ();
2615
+ auto scope = var->getScope ();
2616
+
2617
+ auto type = getOrCreateType (newInst->getType (), I.getModule ());
2618
+
2619
+ llvm::DIBuilder builder (*(I.getModule ()));
2620
+ auto created = builder.createAutoVariable (scope, var->getName (), file, lineNo, type);
2621
+ builder.insertDbgValueIntrinsic (newInst, created, builder.createExpression (), loc, ddi);
2622
+ ddi->eraseFromParent ();
2623
+ }
2624
+ }
2625
+ }
2626
+
2627
+ void JointMatrixFuncsResolutionPass::visitAddrSpaceCastInst (llvm::AddrSpaceCastInst& I)
2628
+ {
2629
+ if (ResolvedValues.count (&I) > 0 )
2630
+ return ;
2631
+
2632
+ if (!isOrContainsMatrixType (I.getType ()))
2633
+ return ;
2634
+
2635
+ ResolveSIMDSize (I.getParent ()->getParent ());
2636
+ ResolveGeneric (&I);
2637
+ }
2638
+
2639
+ void JointMatrixFuncsResolutionPass::visitLoadInst (llvm::LoadInst& I)
2640
+ {
2641
+ if (ResolvedValues.count (&I) > 0 )
2642
+ return ;
2643
+
2644
+ if (!isOrContainsMatrixType (I.getType ()))
2645
+ return ;
2646
+
2647
+ ResolveSIMDSize (I.getParent ()->getParent ());
2648
+ ResolveGeneric (&I);
2649
+ }
2650
+
2651
+ void JointMatrixFuncsResolutionPass::visitPHINode (llvm::PHINode& I)
2652
+ {
2653
+ if (ResolvedValues.count (&I) > 0 )
2654
+ return ;
2655
+
2656
+ if (!isOrContainsMatrixType (I.getType ()))
2657
+ return ;
2658
+
2659
+ ResolveSIMDSize (I.getParent ()->getParent ());
2660
+ ResolveGeneric (&I);
2661
+ }
2662
+
2663
+ void JointMatrixFuncsResolutionPass::visitReturnInst (llvm::ReturnInst& I)
2664
+ {
2665
+ if (ResolvedValues.count (&I) > 0 )
2666
+ return ;
2667
+
2668
+ if (I.getReturnValue () == nullptr || !isOrContainsMatrixType (I.getReturnValue ()->getType ()))
2669
+ return ;
2670
+
2399
2671
ResolveSIMDSize (I.getParent ()->getParent ());
2400
2672
ResolveGeneric (&I);
2401
2673
}
0 commit comments