Skip to content

[MergeFunc] Fix crash caused by bitcasting ArrayType #133259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/IR/IRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,13 @@ class IRBuilderBase {
// isSigned parameter.
Value *CreateIntCast(Value *, Type *, const char *) = delete;

/// Cast between aggregate types that must have identical structure but may
/// differ in their leaf types. The leaf values are recursively extracted,
/// casted, and then reinserted into a value of type DestTy. The leaf types
/// must be castable using a bitcast or ptrcast, because signedness is
/// not specified.
Value *CreateAggregateCast(Value *V, Type *DestTy);

//===--------------------------------------------------------------------===//
// Instruction creation methods: Compare Instructions
//===--------------------------------------------------------------------===//
Expand Down
47 changes: 5 additions & 42 deletions llvm/lib/CodeGen/GlobalMergeFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,44 +140,6 @@ static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
return true;
}

static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
Type *SrcTy = V->getType();
if (SrcTy->isStructTy()) {
assert(DestTy->isStructTy());
assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
Value *Result = PoisonValue::get(DestTy);
for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
Value *Element =
createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
DestTy->getStructElementType(I));

Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
}
return Result;
}
assert(!DestTy->isStructTy());
if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
auto *DestAT = dyn_cast<ArrayType>(DestTy);
assert(DestAT);
assert(SrcAT->getNumElements() == DestAT->getNumElements());
Value *Result = PoisonValue::get(DestTy);
for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
Value *Element =
createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
DestAT->getElementType());

Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
}
return Result;
}
assert(!DestTy->isArrayTy());
if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
return Builder.CreateIntToPtr(V, DestTy);
if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
return Builder.CreatePtrToInt(V, DestTy);
return Builder.CreateBitCast(V, DestTy);
}

void GlobalMergeFunc::analyze(Module &M) {
++NumAnalyzedModues;
for (Function &Func : M) {
Expand Down Expand Up @@ -268,7 +230,7 @@ static Function *createMergedFunction(FuncMergeInfo &FI,
if (OrigC->getType() != NewArg->getType()) {
IRBuilder<> Builder(Inst->getParent(), Inst->getIterator());
Inst->setOperand(OpndIndex,
createCast(Builder, NewArg, OrigC->getType()));
Builder.CreateAggregateCast(NewArg, OrigC->getType()));
} else {
Inst->setOperand(OpndIndex, NewArg);
}
Expand Down Expand Up @@ -297,15 +259,16 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,

// Add arguments which are passed through Thunk.
for (Argument &AI : Thunk->args()) {
Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx)));
Args.push_back(
Builder.CreateAggregateCast(&AI, ToFuncTy->getParamType(ParamIdx)));
++ParamIdx;
}

// Add new arguments defined by Params.
for (auto *Param : Params) {
assert(ParamIdx < ToFuncTy->getNumParams());
Args.push_back(
createCast(Builder, Param, ToFuncTy->getParamType(ParamIdx)));
Builder.CreateAggregateCast(Param, ToFuncTy->getParamType(ParamIdx)));
++ParamIdx;
}

Expand All @@ -319,7 +282,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
if (Thunk->getReturnType()->isVoidTy())
Builder.CreateRetVoid();
else
Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType()));
Builder.CreateRet(Builder.CreateAggregateCast(CI, Thunk->getReturnType()));
}

// Check if the old merged/optimized IndexOperandHashMap is compatible with
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/IR/IRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,40 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const {
}
}

Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
Type *SrcTy = V->getType();
if (SrcTy == DestTy)
return V;

if (SrcTy->isAggregateType()) {
unsigned NumElements;
if (SrcTy->isStructTy()) {
assert(DestTy->isStructTy() && "Expected StructType");
assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements() &&
"Expected StructTypes with equal number of elements");
NumElements = SrcTy->getStructNumElements();
} else {
assert(SrcTy->isArrayTy() && DestTy->isArrayTy() && "Expected ArrayType");
assert(SrcTy->getArrayNumElements() == DestTy->getArrayNumElements() &&
"Expected ArrayTypes with equal number of elements");
NumElements = SrcTy->getArrayNumElements();
}

Value *Result = PoisonValue::get(DestTy);
for (unsigned I = 0; I < NumElements; ++I) {
Type *ElementTy = SrcTy->isStructTy() ? DestTy->getStructElementType(I)
: DestTy->getArrayElementType();
Value *Element =
CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)), ElementTy);

Result = CreateInsertValue(Result, Element, ArrayRef(I));
}
return Result;
}

return CreateBitOrPointerCast(V, DestTy);
}

CallInst *
IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
const Twine &Name, FMFSource FMFSource,
Expand Down
31 changes: 2 additions & 29 deletions llvm/lib/Transforms/IPO/MergeFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,33 +511,6 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
}
}

// Helper for writeThunk,
// Selects proper bitcast operation,
// but a bit simpler then CastInst::getCastOpcode.
static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
Type *SrcTy = V->getType();
if (SrcTy->isStructTy()) {
assert(DestTy->isStructTy());
assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
Value *Result = PoisonValue::get(DestTy);
for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
Value *Element =
createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
DestTy->getStructElementType(I));

Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
}
return Result;
}
assert(!DestTy->isStructTy());
if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
return Builder.CreateIntToPtr(V, DestTy);
else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
return Builder.CreatePtrToInt(V, DestTy);
else
return Builder.CreateBitCast(V, DestTy);
}

// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
// parameter debug info, from the entry block.
void MergeFunctions::eraseInstsUnrelatedToPDI(
Expand Down Expand Up @@ -789,7 +762,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
unsigned i = 0;
FunctionType *FFTy = F->getFunctionType();
for (Argument &AI : H->args()) {
Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
Args.push_back(Builder.CreateAggregateCast(&AI, FFTy->getParamType(i)));
++i;
}

Expand All @@ -804,7 +777,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
if (H->getReturnType()->isVoidTy()) {
RI = Builder.CreateRetVoid();
} else {
RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
RI = Builder.CreateRet(Builder.CreateAggregateCast(CI, H->getReturnType()));
}

if (MergeFunctionsPDI) {
Expand Down
76 changes: 76 additions & 0 deletions llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
; RUN: opt -S -passes=mergefunc < %s | FileCheck %s

target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"

%A = type { double }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test that requires a bitcast e.g. where one leaf type is double and the other i64? Not sure if such types would get handled by mergefunc though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added ptrcast test instead of bitcast. I checked the code in FunctionComparator and searched the tests, and couldn't find a type that causes a bitcast in the thunk function.

; the intermediary struct causes A_arr and B_arr to be different types
%A_struct = type { %A }
%A_arr = type { [1 x %A_struct] }

%B = type { double }
%B_struct = type { %B }
%B_arr = type { [1 x %B_struct] }

; conversion between C_arr and D_arr is possible, but requires ptrcast
%C = type { i64 }
%C_struct = type { %C }
%C_arr = type { [1 x %C_struct] }

%D = type { ptr }
%D_struct = type { %D }
%D_arr = type { [1 x %D_struct] }

declare void @noop()

define %A_arr @a() {
; CHECK-LABEL: define %A_arr @a() {
; CHECK-NEXT: call void @noop()
; CHECK-NEXT: ret %A_arr zeroinitializer
;
call void @noop()
ret %A_arr zeroinitializer
}

define %C_arr @c() {
; CHECK-LABEL: define %C_arr @c() {
; CHECK-NEXT: call void @noop()
; CHECK-NEXT: ret %C_arr zeroinitializer
;
call void @noop()
ret %C_arr zeroinitializer
}

define %B_arr @b() {
; CHECK-LABEL: define %B_arr @b() {
; CHECK-NEXT: [[TMP1:%.*]] = tail call %A_arr @a
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue %A_arr [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue [1 x %A_struct] [[TMP2]], 0
; CHECK-NEXT: [[TMP4:%.*]] = extractvalue %A_struct [[TMP3]], 0
; CHECK-NEXT: [[TMP5:%.*]] = extractvalue %A [[TMP4]], 0
; CHECK-NEXT: [[TMP6:%.*]] = insertvalue %B poison, double [[TMP5]], 0
; CHECK-NEXT: [[TMP7:%.*]] = insertvalue %B_struct poison, %B [[TMP6]], 0
; CHECK-NEXT: [[TMP8:%.*]] = insertvalue [1 x %B_struct] poison, %B_struct [[TMP7]], 0
; CHECK-NEXT: [[TMP9:%.*]] = insertvalue %B_arr poison, [1 x %B_struct] [[TMP8]], 0
; CHECK-NEXT: ret %B_arr [[TMP9]]
;
call void @noop()
ret %B_arr zeroinitializer
}

define %D_arr @d() {
; CHECK-LABEL: define %D_arr @d() {
; CHECK-NEXT: [[TMP1:%.*]] = tail call %C_arr @c
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue %C_arr [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue [1 x %C_struct] [[TMP2]], 0
; CHECK-NEXT: [[TMP4:%.*]] = extractvalue %C_struct [[TMP3]], 0
; CHECK-NEXT: [[TMP5:%.*]] = extractvalue %C [[TMP4]], 0
; CHECK-NEXT: [[TMP10:%.*]] = inttoptr i64 [[TMP5]] to ptr
; CHECK-NEXT: [[TMP6:%.*]] = insertvalue %D poison, ptr [[TMP10]], 0
; CHECK-NEXT: [[TMP7:%.*]] = insertvalue %D_struct poison, %D [[TMP6]], 0
; CHECK-NEXT: [[TMP8:%.*]] = insertvalue [1 x %D_struct] poison, %D_struct [[TMP7]], 0
; CHECK-NEXT: [[TMP9:%.*]] = insertvalue %D_arr poison, [1 x %D_struct] [[TMP8]], 0
; CHECK-NEXT: ret %D_arr [[TMP9]]
;
call void @noop()
ret %D_arr zeroinitializer
}