@@ -211,6 +211,8 @@ static bool
211
211
collectRegionsConstants (OutlinableRegion &Region,
212
212
DenseMap<unsigned , Constant *> &GVNToConstant,
213
213
DenseSet<unsigned > &NotSame) {
214
+ bool ConstantsTheSame = true ;
215
+
214
216
IRSimilarityCandidate &C = *Region.Candidate ;
215
217
for (IRInstructionData &ID : C) {
216
218
@@ -222,11 +224,10 @@ collectRegionsConstants(OutlinableRegion &Region,
222
224
assert (GVNOpt.hasValue () && " Expected a GVN for operand?" );
223
225
unsigned GVN = GVNOpt.getValue ();
224
226
225
- // If this global value has been found to not be the same, it could have
226
- // just been a register, check that it is not a constant value.
227
+ // Check if this global value has been found to not be the same already.
227
228
if (NotSame.find (GVN) != NotSame.end ()) {
228
229
if (isa<Constant>(V))
229
- return false ;
230
+ ConstantsTheSame = false ;
230
231
continue ;
231
232
}
232
233
@@ -239,30 +240,27 @@ collectRegionsConstants(OutlinableRegion &Region,
239
240
if (ConstantMatches.getValue ())
240
241
continue ;
241
242
else
242
- return false ;
243
+ ConstantsTheSame = false ;
243
244
}
244
245
245
246
// While this value is a register, it might not have been previously,
246
247
// make sure we don't already have a constant mapped to this global value
247
248
// number.
248
249
if (GVNToConstant.find (GVN) != GVNToConstant.end ())
249
- return false ;
250
+ ConstantsTheSame = false ;
250
251
251
252
NotSame.insert (GVN);
252
253
}
253
254
}
254
255
255
- return true ;
256
+ return ConstantsTheSame ;
256
257
}
257
258
258
259
void OutlinableGroup::findSameConstants (DenseSet<unsigned > &NotSame) {
259
260
DenseMap<unsigned , Constant *> GVNToConstant;
260
261
261
262
for (OutlinableRegion *Region : Regions)
262
- if (!collectRegionsConstants (*Region, GVNToConstant, NotSame)) {
263
- IgnoreGroup = true ;
264
- return ;
265
- }
263
+ collectRegionsConstants (*Region, GVNToConstant, NotSame);
266
264
}
267
265
268
266
Function *IROutliner::createFunction (Module &M, OutlinableGroup &Group,
@@ -307,16 +305,44 @@ static BasicBlock *moveFunctionData(Function &Old, Function &New) {
307
305
return NewEnd;
308
306
}
309
307
310
- // / Find the GVN for the inputs that have been found by the CodeExtractor,
311
- // / excluding the ones that will be removed by llvm.assumes as these will be
312
- // / removed by the CodeExtractor.
308
+ // / Find the the constants that will need to be lifted into arguments
309
+ // / as they are not the same in each instance of the region.
310
+ // /
311
+ // / \param [in] C - The IRSimilarityCandidate containing the region we are
312
+ // / analyzing.
313
+ // / \param [in] NotSame - The set of global value numbers that do not have a
314
+ // / single Constant across all OutlinableRegions similar to \p C.
315
+ // / \param [out] Inputs - The list containing the global value numbers of the
316
+ // / arguments needed for the region of code.
317
+ static void findConstants (IRSimilarityCandidate &C, DenseSet<unsigned > &NotSame,
318
+ std::vector<unsigned > &Inputs) {
319
+ DenseSet<unsigned > Seen;
320
+ // Iterate over the instructions, and find what constants will need to be
321
+ // extracted into arguments.
322
+ for (IRInstructionDataList::iterator IDIt = C.begin (), EndIDIt = C.end ();
323
+ IDIt != EndIDIt; IDIt++) {
324
+ for (Value *V : (*IDIt).OperVals ) {
325
+ // Since these are stored before any outlining, they will be in the
326
+ // global value numbering.
327
+ unsigned GVN = C.getGVN (V).getValue ();
328
+ if (Constant *CST = dyn_cast<Constant>(V))
329
+ if (NotSame.find (GVN) != NotSame.end () &&
330
+ Seen.find (GVN) == Seen.end ()) {
331
+ Inputs.push_back (GVN);
332
+ Seen.insert (GVN);
333
+ }
334
+ }
335
+ }
336
+ }
337
+
338
+ // / Find the GVN for the inputs that have been found by the CodeExtractor.
313
339
// /
314
340
// / \param [in] C - The IRSimilarityCandidate containing the region we are
315
341
// / analyzing.
316
342
// / \param [in] CurrentInputs - The set of inputs found by the
317
343
// / CodeExtractor.
318
- // / \param [out] CurrentInputNumbers - The global value numbers for the
319
- // / extracted arguments.
344
+ // / \param [out] EndInputNumbers - The global value numbers for the extracted
345
+ // / arguments.
320
346
static void mapInputsToGVNs (IRSimilarityCandidate &C,
321
347
SetVector<Value *> &CurrentInputs,
322
348
std::vector<unsigned > &EndInputNumbers) {
@@ -332,16 +358,20 @@ static void mapInputsToGVNs(IRSimilarityCandidate &C,
332
358
// / Find the input GVNs and the output values for a region of Instructions.
333
359
// / Using the code extractor, we collect the inputs to the extracted function.
334
360
// /
335
- // / The \p Region can be identifed as needing to be ignored in this function.
361
+ // / The \p Region can be identified as needing to be ignored in this function.
336
362
// / It should be checked whether it should be ignored after a call to this
337
363
// / function.
338
364
// /
339
365
// / \param [in,out] Region - The region of code to be analyzed.
340
366
// / \param [out] InputGVNs - The global value numbers for the extracted
341
367
// / arguments.
368
+ // / \param [in] NotSame - The global value numbers in the region that do not
369
+ // / have the same constant value in the regions structurally similar to
370
+ // / \p Region.
342
371
// / \param [out] ArgInputs - The values of the inputs to the extracted function.
343
372
static void getCodeExtractorArguments (OutlinableRegion &Region,
344
373
std::vector<unsigned > &InputGVNs,
374
+ DenseSet<unsigned > &NotSame,
345
375
SetVector<Value *> &ArgInputs) {
346
376
IRSimilarityCandidate &C = *Region.Candidate ;
347
377
@@ -389,13 +419,18 @@ static void getCodeExtractorArguments(OutlinableRegion &Region,
389
419
return ;
390
420
}
391
421
422
+ findConstants (C, NotSame, InputGVNs);
392
423
mapInputsToGVNs (C, OverallInputs, InputGVNs);
424
+
425
+ // Sort the GVNs, since we now have constants included in the \ref InputGVNs
426
+ // we need to make sure they are in a deterministic order.
427
+ stable_sort (InputGVNs.begin (), InputGVNs.end ());
393
428
}
394
429
395
430
// / Look over the inputs and map each input argument to an argument in the
396
- // / overall function for the regions . This creates a way to replace the
397
- // / arguments of the extracted function, with the arguments of the new overall
398
- // / function.
431
+ // / overall function for the OutlinableRegions . This creates a way to replace
432
+ // / the arguments of the extracted function with the arguments of the new
433
+ // / overall function.
399
434
// /
400
435
// / \param [in,out] Region - The region of code to be analyzed.
401
436
// / \param [in] InputsGVNs - The global value numbering of the input values
@@ -417,7 +452,10 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
417
452
unsigned OriginalIndex = 0 ;
418
453
419
454
// Find the mapping of the extracted arguments to the arguments for the
420
- // overall function.
455
+ // overall function. Since there may be extra arguments in the overall
456
+ // function to account for the extracted constants, we have two different
457
+ // counters as we find extracted arguments, and as we come across overall
458
+ // arguments.
421
459
for (unsigned InputVal : InputGVNs) {
422
460
Optional<Value *> InputOpt = C.fromGVN (InputVal);
423
461
assert (InputOpt.hasValue () && " Global value number not found?" );
@@ -426,9 +464,16 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
426
464
if (!Group.InputTypesSet )
427
465
Group.ArgumentTypes .push_back (Input->getType ());
428
466
429
- // It is not a constant, check if it is a sunken alloca. If it is not,
430
- // create the mapping from extracted to overall. If it is, create the
431
- // mapping of the index to the value.
467
+ // Check if we have a constant. If we do add it to the overall argument
468
+ // number to Constant map for the region, and continue to the next input.
469
+ if (Constant *CST = dyn_cast<Constant>(Input)) {
470
+ Region.AggArgToConstant .insert (std::make_pair (TypeIndex, CST));
471
+ TypeIndex++;
472
+ continue ;
473
+ }
474
+
475
+ // It is not a constant, we create the mapping from extracted argument list
476
+ // to the overall argument list.
432
477
assert (ArgInputs.count (Input) && " Input cannot be found!" );
433
478
434
479
Region.ExtractedArgToAgg .insert (std::make_pair (OriginalIndex, TypeIndex));
@@ -437,10 +482,10 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
437
482
TypeIndex++;
438
483
}
439
484
440
- // If we do not have definitions for the OutlinableGroup holding the region,
441
- // set the length of the inputs here. We should have the same inputs for
442
- // all of the different regions contained in the OutlinableGroup since they
443
- // are all structurally similar to one another
485
+ // If the function type definitions for the OutlinableGroup holding the region
486
+ // have not been set, set the length of the inputs here. We should have the
487
+ // same inputs for all of the different regions contained in the
488
+ // OutlinableGroup since they are all structurally similar to one another.
444
489
if (!Group.InputTypesSet ) {
445
490
Group.NumAggregateInputs = TypeIndex;
446
491
Group.InputTypesSet = true ;
@@ -449,11 +494,12 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
449
494
Region.NumExtractedInputs = OriginalIndex;
450
495
}
451
496
452
- void IROutliner::findAddInputsOutputs (Module &M, OutlinableRegion &Region) {
497
+ void IROutliner::findAddInputsOutputs (Module &M, OutlinableRegion &Region,
498
+ DenseSet<unsigned > &NotSame) {
453
499
std::vector<unsigned > Inputs;
454
500
SetVector<Value *> ArgInputs;
455
501
456
- getCodeExtractorArguments (Region, Inputs, ArgInputs);
502
+ getCodeExtractorArguments (Region, Inputs, NotSame, ArgInputs);
457
503
458
504
if (Region.IgnoreRegion )
459
505
return ;
@@ -474,6 +520,7 @@ void IROutliner::findAddInputsOutputs(Module &M, OutlinableRegion &Region) {
474
520
// / \returns a call instruction with the replaced function.
475
521
CallInst *replaceCalledFunction (Module &M, OutlinableRegion &Region) {
476
522
std::vector<Value *> NewCallArgs;
523
+ DenseMap<unsigned , unsigned >::iterator ArgPair;
477
524
478
525
OutlinableGroup &Group = *Region.Parent ;
479
526
CallInst *Call = Region.Call ;
@@ -484,12 +531,72 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) {
484
531
// If the arguments are the same size, there are not values that need to be
485
532
// made argument, or different output registers to handle. We can simply
486
533
// replace the called function in this case.
487
- assert (AggFunc->arg_size () == Call->arg_size () &&
488
- " Can only replace calls with the same number of arguments!" );
534
+ if (AggFunc->arg_size () == Call->arg_size ()) {
535
+ LLVM_DEBUG (dbgs () << " Replace call to " << *Call << " with call to "
536
+ << *AggFunc << " with same number of arguments\n " );
537
+ Call->setCalledFunction (AggFunc);
538
+ return Call;
539
+ }
540
+
541
+ // We have a different number of arguments than the new function, so
542
+ // we need to use our previously mappings off extracted argument to overall
543
+ // function argument, and constants to overall function argument to create the
544
+ // new argument list.
545
+ for (unsigned AggArgIdx = 0 ; AggArgIdx < AggFunc->arg_size (); AggArgIdx++) {
546
+
547
+ ArgPair = Region.AggArgToExtracted .find (AggArgIdx);
548
+ if (ArgPair != Region.AggArgToExtracted .end ()) {
549
+ Value *ArgumentValue = Call->getArgOperand (ArgPair->second );
550
+ // If we found the mapping from the extracted function to the overall
551
+ // function, we simply add it to the argument list. We use the same
552
+ // value, it just needs to honor the new order of arguments.
553
+ LLVM_DEBUG (dbgs () << " Setting argument " << AggArgIdx << " to value "
554
+ << *ArgumentValue << " \n " );
555
+ NewCallArgs.push_back (ArgumentValue);
556
+ continue ;
557
+ }
558
+
559
+ // If it is a constant, we simply add it to the argument list as a value.
560
+ if (Region.AggArgToConstant .find (AggArgIdx) !=
561
+ Region.AggArgToConstant .end ()) {
562
+ Constant *CST = Region.AggArgToConstant .find (AggArgIdx)->second ;
563
+ LLVM_DEBUG (dbgs () << " Setting argument " << AggArgIdx << " to value "
564
+ << *CST << " \n " );
565
+ NewCallArgs.push_back (CST);
566
+ continue ;
567
+ }
568
+
569
+ // Add a nullptr value if the argument is not found in the extracted
570
+ // function. If we cannot find a value, it means it is not in use
571
+ // for the region, so we should not pass anything to it.
572
+ LLVM_DEBUG (dbgs () << " Setting argument " << AggArgIdx << " to nullptr\n " );
573
+ NewCallArgs.push_back (ConstantPointerNull::get (
574
+ static_cast <PointerType *>(AggFunc->getArg (AggArgIdx)->getType ())));
575
+ }
489
576
490
577
LLVM_DEBUG (dbgs () << " Replace call to " << *Call << " with call to "
491
- << *AggFunc << " with same number of arguments\n " );
492
- Call->setCalledFunction (AggFunc);
578
+ << *AggFunc << " with new set of arguments\n " );
579
+ // Create the new call instruction and erase the old one.
580
+ Call = CallInst::Create (AggFunc->getFunctionType (), AggFunc, NewCallArgs, " " ,
581
+ Call);
582
+
583
+ // It is possible that the call to the outlined function is either the first
584
+ // instruction in the new block, the last instruction, or both. If either of
585
+ // these is the case, we need to make sure that we replace the instruction in
586
+ // the IRInstructionData struct with the new call.
587
+ CallInst *OldCall = Region.Call ;
588
+ if (Region.NewFront ->Inst == OldCall)
589
+ Region.NewFront ->Inst = Call;
590
+ if (Region.NewBack ->Inst == OldCall)
591
+ Region.NewBack ->Inst = Call;
592
+
593
+ // Transfer any debug information.
594
+ Call->setDebugLoc (Region.Call ->getDebugLoc ());
595
+
596
+ // Remove the old instruction.
597
+ OldCall->eraseFromParent ();
598
+ Region.Call = Call;
599
+
493
600
return Call;
494
601
}
495
602
@@ -518,6 +625,37 @@ static void replaceArgumentUses(OutlinableRegion &Region) {
518
625
}
519
626
}
520
627
628
+ // / Within an extracted function, replace the constants that need to be lifted
629
+ // / into arguments with the actual argument.
630
+ // /
631
+ // / \param Region [in] - The region of extracted code to be changed.
632
+ void replaceConstants (OutlinableRegion &Region) {
633
+ OutlinableGroup &Group = *Region.Parent ;
634
+ // Iterate over the constants that need to be elevated into arguments
635
+ for (std::pair<unsigned , Constant *> &Const : Region.AggArgToConstant ) {
636
+ unsigned AggArgIdx = Const.first ;
637
+ Function *OutlinedFunction = Group.OutlinedFunction ;
638
+ assert (OutlinedFunction && " Overall Function is not defined?" );
639
+ Constant *CST = Const.second ;
640
+ Argument *Arg = Group.OutlinedFunction ->getArg (AggArgIdx);
641
+ // Identify the argument it will be elevated to, and replace instances of
642
+ // that constant in the function.
643
+
644
+ // TODO: If in the future constants do not have one global value number,
645
+ // i.e. a constant 1 could be mapped to several values, this check will
646
+ // have to be more strict. It cannot be using only replaceUsesWithIf.
647
+
648
+ LLVM_DEBUG (dbgs () << " Replacing uses of constant " << *CST
649
+ << " in function " << *OutlinedFunction << " with "
650
+ << *Arg << " \n " );
651
+ CST->replaceUsesWithIf (Arg, [OutlinedFunction](Use &U) {
652
+ if (Instruction *I = dyn_cast<Instruction>(U.getUser ()))
653
+ return I->getFunction () == OutlinedFunction;
654
+ return false ;
655
+ });
656
+ }
657
+ }
658
+
521
659
// / Fill the new function that will serve as the replacement function for all of
522
660
// / the extracted regions of a certain structure from the first region in the
523
661
// / list of regions. Replace this first region's extracted function with the
@@ -544,6 +682,7 @@ static void fillOverallFunction(Module &M, OutlinableGroup &CurrentGroup,
544
682
CurrentGroup.OutlinedFunction ->addFnAttr (A);
545
683
546
684
replaceArgumentUses (*CurrentOS);
685
+ replaceConstants (*CurrentOS);
547
686
548
687
// Replace the call to the extracted function with the outlined function.
549
688
CurrentOS->Call = replaceCalledFunction (M, *CurrentOS);
@@ -738,7 +877,7 @@ unsigned IROutliner::doOutline(Module &M) {
738
877
OS->CE = new (ExtractorAllocator.Allocate ())
739
878
CodeExtractor (BE, nullptr , false , nullptr , nullptr , nullptr , false ,
740
879
false , " outlined" );
741
- findAddInputsOutputs (M, *OS);
880
+ findAddInputsOutputs (M, *OS, NotSame );
742
881
if (!OS->IgnoreRegion )
743
882
OutlinedRegions.push_back (OS);
744
883
else
0 commit comments