@@ -76,14 +76,6 @@ static const unsigned int MAX_RETVAL_SIZE_IN_BITS = 64;
76
76
static const unsigned int MAX_STRUCT_SIZE_IN_BITS = 128 ;
77
77
static const unsigned int MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS = 512 ;
78
78
79
- enum ReturnOpt
80
- {
81
- RETURN_DEFAULT = 0 ,
82
- RETURN_BY_REF,
83
- RETURN_STRUCT,
84
- RETURN_LEGAL_INT
85
- };
86
-
87
79
bool LegalizeFunctionSignatures::runOnModule (Module& M)
88
80
{
89
81
auto pMdUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils ();
@@ -275,7 +267,8 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
275
267
continue ;
276
268
}
277
269
278
- ReturnOpt retTypeOption = ReturnOpt::RETURN_DEFAULT;
270
+ bool legalizeReturnType = false ;
271
+ bool promoteSRetType = false ;
279
272
bool fixArgType = false ;
280
273
std::vector<Type*> argTypes;
281
274
@@ -287,18 +280,14 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
287
280
// Create the new function signature by replacing the illegal types
288
281
if (FunctionHasPromotableSRetArg (M, pFunc))
289
282
{
290
- retTypeOption = ReturnOpt::RETURN_STRUCT ;
283
+ promoteSRetType = true ;
291
284
ai++; // Skip adding the first arg
292
285
}
293
286
else if (!isLegalSignatureType (M, pFunc->getReturnType (), isStackCall))
294
287
{
295
- retTypeOption = ReturnOpt::RETURN_BY_REF ;
288
+ legalizeReturnType = true ;
296
289
argTypes.push_back (PointerType::get (pFunc->getReturnType (), 0 ));
297
290
}
298
- else if (!isLegalIntVectorType (M, pFunc->getReturnType ()))
299
- {
300
- retTypeOption = ReturnOpt::RETURN_LEGAL_INT;
301
- }
302
291
303
292
for (; ai != ei; ai++)
304
293
{
@@ -324,32 +313,33 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
324
313
}
325
314
}
326
315
327
- if (retTypeOption != ReturnOpt::RETURN_DEFAULT || fixArgType)
316
+ if (!legalizeReturnType && !promoteSRetType && ! fixArgType)
328
317
{
329
- // Clone function with new signature
330
- Type* returnType =
331
- retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (M.getContext ()) :
332
- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, pFunc->arg_begin ()->getType ()) :
333
- retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, pFunc->getReturnType ()) :
334
- pFunc->getReturnType ();
335
- FunctionType* signature = FunctionType::get (returnType, argTypes, false );
336
- Function* pNewFunc = Function::Create (signature, pFunc->getLinkage (), pFunc->getName (), pFunc->getParent ());
337
- pNewFunc->takeName (pFunc);
338
- pNewFunc->setCallingConv (pFunc->getCallingConv ());
339
- pNewFunc->setAttributes (pFunc->getAttributes ());
340
-
341
- // Since we need to pass in pointers to be dereferenced by the new function, remove the "readnone" attribute
342
- // Also we need to create allocas for these pointers, so set the flag to true
343
- if (retTypeOption == ReturnOpt::RETURN_BY_REF)
344
- {
345
- pNewFunc->removeFnAttr (llvm::Attribute::ReadNone);
346
- pNewFunc->removeFnAttr (llvm::Attribute::ReadOnly);
347
- pContext->m_instrTypes .hasNonPrimitiveAlloca = true ;
348
- }
318
+ // Nothing to fix
319
+ continue ;
320
+ }
349
321
350
- // Map the old function to the new
351
- oldToNewFuncMap[pFunc] = pNewFunc;
322
+ // Clone function with new signature
323
+ Type* returnType = legalizeReturnType ? Type::getVoidTy (M.getContext ()) :
324
+ promoteSRetType ? PromotedStructValueType (M, pFunc->arg_begin ()->getType ()) :
325
+ pFunc->getReturnType ();
326
+ FunctionType* signature = FunctionType::get (returnType, argTypes, false );
327
+ Function* pNewFunc = Function::Create (signature, pFunc->getLinkage (), pFunc->getName (), pFunc->getParent ());
328
+ pNewFunc->takeName (pFunc);
329
+ pNewFunc->setCallingConv (pFunc->getCallingConv ());
330
+ pNewFunc->setAttributes (pFunc->getAttributes ());
331
+
332
+ // Since we need to pass in pointers to be dereferenced by the new function, remove the "readnone" attribute
333
+ // Also we need to create allocas for these pointers, so set the flag to true
334
+ if (legalizeReturnType)
335
+ {
336
+ pNewFunc->removeFnAttr (llvm::Attribute::ReadNone);
337
+ pNewFunc->removeFnAttr (llvm::Attribute::ReadOnly);
338
+ pContext->m_instrTypes .hasNonPrimitiveAlloca = true ;
352
339
}
340
+
341
+ // Map the old function to the new
342
+ oldToNewFuncMap[pFunc] = pNewFunc;
353
343
}
354
344
}
355
345
@@ -367,28 +357,26 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
367
357
llvm::SmallVector<llvm::ReturnInst*, 8 > Returns;
368
358
auto OldArgIt = pFunc->arg_begin ();
369
359
auto NewArgIt = pNewFunc->arg_begin ();
370
- ReturnOpt retTypeOption = ReturnOpt::RETURN_DEFAULT;
360
+ bool legalizeReturnType = false ;
361
+ bool promoteSRetType = false ;
371
362
bool isStackCall = pFunc->hasFnAttribute (" visaStackCall" );
372
363
Value* tempAllocaForSRetPointer = nullptr ;
373
364
llvm::SmallVector<llvm::Argument*, 8 > ArgByVal;
374
365
375
366
if (FunctionHasPromotableSRetArg (M, pFunc)) {
376
- retTypeOption = ReturnOpt::RETURN_STRUCT ;
367
+ promoteSRetType = true ;
377
368
}
378
369
else if (!isLegalSignatureType (M, pFunc->getReturnType (), isStackCall)) {
379
- retTypeOption = ReturnOpt::RETURN_BY_REF ;
370
+ legalizeReturnType = true ;
380
371
++NewArgIt; // Skip first argument that we added.
381
372
}
382
- else if (!isLegalIntVectorType (M, pFunc->getReturnType ())) {
383
- retTypeOption = ReturnOpt::RETURN_LEGAL_INT;
384
- }
385
373
386
374
// Fix the usages of arguments that have changed
387
375
BasicBlock* EntryBB = BasicBlock::Create (M.getContext (), " " , pNewFunc);
388
376
IGCLLVM::IRBuilder<> builder (EntryBB);
389
377
for (; OldArgIt != pFunc->arg_end (); ++OldArgIt)
390
378
{
391
- if (OldArgIt == pFunc->arg_begin () && retTypeOption == ReturnOpt::RETURN_STRUCT )
379
+ if (OldArgIt == pFunc->arg_begin () && promoteSRetType )
392
380
{
393
381
// Create a temp alloca to map the old argument. This will be removed later by SROA.
394
382
tempAllocaForSRetPointer = builder.CreateAlloca (PromotedStructValueType (M, OldArgIt->getType ()));
@@ -449,7 +437,7 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
449
437
}
450
438
451
439
// Now fix the return values
452
- if (retTypeOption == ReturnOpt::RETURN_BY_REF )
440
+ if (legalizeReturnType )
453
441
{
454
442
// Add the 'noalias' and 'sret' attribute to arg0
455
443
auto retArg = pNewFunc->arg_begin ();
@@ -471,7 +459,7 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
471
459
RetInst->eraseFromParent ();
472
460
}
473
461
}
474
- else if (retTypeOption == ReturnOpt::RETURN_STRUCT )
462
+ else if (promoteSRetType )
475
463
{
476
464
// For "sret" returns, we load from the temp alloca created earlier and return the loaded value instead
477
465
for (auto RetInst : Returns)
@@ -482,19 +470,6 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
482
470
RetInst->eraseFromParent ();
483
471
}
484
472
}
485
- else if (retTypeOption == ReturnOpt::RETURN_LEGAL_INT)
486
- {
487
- // Extend illegal int returns to legal type
488
- for (auto RetInst : Returns)
489
- {
490
- IGCLLVM::IRBuilder<> builder (RetInst);
491
- Value* retVal = RetInst->getReturnValue ();
492
- Type* retTy = retVal->getType ();
493
- retVal = builder.CreateZExt (retVal, LegalizedIntVectorType (M, retTy));
494
- builder.CreateRet (retVal);
495
- RetInst->eraseFromParent ();
496
- }
497
- }
498
473
}
499
474
500
475
// Now that all instructions are transferred to the new func, delete the old func
@@ -558,7 +533,8 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
558
533
{
559
534
Function* calledFunc = callInst->getCalledFunction ();
560
535
SmallVector<Value*, 16 > callArgs;
561
- ReturnOpt retTypeOption = ReturnOpt::RETURN_DEFAULT;
536
+ bool legalizeReturnType = false ;
537
+ bool promoteSRetType = false ;
562
538
bool fixArgType = false ;
563
539
bool isStackCall = !calledFunc || calledFunc->hasFnAttribute (" visaStackCall" );
564
540
@@ -577,7 +553,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
577
553
isPromotableStructType (M, callInst->getArgOperand (0 )->getType (), isStackCall, true /* retval */ ))
578
554
{
579
555
opNum++; // Skip the first call operand
580
- retTypeOption = ReturnOpt::RETURN_STRUCT ;
556
+ promoteSRetType = true ;
581
557
}
582
558
else if (!isLegalSignatureType (M, callInst->getType (), isStackCall))
583
559
{
@@ -590,11 +566,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
590
566
ArgAttrs.addAttribute (llvm::Attribute::NoAlias);
591
567
ArgAttrs.addStructRetAttr (callInst->getType ());
592
568
ArgAttrVec.push_back (AttributeSet::get (M.getContext (), ArgAttrs));
593
- retTypeOption = ReturnOpt::RETURN_BY_REF;
594
- }
595
- else if (!isLegalIntVectorType (M, callInst->getType ()))
596
- {
597
- retTypeOption = ReturnOpt::RETURN_LEGAL_INT;
569
+ legalizeReturnType = true ;
598
570
}
599
571
600
572
// Check call operands if it needs to be replaced
@@ -640,7 +612,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
640
612
}
641
613
}
642
614
643
- if (retTypeOption != ReturnOpt::RETURN_DEFAULT || fixArgType)
615
+ if (legalizeReturnType || promoteSRetType || fixArgType)
644
616
{
645
617
IGCLLVM::IRBuilder<> builder (callInst);
646
618
Value* newCalledValue = nullptr ;
@@ -653,10 +625,8 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
653
625
{
654
626
argTypes.push_back (arg->getType ());
655
627
}
656
- Type* retType =
657
- retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy (callInst->getContext ()) :
658
- retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType (M, callInst->getArgOperand (0 )->getType ()) :
659
- retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType (M, callInst->getType ()) :
628
+ Type* retType = legalizeReturnType ? Type::getVoidTy (callInst->getContext ()) :
629
+ promoteSRetType ? PromotedStructValueType (M, callInst->getArgOperand (0 )->getType ()) :
660
630
callInst->getType ();
661
631
newFnTy = FunctionType::get (retType, argTypes, false );
662
632
Value* calledValue = IGCLLVM::getCalledValue (callInst);
@@ -676,24 +646,18 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
676
646
newCallInst->setAttributes (AttributeList::get (M.getContext (), IGCLLVM::getFnAttrs (PAL), IGCLLVM::getRetAttrs (PAL), ArgAttrVec));
677
647
newCallInst->setDebugLoc (callInst->getDebugLoc ());
678
648
679
- if (retTypeOption == ReturnOpt::RETURN_BY_REF )
649
+ if (legalizeReturnType )
680
650
{
681
651
// Load the return value from the arg pointer before using it
682
652
IGC_ASSERT (returnPtr);
683
653
Value* load = builder.CreateLoad (returnPtr);
684
654
callInst->replaceAllUsesWith (load);
685
655
}
686
- else if (retTypeOption == ReturnOpt::RETURN_STRUCT )
656
+ else if (promoteSRetType )
687
657
{
688
658
// Store the struct value into the orginal pointer operand
689
659
StoreToStruct (builder, newCallInst, callInst->getArgOperand (0 ));
690
660
}
691
- else if (retTypeOption == ReturnOpt::RETURN_LEGAL_INT)
692
- {
693
- // Truncate legal type back into original value
694
- Value* trunc = builder.CreateTrunc (newCallInst, callInst->getType ());
695
- callInst->replaceAllUsesWith (trunc);
696
- }
697
661
else
698
662
{
699
663
callInst->replaceAllUsesWith (newCallInst);
0 commit comments