Skip to content

Commit 4aaa925

Browse files
authored
[llvm][CodeExtractor] fix bug in parameter naming (llvm#114237)
The code extractor tries to apply the names of source input and output values to function arguments. Not all input and output values get added as arguments: some are instead placed inside of a struct passed to the function. The existing renaming code skipped trying to set these struct-packed arguments names (as there is no corresponding function argument to rename), but it still incremented the iterator over the function arguments. This could result in dereferencing an end iterator if struct-packed inputs/outputs preceded non-struct-packed inputs/outputs. This patch rewrites this loop to avoid the end iterator dereference.
1 parent e28d7f7 commit 4aaa925

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -823,17 +823,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
823823

824824
std::vector<Type *> ParamTy;
825825
std::vector<Type *> AggParamTy;
826+
std::vector<std::tuple<unsigned, Value *>> NumberedInputs;
827+
std::vector<std::tuple<unsigned, Value *>> NumberedOutputs;
826828
ValueSet StructValues;
827829
const DataLayout &DL = M->getDataLayout();
828830

829831
// Add the types of the input values to the function's argument list
832+
unsigned ArgNum = 0;
830833
for (Value *value : inputs) {
831834
LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
832835
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
833836
AggParamTy.push_back(value->getType());
834837
StructValues.insert(value);
835-
} else
838+
} else {
836839
ParamTy.push_back(value->getType());
840+
NumberedInputs.emplace_back(ArgNum++, value);
841+
}
837842
}
838843

839844
// Add the types of the output values to the function's argument list.
@@ -842,9 +847,11 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
842847
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
843848
AggParamTy.push_back(output->getType());
844849
StructValues.insert(output);
845-
} else
850+
} else {
846851
ParamTy.push_back(
847852
PointerType::get(output->getType(), DL.getAllocaAddrSpace()));
853+
NumberedOutputs.emplace_back(ArgNum++, output);
854+
}
848855
}
849856

850857
assert(
@@ -1053,15 +1060,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
10531060
}
10541061

10551062
// Set names for input and output arguments.
1056-
if (NumScalarParams) {
1057-
ScalarAI = newFunction->arg_begin();
1058-
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
1059-
if (!StructValues.contains(inputs[i]))
1060-
ScalarAI->setName(inputs[i]->getName());
1061-
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
1062-
if (!StructValues.contains(outputs[i]))
1063-
ScalarAI->setName(outputs[i]->getName() + ".out");
1064-
}
1063+
for (auto [i, argVal] : NumberedInputs)
1064+
newFunction->getArg(i)->setName(argVal->getName());
1065+
for (auto [i, argVal] : NumberedOutputs)
1066+
newFunction->getArg(i)->setName(argVal->getName() + ".out");
10651067

10661068
// Rewrite branches to basic blocks outside of the loop to new dummy blocks
10671069
// within the new function. This must be done before we lose track of which

llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,53 @@ TEST(CodeExtractor, PartialAggregateArgs) {
556556
EXPECT_FALSE(verifyFunction(*Func));
557557
}
558558

559+
/// Regression test to ensure we don't crash trying to set the name of the ptr
560+
/// argument
561+
TEST(CodeExtractor, PartialAggregateArgs2) {
562+
LLVMContext Ctx;
563+
SMDiagnostic Err;
564+
std::unique_ptr<Module> M(parseAssemblyString(R"ir(
565+
declare void @usei(i32)
566+
declare void @usep(ptr)
567+
568+
define void @foo(i32 %a, i32 %b, ptr %p) {
569+
entry:
570+
br label %extract
571+
572+
extract:
573+
call void @usei(i32 %a)
574+
call void @usei(i32 %b)
575+
call void @usep(ptr %p)
576+
br label %exit
577+
578+
exit:
579+
ret void
580+
}
581+
)ir",
582+
Err, Ctx));
583+
584+
Function *Func = M->getFunction("foo");
585+
SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
586+
587+
// Create the CodeExtractor with arguments aggregation enabled.
588+
CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
589+
/* AggregateArgs */ true);
590+
EXPECT_TRUE(CE.isEligible());
591+
592+
CodeExtractorAnalysisCache CEAC(*Func);
593+
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
594+
BasicBlock *CommonExit = nullptr;
595+
CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
596+
CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
597+
// Exclude the last input from the argument aggregate.
598+
CE.excludeArgFromAggregate(Inputs[2]);
599+
600+
Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
601+
EXPECT_TRUE(Outlined);
602+
EXPECT_FALSE(verifyFunction(*Outlined));
603+
EXPECT_FALSE(verifyFunction(*Func));
604+
}
605+
559606
TEST(CodeExtractor, OpenMPAggregateArgs) {
560607
LLVMContext Ctx;
561608
SMDiagnostic Err;

0 commit comments

Comments
 (0)