Skip to content

AMDGPU: Verify function type matches when matching libcalls #119043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ bool AMDGPULibCalls::fold(CallInst *CI) {

// Further check the number of arguments to see if they match.
// TODO: Check calling convention matches too
if (!FInfo.isCompatibleSignature(CI->getFunctionType()))
if (!FInfo.isCompatibleSignature(*Callee->getParent(), CI->getFunctionType()))
return false;

LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n');
Expand Down
64 changes: 47 additions & 17 deletions llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,17 +620,17 @@ bool ItaniumParamParser::parseItaniumParam(StringRef& param,
// parse type
char const TC = param.front();
if (isDigit(TC)) {
res.ArgType = StringSwitch<AMDGPULibFunc::EType>
(eatLengthPrefixedName(param))
.Case("ocl_image1darray" , AMDGPULibFunc::IMG1DA)
.Case("ocl_image1dbuffer", AMDGPULibFunc::IMG1DB)
.Case("ocl_image2darray" , AMDGPULibFunc::IMG2DA)
.Case("ocl_image1d" , AMDGPULibFunc::IMG1D)
.Case("ocl_image2d" , AMDGPULibFunc::IMG2D)
.Case("ocl_image3d" , AMDGPULibFunc::IMG3D)
.Case("ocl_event" , AMDGPULibFunc::DUMMY)
.Case("ocl_sampler" , AMDGPULibFunc::DUMMY)
.Default(AMDGPULibFunc::DUMMY);
res.ArgType =
StringSwitch<AMDGPULibFunc::EType>(eatLengthPrefixedName(param))
.Case("ocl_image1darray", AMDGPULibFunc::IMG1DA)
.Case("ocl_image1dbuffer", AMDGPULibFunc::IMG1DB)
.Case("ocl_image2darray", AMDGPULibFunc::IMG2DA)
.StartsWith("ocl_image1d", AMDGPULibFunc::IMG1D)
.StartsWith("ocl_image2d", AMDGPULibFunc::IMG2D)
.StartsWith("ocl_image3d", AMDGPULibFunc::IMG3D)
.Case("ocl_event", AMDGPULibFunc::DUMMY)
.Case("ocl_sampler", AMDGPULibFunc::DUMMY)
.Default(AMDGPULibFunc::DUMMY);
} else {
drop_front(param);
switch (TC) {
Expand Down Expand Up @@ -969,7 +969,7 @@ static Type* getIntrinsicParamType(
return T;
}

FunctionType *AMDGPUMangledLibFunc::getFunctionType(Module &M) const {
FunctionType *AMDGPUMangledLibFunc::getFunctionType(const Module &M) const {
LLVMContext& C = M.getContext();
std::vector<Type*> Args;
ParamIterator I(Leads, manglingRules[FuncId]);
Expand Down Expand Up @@ -997,9 +997,39 @@ std::string AMDGPUMangledLibFunc::getName() const {
return std::string(OS.str());
}

bool AMDGPULibFunc::isCompatibleSignature(const FunctionType *FuncTy) const {
// TODO: Validate types make sense
return !FuncTy->isVarArg() && FuncTy->getNumParams() == getNumArgs();
bool AMDGPULibFunc::isCompatibleSignature(const Module &M,
const FunctionType *CallTy) const {
const FunctionType *FuncTy = getFunctionType(M);

// FIXME: UnmangledFuncInfo does not have any type information other than the
// number of arguments.
if (!FuncTy)
return getNumArgs() == CallTy->getNumParams();

// Normally the types should exactly match.
if (FuncTy == CallTy)
return true;

const unsigned NumParams = FuncTy->getNumParams();
if (NumParams != CallTy->getNumParams())
return false;

for (unsigned I = 0; I != NumParams; ++I) {
Type *FuncArgTy = FuncTy->getParamType(I);
Type *CallArgTy = CallTy->getParamType(I);
if (FuncArgTy == CallArgTy)
continue;

// Some cases permit implicit splatting a scalar value to a vector argument.
auto *FuncVecTy = dyn_cast<VectorType>(FuncArgTy);
if (FuncVecTy && FuncVecTy->getElementType() == CallArgTy &&
allowsImplicitVectorSplat(I))
continue;

return false;
}

return true;
}

Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
Expand All @@ -1012,7 +1042,7 @@ Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
if (F->hasFnAttribute(Attribute::NoBuiltin))
return nullptr;

if (!fInfo.isCompatibleSignature(F->getFunctionType()))
if (!fInfo.isCompatibleSignature(*M, F->getFunctionType()))
return nullptr;

return F;
Expand All @@ -1028,7 +1058,7 @@ FunctionCallee AMDGPULibFunc::getOrInsertFunction(Module *M,
if (F->hasFnAttribute(Attribute::NoBuiltin))
return nullptr;
if (!F->isDeclaration() &&
fInfo.isCompatibleSignature(F->getFunctionType()))
fInfo.isCompatibleSignature(*M, F->getFunctionType()))
return F;
}

Expand Down
26 changes: 21 additions & 5 deletions llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class AMDGPULibFuncImpl : public AMDGPULibFuncBase {
void setName(StringRef N) { Name = std::string(N); }
void setPrefix(ENamePrefix pfx) { FKind = pfx; }

virtual FunctionType *getFunctionType(Module &M) const = 0;
virtual FunctionType *getFunctionType(const Module &M) const = 0;

protected:
EFuncId FuncId;
Expand Down Expand Up @@ -391,8 +391,22 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
return Impl->parseFuncName(MangledName);
}

/// Return true if it's legal to splat a scalar value passed in parameter \p
/// ArgIdx to a vector argument.
bool allowsImplicitVectorSplat(int ArgIdx) const {
switch (getId()) {
case EI_LDEXP:
return ArgIdx == 1;
case EI_FMIN:
case EI_FMAX:
return true;
default:
return false;
}
}

// Validate the call type matches the expected libfunc type.
bool isCompatibleSignature(const FunctionType *FuncTy) const;
bool isCompatibleSignature(const Module &M, const FunctionType *FuncTy) const;

/// \return The mangled function name for mangled library functions
/// and unmangled function name for unmangled library functions.
Expand All @@ -401,7 +415,7 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
void setName(StringRef N) { Impl->setName(N); }
void setPrefix(ENamePrefix PFX) { Impl->setPrefix(PFX); }

FunctionType *getFunctionType(Module &M) const {
FunctionType *getFunctionType(const Module &M) const {
return Impl->getFunctionType(M);
}
static Function *getFunction(llvm::Module *M, const AMDGPULibFunc &fInfo);
Expand All @@ -428,7 +442,7 @@ class AMDGPUMangledLibFunc : public AMDGPULibFuncImpl {

std::string getName() const override;
unsigned getNumArgs() const override;
FunctionType *getFunctionType(Module &M) const override;
FunctionType *getFunctionType(const Module &M) const override;
static StringRef getUnmangledName(StringRef MangledName);

bool parseFuncName(StringRef &mangledName) override;
Expand Down Expand Up @@ -458,7 +472,9 @@ class AMDGPUUnmangledLibFunc : public AMDGPULibFuncImpl {
}
std::string getName() const override { return Name; }
unsigned getNumArgs() const override;
FunctionType *getFunctionType(Module &M) const override { return FuncTy; }
FunctionType *getFunctionType(const Module &M) const override {
return FuncTy;
}

bool parseFuncName(StringRef &Name) override;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -passes=amdgpu-simplifylib %s | FileCheck %s

; Make sure we can produce a valid FunctionType for the expected
; signature of image functions.

declare i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4))

define i32 @call_ocl_image2d_depth(ptr addrspace(4) %img) {
; CHECK-LABEL: define i32 @call_ocl_image2d_depth(
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4) [[IMG]])
; CHECK-NEXT: ret i32 [[RESULT]]
;
%result = call i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4) %img)
ret i32 %result
}

declare i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4))

define i32 @call_ocl_image3d_depth(ptr addrspace(4) %img) {
; CHECK-LABEL: define i32 @call_ocl_image3d_depth(
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4) [[IMG]])
; CHECK-NEXT: ret i32 [[RESULT]]
;
%result = call i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4) %img)
ret i32 %result
}

declare i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4))

define i32 @call_get_image_width14ocl_image1d_ro(ptr addrspace(4) %img) {
; CHECK-LABEL: define i32 @call_get_image_width14ocl_image1d_ro(
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4) [[IMG]])
; CHECK-NEXT: ret i32 [[RESULT]]
;
%result = call i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4) %img)
ret i32 %result
}

declare <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4))

define <2 x i32> @call_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) %img) {
; CHECK-LABEL: define <2 x i32> @call_Z13get_image_dim20ocl_image2d_array_ro(
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
; CHECK-NEXT: [[RESULT:%.*]] = call <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) [[IMG]])
; CHECK-NEXT: ret <2 x i32> [[RESULT]]
;
%result = call <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) %img)
ret <2 x i32> %result
}

declare i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4))

define i32 @call_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) %img) {
; CHECK-LABEL: define i32 @call_Z15get_image_width20ocl_image1d_array_ro(
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) [[IMG]])
; CHECK-NEXT: ret i32 [[RESULT]]
;
%result = call i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) %img)
ret i32 %result
}
41 changes: 41 additions & 0 deletions llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,47 @@ define float @test_ldexp_f32_strictfp(float %x, i32 %y) #4 {
ret float %ldexp
}

;---------------------------------------------------------------------
; Invalid signatures
;---------------------------------------------------------------------

; Declared with wrong type, second argument is float
declare float @_Z5ldexpff(float noundef, float noundef)

define float @call_wrong_typed_ldexp_f32_second_arg(float %x, float %wrongtype) {
; CHECK-LABEL: define float @call_wrong_typed_ldexp_f32_second_arg
; CHECK-SAME: (float [[X:%.*]], float [[WRONGTYPE:%.*]]) {
; CHECK-NEXT: [[CALL:%.*]] = call float @_Z5ldexpff(float [[X]], float [[WRONGTYPE]])
; CHECK-NEXT: ret float [[CALL]]
;
%call = call float @_Z5ldexpff(float %x, float %wrongtype)
ret float %call
}

declare <2 x float> @_Z5ldexpDv2_fS_(<2 x float>, <2 x float>)

define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg(<2 x float> %x, <2 x float> %wrongtype) {
; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg
; CHECK-SAME: (<2 x float> [[X:%.*]], <2 x float> [[WRONGTYPE:%.*]]) {
; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> [[X]], <2 x float> [[WRONGTYPE]])
; CHECK-NEXT: ret <2 x float> [[CALL]]
;
%call = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> %x, <2 x float> %wrongtype)
ret <2 x float> %call
}

declare <2 x float> @_Z5ldexpDv2_ff(<2 x float>, float)

define <2 x float> @call_wrong_typed_ldexp_v2f32_f32(<2 x float> %x, float %wrongtype) {
; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_f32
; CHECK-SAME: (<2 x float> [[X:%.*]], float [[WRONGTYPE:%.*]]) {
; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> [[X]], float [[WRONGTYPE]])
; CHECK-NEXT: ret <2 x float> [[CALL]]
;
%call = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> %x, float %wrongtype)
ret <2 x float> %call
}

attributes #0 = { nobuiltin }
attributes #1 = { "no-builtins" }
attributes #2 = { nounwind memory(none) }
Expand Down
Loading