Skip to content

Commit b1191c8

Browse files
[IROutliner] Adding support for elevating constants that are not the same in each region to arguments
When there are constants that have the same structural location, but not the same value, between different regions, we cannot simply outline the region. Instead, we find the constants that are not the same in each location, and promote them to arguments to be passed into the respective functions. At each call site, we pass the constant in as an argument regardless of type. Added/Edited Tests: llvm/test/Transforms/IROutliner/outlining-constants-vs-registers.ll llvm/test/Transforms/IROutliner/outlining-different-constants.ll llvm/test/Transforms/IROutliner/outlining-different-globals.ll Reviewers: paquette, jroelofs Differential Revision: https://reviews.llvm.org/D87294
1 parent 3b3a9d2 commit b1191c8

File tree

5 files changed

+225
-75
lines changed

5 files changed

+225
-75
lines changed

llvm/include/llvm/Transforms/IPO/IROutliner.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ struct OutlinableRegion {
8181
DenseMap<unsigned, unsigned> ExtractedArgToAgg;
8282
DenseMap<unsigned, unsigned> AggArgToExtracted;
8383

84+
/// Mapping of the argument number in the deduplicated function
85+
/// to a given constant, which is used when creating the arguments to the call
86+
/// to the newly created deduplicated function. This is handled separately
87+
/// since the CodeExtractor does not recognize constants.
88+
DenseMap<unsigned, Constant *> AggArgToConstant;
89+
8490
/// Used to create an outlined function.
8591
CodeExtractor *CE = nullptr;
8692

@@ -180,8 +186,11 @@ class IROutliner {
180186
/// function if needed.
181187
///
182188
/// \param [in] M - The module to outline from.
183-
/// \param [in,out] Region - The region to be extracted
184-
void findAddInputsOutputs(Module &M, OutlinableRegion &Region);
189+
/// \param [in,out] Region - The region to be extracted.
190+
/// \param [in] NotSame - The global value numbers of the Values in the region
191+
/// that do not have the same Constant in each strucutrally similar region.
192+
void findAddInputsOutputs(Module &M, OutlinableRegion &Region,
193+
DenseSet<unsigned> &NotSame);
185194

186195
/// Extract \p Region into its own function.
187196
///

llvm/lib/Transforms/IPO/IROutliner.cpp

Lines changed: 173 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ static bool
211211
collectRegionsConstants(OutlinableRegion &Region,
212212
DenseMap<unsigned, Constant *> &GVNToConstant,
213213
DenseSet<unsigned> &NotSame) {
214+
bool ConstantsTheSame = true;
215+
214216
IRSimilarityCandidate &C = *Region.Candidate;
215217
for (IRInstructionData &ID : C) {
216218

@@ -222,11 +224,10 @@ collectRegionsConstants(OutlinableRegion &Region,
222224
assert(GVNOpt.hasValue() && "Expected a GVN for operand?");
223225
unsigned GVN = GVNOpt.getValue();
224226

225-
// If this global value has been found to not be the same, it could have
226-
// just been a register, check that it is not a constant value.
227+
// Check if this global value has been found to not be the same already.
227228
if (NotSame.find(GVN) != NotSame.end()) {
228229
if (isa<Constant>(V))
229-
return false;
230+
ConstantsTheSame = false;
230231
continue;
231232
}
232233

@@ -239,30 +240,27 @@ collectRegionsConstants(OutlinableRegion &Region,
239240
if (ConstantMatches.getValue())
240241
continue;
241242
else
242-
return false;
243+
ConstantsTheSame = false;
243244
}
244245

245246
// While this value is a register, it might not have been previously,
246247
// make sure we don't already have a constant mapped to this global value
247248
// number.
248249
if (GVNToConstant.find(GVN) != GVNToConstant.end())
249-
return false;
250+
ConstantsTheSame = false;
250251

251252
NotSame.insert(GVN);
252253
}
253254
}
254255

255-
return true;
256+
return ConstantsTheSame;
256257
}
257258

258259
void OutlinableGroup::findSameConstants(DenseSet<unsigned> &NotSame) {
259260
DenseMap<unsigned, Constant *> GVNToConstant;
260261

261262
for (OutlinableRegion *Region : Regions)
262-
if (!collectRegionsConstants(*Region, GVNToConstant, NotSame)) {
263-
IgnoreGroup = true;
264-
return;
265-
}
263+
collectRegionsConstants(*Region, GVNToConstant, NotSame);
266264
}
267265

268266
Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group,
@@ -307,16 +305,44 @@ static BasicBlock *moveFunctionData(Function &Old, Function &New) {
307305
return NewEnd;
308306
}
309307

310-
/// Find the GVN for the inputs that have been found by the CodeExtractor,
311-
/// excluding the ones that will be removed by llvm.assumes as these will be
312-
/// removed by the CodeExtractor.
308+
/// Find the the constants that will need to be lifted into arguments
309+
/// as they are not the same in each instance of the region.
310+
///
311+
/// \param [in] C - The IRSimilarityCandidate containing the region we are
312+
/// analyzing.
313+
/// \param [in] NotSame - The set of global value numbers that do not have a
314+
/// single Constant across all OutlinableRegions similar to \p C.
315+
/// \param [out] Inputs - The list containing the global value numbers of the
316+
/// arguments needed for the region of code.
317+
static void findConstants(IRSimilarityCandidate &C, DenseSet<unsigned> &NotSame,
318+
std::vector<unsigned> &Inputs) {
319+
DenseSet<unsigned> Seen;
320+
// Iterate over the instructions, and find what constants will need to be
321+
// extracted into arguments.
322+
for (IRInstructionDataList::iterator IDIt = C.begin(), EndIDIt = C.end();
323+
IDIt != EndIDIt; IDIt++) {
324+
for (Value *V : (*IDIt).OperVals) {
325+
// Since these are stored before any outlining, they will be in the
326+
// global value numbering.
327+
unsigned GVN = C.getGVN(V).getValue();
328+
if (Constant *CST = dyn_cast<Constant>(V))
329+
if (NotSame.find(GVN) != NotSame.end() &&
330+
Seen.find(GVN) == Seen.end()) {
331+
Inputs.push_back(GVN);
332+
Seen.insert(GVN);
333+
}
334+
}
335+
}
336+
}
337+
338+
/// Find the GVN for the inputs that have been found by the CodeExtractor.
313339
///
314340
/// \param [in] C - The IRSimilarityCandidate containing the region we are
315341
/// analyzing.
316342
/// \param [in] CurrentInputs - The set of inputs found by the
317343
/// CodeExtractor.
318-
/// \param [out] CurrentInputNumbers - The global value numbers for the
319-
/// extracted arguments.
344+
/// \param [out] EndInputNumbers - The global value numbers for the extracted
345+
/// arguments.
320346
static void mapInputsToGVNs(IRSimilarityCandidate &C,
321347
SetVector<Value *> &CurrentInputs,
322348
std::vector<unsigned> &EndInputNumbers) {
@@ -332,16 +358,20 @@ static void mapInputsToGVNs(IRSimilarityCandidate &C,
332358
/// Find the input GVNs and the output values for a region of Instructions.
333359
/// Using the code extractor, we collect the inputs to the extracted function.
334360
///
335-
/// The \p Region can be identifed as needing to be ignored in this function.
361+
/// The \p Region can be identified as needing to be ignored in this function.
336362
/// It should be checked whether it should be ignored after a call to this
337363
/// function.
338364
///
339365
/// \param [in,out] Region - The region of code to be analyzed.
340366
/// \param [out] InputGVNs - The global value numbers for the extracted
341367
/// arguments.
368+
/// \param [in] NotSame - The global value numbers in the region that do not
369+
/// have the same constant value in the regions structurally similar to
370+
/// \p Region.
342371
/// \param [out] ArgInputs - The values of the inputs to the extracted function.
343372
static void getCodeExtractorArguments(OutlinableRegion &Region,
344373
std::vector<unsigned> &InputGVNs,
374+
DenseSet<unsigned> &NotSame,
345375
SetVector<Value *> &ArgInputs) {
346376
IRSimilarityCandidate &C = *Region.Candidate;
347377

@@ -389,13 +419,18 @@ static void getCodeExtractorArguments(OutlinableRegion &Region,
389419
return;
390420
}
391421

422+
findConstants(C, NotSame, InputGVNs);
392423
mapInputsToGVNs(C, OverallInputs, InputGVNs);
424+
425+
// Sort the GVNs, since we now have constants included in the \ref InputGVNs
426+
// we need to make sure they are in a deterministic order.
427+
stable_sort(InputGVNs.begin(), InputGVNs.end());
393428
}
394429

395430
/// Look over the inputs and map each input argument to an argument in the
396-
/// overall function for the regions. This creates a way to replace the
397-
/// arguments of the extracted function, with the arguments of the new overall
398-
/// function.
431+
/// overall function for the OutlinableRegions. This creates a way to replace
432+
/// the arguments of the extracted function with the arguments of the new
433+
/// overall function.
399434
///
400435
/// \param [in,out] Region - The region of code to be analyzed.
401436
/// \param [in] InputsGVNs - The global value numbering of the input values
@@ -417,7 +452,10 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
417452
unsigned OriginalIndex = 0;
418453

419454
// Find the mapping of the extracted arguments to the arguments for the
420-
// overall function.
455+
// overall function. Since there may be extra arguments in the overall
456+
// function to account for the extracted constants, we have two different
457+
// counters as we find extracted arguments, and as we come across overall
458+
// arguments.
421459
for (unsigned InputVal : InputGVNs) {
422460
Optional<Value *> InputOpt = C.fromGVN(InputVal);
423461
assert(InputOpt.hasValue() && "Global value number not found?");
@@ -426,9 +464,16 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
426464
if (!Group.InputTypesSet)
427465
Group.ArgumentTypes.push_back(Input->getType());
428466

429-
// It is not a constant, check if it is a sunken alloca. If it is not,
430-
// create the mapping from extracted to overall. If it is, create the
431-
// mapping of the index to the value.
467+
// Check if we have a constant. If we do add it to the overall argument
468+
// number to Constant map for the region, and continue to the next input.
469+
if (Constant *CST = dyn_cast<Constant>(Input)) {
470+
Region.AggArgToConstant.insert(std::make_pair(TypeIndex, CST));
471+
TypeIndex++;
472+
continue;
473+
}
474+
475+
// It is not a constant, we create the mapping from extracted argument list
476+
// to the overall argument list.
432477
assert(ArgInputs.count(Input) && "Input cannot be found!");
433478

434479
Region.ExtractedArgToAgg.insert(std::make_pair(OriginalIndex, TypeIndex));
@@ -437,10 +482,10 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
437482
TypeIndex++;
438483
}
439484

440-
// If we do not have definitions for the OutlinableGroup holding the region,
441-
// set the length of the inputs here. We should have the same inputs for
442-
// all of the different regions contained in the OutlinableGroup since they
443-
// are all structurally similar to one another
485+
// If the function type definitions for the OutlinableGroup holding the region
486+
// have not been set, set the length of the inputs here. We should have the
487+
// same inputs for all of the different regions contained in the
488+
// OutlinableGroup since they are all structurally similar to one another.
444489
if (!Group.InputTypesSet) {
445490
Group.NumAggregateInputs = TypeIndex;
446491
Group.InputTypesSet = true;
@@ -449,11 +494,12 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region,
449494
Region.NumExtractedInputs = OriginalIndex;
450495
}
451496

452-
void IROutliner::findAddInputsOutputs(Module &M, OutlinableRegion &Region) {
497+
void IROutliner::findAddInputsOutputs(Module &M, OutlinableRegion &Region,
498+
DenseSet<unsigned> &NotSame) {
453499
std::vector<unsigned> Inputs;
454500
SetVector<Value *> ArgInputs;
455501

456-
getCodeExtractorArguments(Region, Inputs, ArgInputs);
502+
getCodeExtractorArguments(Region, Inputs, NotSame, ArgInputs);
457503

458504
if (Region.IgnoreRegion)
459505
return;
@@ -474,6 +520,7 @@ void IROutliner::findAddInputsOutputs(Module &M, OutlinableRegion &Region) {
474520
/// \returns a call instruction with the replaced function.
475521
CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) {
476522
std::vector<Value *> NewCallArgs;
523+
DenseMap<unsigned, unsigned>::iterator ArgPair;
477524

478525
OutlinableGroup &Group = *Region.Parent;
479526
CallInst *Call = Region.Call;
@@ -484,12 +531,72 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) {
484531
// If the arguments are the same size, there are not values that need to be
485532
// made argument, or different output registers to handle. We can simply
486533
// replace the called function in this case.
487-
assert(AggFunc->arg_size() == Call->arg_size() &&
488-
"Can only replace calls with the same number of arguments!");
534+
if (AggFunc->arg_size() == Call->arg_size()) {
535+
LLVM_DEBUG(dbgs() << "Replace call to " << *Call << " with call to "
536+
<< *AggFunc << " with same number of arguments\n");
537+
Call->setCalledFunction(AggFunc);
538+
return Call;
539+
}
540+
541+
// We have a different number of arguments than the new function, so
542+
// we need to use our previously mappings off extracted argument to overall
543+
// function argument, and constants to overall function argument to create the
544+
// new argument list.
545+
for (unsigned AggArgIdx = 0; AggArgIdx < AggFunc->arg_size(); AggArgIdx++) {
546+
547+
ArgPair = Region.AggArgToExtracted.find(AggArgIdx);
548+
if (ArgPair != Region.AggArgToExtracted.end()) {
549+
Value *ArgumentValue = Call->getArgOperand(ArgPair->second);
550+
// If we found the mapping from the extracted function to the overall
551+
// function, we simply add it to the argument list. We use the same
552+
// value, it just needs to honor the new order of arguments.
553+
LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to value "
554+
<< *ArgumentValue << "\n");
555+
NewCallArgs.push_back(ArgumentValue);
556+
continue;
557+
}
558+
559+
// If it is a constant, we simply add it to the argument list as a value.
560+
if (Region.AggArgToConstant.find(AggArgIdx) !=
561+
Region.AggArgToConstant.end()) {
562+
Constant *CST = Region.AggArgToConstant.find(AggArgIdx)->second;
563+
LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to value "
564+
<< *CST << "\n");
565+
NewCallArgs.push_back(CST);
566+
continue;
567+
}
568+
569+
// Add a nullptr value if the argument is not found in the extracted
570+
// function. If we cannot find a value, it means it is not in use
571+
// for the region, so we should not pass anything to it.
572+
LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to nullptr\n");
573+
NewCallArgs.push_back(ConstantPointerNull::get(
574+
static_cast<PointerType *>(AggFunc->getArg(AggArgIdx)->getType())));
575+
}
489576

490577
LLVM_DEBUG(dbgs() << "Replace call to " << *Call << " with call to "
491-
<< *AggFunc << " with same number of arguments\n");
492-
Call->setCalledFunction(AggFunc);
578+
<< *AggFunc << " with new set of arguments\n");
579+
// Create the new call instruction and erase the old one.
580+
Call = CallInst::Create(AggFunc->getFunctionType(), AggFunc, NewCallArgs, "",
581+
Call);
582+
583+
// It is possible that the call to the outlined function is either the first
584+
// instruction in the new block, the last instruction, or both. If either of
585+
// these is the case, we need to make sure that we replace the instruction in
586+
// the IRInstructionData struct with the new call.
587+
CallInst *OldCall = Region.Call;
588+
if (Region.NewFront->Inst == OldCall)
589+
Region.NewFront->Inst = Call;
590+
if (Region.NewBack->Inst == OldCall)
591+
Region.NewBack->Inst = Call;
592+
593+
// Transfer any debug information.
594+
Call->setDebugLoc(Region.Call->getDebugLoc());
595+
596+
// Remove the old instruction.
597+
OldCall->eraseFromParent();
598+
Region.Call = Call;
599+
493600
return Call;
494601
}
495602

@@ -518,6 +625,37 @@ static void replaceArgumentUses(OutlinableRegion &Region) {
518625
}
519626
}
520627

628+
/// Within an extracted function, replace the constants that need to be lifted
629+
/// into arguments with the actual argument.
630+
///
631+
/// \param Region [in] - The region of extracted code to be changed.
632+
void replaceConstants(OutlinableRegion &Region) {
633+
OutlinableGroup &Group = *Region.Parent;
634+
// Iterate over the constants that need to be elevated into arguments
635+
for (std::pair<unsigned, Constant *> &Const : Region.AggArgToConstant) {
636+
unsigned AggArgIdx = Const.first;
637+
Function *OutlinedFunction = Group.OutlinedFunction;
638+
assert(OutlinedFunction && "Overall Function is not defined?");
639+
Constant *CST = Const.second;
640+
Argument *Arg = Group.OutlinedFunction->getArg(AggArgIdx);
641+
// Identify the argument it will be elevated to, and replace instances of
642+
// that constant in the function.
643+
644+
// TODO: If in the future constants do not have one global value number,
645+
// i.e. a constant 1 could be mapped to several values, this check will
646+
// have to be more strict. It cannot be using only replaceUsesWithIf.
647+
648+
LLVM_DEBUG(dbgs() << "Replacing uses of constant " << *CST
649+
<< " in function " << *OutlinedFunction << " with "
650+
<< *Arg << "\n");
651+
CST->replaceUsesWithIf(Arg, [OutlinedFunction](Use &U) {
652+
if (Instruction *I = dyn_cast<Instruction>(U.getUser()))
653+
return I->getFunction() == OutlinedFunction;
654+
return false;
655+
});
656+
}
657+
}
658+
521659
/// Fill the new function that will serve as the replacement function for all of
522660
/// the extracted regions of a certain structure from the first region in the
523661
/// list of regions. Replace this first region's extracted function with the
@@ -544,6 +682,7 @@ static void fillOverallFunction(Module &M, OutlinableGroup &CurrentGroup,
544682
CurrentGroup.OutlinedFunction->addFnAttr(A);
545683

546684
replaceArgumentUses(*CurrentOS);
685+
replaceConstants(*CurrentOS);
547686

548687
// Replace the call to the extracted function with the outlined function.
549688
CurrentOS->Call = replaceCalledFunction(M, *CurrentOS);
@@ -738,7 +877,7 @@ unsigned IROutliner::doOutline(Module &M) {
738877
OS->CE = new (ExtractorAllocator.Allocate())
739878
CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
740879
false, "outlined");
741-
findAddInputsOutputs(M, *OS);
880+
findAddInputsOutputs(M, *OS, NotSame);
742881
if (!OS->IgnoreRegion)
743882
OutlinedRegions.push_back(OS);
744883
else

0 commit comments

Comments
 (0)