Skip to content

Commit b446c20

Browse files
authored
AMDGPU: Verify function type matches when matching libcalls (#119043)
Previously this would recognize a call to a mangled ldexp(float, float) as a candidate to replace with the intrinsic. We need to verify the second parameter is in fact an integer. Fixes: SWDEV-501389
1 parent 9ba7e2d commit b446c20

File tree

5 files changed

+175
-23
lines changed

5 files changed

+175
-23
lines changed

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ bool AMDGPULibCalls::fold(CallInst *CI) {
654654

655655
// Further check the number of arguments to see if they match.
656656
// TODO: Check calling convention matches too
657-
if (!FInfo.isCompatibleSignature(CI->getFunctionType()))
657+
if (!FInfo.isCompatibleSignature(*Callee->getParent(), CI->getFunctionType()))
658658
return false;
659659

660660
LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n');

llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -620,17 +620,17 @@ bool ItaniumParamParser::parseItaniumParam(StringRef& param,
620620
// parse type
621621
char const TC = param.front();
622622
if (isDigit(TC)) {
623-
res.ArgType = StringSwitch<AMDGPULibFunc::EType>
624-
(eatLengthPrefixedName(param))
625-
.Case("ocl_image1darray" , AMDGPULibFunc::IMG1DA)
626-
.Case("ocl_image1dbuffer", AMDGPULibFunc::IMG1DB)
627-
.Case("ocl_image2darray" , AMDGPULibFunc::IMG2DA)
628-
.Case("ocl_image1d" , AMDGPULibFunc::IMG1D)
629-
.Case("ocl_image2d" , AMDGPULibFunc::IMG2D)
630-
.Case("ocl_image3d" , AMDGPULibFunc::IMG3D)
631-
.Case("ocl_event" , AMDGPULibFunc::DUMMY)
632-
.Case("ocl_sampler" , AMDGPULibFunc::DUMMY)
633-
.Default(AMDGPULibFunc::DUMMY);
623+
res.ArgType =
624+
StringSwitch<AMDGPULibFunc::EType>(eatLengthPrefixedName(param))
625+
.Case("ocl_image1darray", AMDGPULibFunc::IMG1DA)
626+
.Case("ocl_image1dbuffer", AMDGPULibFunc::IMG1DB)
627+
.Case("ocl_image2darray", AMDGPULibFunc::IMG2DA)
628+
.StartsWith("ocl_image1d", AMDGPULibFunc::IMG1D)
629+
.StartsWith("ocl_image2d", AMDGPULibFunc::IMG2D)
630+
.StartsWith("ocl_image3d", AMDGPULibFunc::IMG3D)
631+
.Case("ocl_event", AMDGPULibFunc::DUMMY)
632+
.Case("ocl_sampler", AMDGPULibFunc::DUMMY)
633+
.Default(AMDGPULibFunc::DUMMY);
634634
} else {
635635
drop_front(param);
636636
switch (TC) {
@@ -969,7 +969,7 @@ static Type* getIntrinsicParamType(
969969
return T;
970970
}
971971

972-
FunctionType *AMDGPUMangledLibFunc::getFunctionType(Module &M) const {
972+
FunctionType *AMDGPUMangledLibFunc::getFunctionType(const Module &M) const {
973973
LLVMContext& C = M.getContext();
974974
std::vector<Type*> Args;
975975
ParamIterator I(Leads, manglingRules[FuncId]);
@@ -997,9 +997,39 @@ std::string AMDGPUMangledLibFunc::getName() const {
997997
return std::string(OS.str());
998998
}
999999

1000-
bool AMDGPULibFunc::isCompatibleSignature(const FunctionType *FuncTy) const {
1001-
// TODO: Validate types make sense
1002-
return !FuncTy->isVarArg() && FuncTy->getNumParams() == getNumArgs();
1000+
bool AMDGPULibFunc::isCompatibleSignature(const Module &M,
1001+
const FunctionType *CallTy) const {
1002+
const FunctionType *FuncTy = getFunctionType(M);
1003+
1004+
// FIXME: UnmangledFuncInfo does not have any type information other than the
1005+
// number of arguments.
1006+
if (!FuncTy)
1007+
return getNumArgs() == CallTy->getNumParams();
1008+
1009+
// Normally the types should exactly match.
1010+
if (FuncTy == CallTy)
1011+
return true;
1012+
1013+
const unsigned NumParams = FuncTy->getNumParams();
1014+
if (NumParams != CallTy->getNumParams())
1015+
return false;
1016+
1017+
for (unsigned I = 0; I != NumParams; ++I) {
1018+
Type *FuncArgTy = FuncTy->getParamType(I);
1019+
Type *CallArgTy = CallTy->getParamType(I);
1020+
if (FuncArgTy == CallArgTy)
1021+
continue;
1022+
1023+
// Some cases permit implicit splatting a scalar value to a vector argument.
1024+
auto *FuncVecTy = dyn_cast<VectorType>(FuncArgTy);
1025+
if (FuncVecTy && FuncVecTy->getElementType() == CallArgTy &&
1026+
allowsImplicitVectorSplat(I))
1027+
continue;
1028+
1029+
return false;
1030+
}
1031+
1032+
return true;
10031033
}
10041034

10051035
Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
@@ -1012,7 +1042,7 @@ Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
10121042
if (F->hasFnAttribute(Attribute::NoBuiltin))
10131043
return nullptr;
10141044

1015-
if (!fInfo.isCompatibleSignature(F->getFunctionType()))
1045+
if (!fInfo.isCompatibleSignature(*M, F->getFunctionType()))
10161046
return nullptr;
10171047

10181048
return F;
@@ -1028,7 +1058,7 @@ FunctionCallee AMDGPULibFunc::getOrInsertFunction(Module *M,
10281058
if (F->hasFnAttribute(Attribute::NoBuiltin))
10291059
return nullptr;
10301060
if (!F->isDeclaration() &&
1031-
fInfo.isCompatibleSignature(F->getFunctionType()))
1061+
fInfo.isCompatibleSignature(*M, F->getFunctionType()))
10321062
return F;
10331063
}
10341064

llvm/lib/Target/AMDGPU/AMDGPULibFunc.h

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ class AMDGPULibFuncImpl : public AMDGPULibFuncBase {
352352
void setName(StringRef N) { Name = std::string(N); }
353353
void setPrefix(ENamePrefix pfx) { FKind = pfx; }
354354

355-
virtual FunctionType *getFunctionType(Module &M) const = 0;
355+
virtual FunctionType *getFunctionType(const Module &M) const = 0;
356356

357357
protected:
358358
EFuncId FuncId;
@@ -391,8 +391,22 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
391391
return Impl->parseFuncName(MangledName);
392392
}
393393

394+
/// Return true if it's legal to splat a scalar value passed in parameter \p
395+
/// ArgIdx to a vector argument.
396+
bool allowsImplicitVectorSplat(int ArgIdx) const {
397+
switch (getId()) {
398+
case EI_LDEXP:
399+
return ArgIdx == 1;
400+
case EI_FMIN:
401+
case EI_FMAX:
402+
return true;
403+
default:
404+
return false;
405+
}
406+
}
407+
394408
// Validate the call type matches the expected libfunc type.
395-
bool isCompatibleSignature(const FunctionType *FuncTy) const;
409+
bool isCompatibleSignature(const Module &M, const FunctionType *FuncTy) const;
396410

397411
/// \return The mangled function name for mangled library functions
398412
/// and unmangled function name for unmangled library functions.
@@ -401,7 +415,7 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
401415
void setName(StringRef N) { Impl->setName(N); }
402416
void setPrefix(ENamePrefix PFX) { Impl->setPrefix(PFX); }
403417

404-
FunctionType *getFunctionType(Module &M) const {
418+
FunctionType *getFunctionType(const Module &M) const {
405419
return Impl->getFunctionType(M);
406420
}
407421
static Function *getFunction(llvm::Module *M, const AMDGPULibFunc &fInfo);
@@ -428,7 +442,7 @@ class AMDGPUMangledLibFunc : public AMDGPULibFuncImpl {
428442

429443
std::string getName() const override;
430444
unsigned getNumArgs() const override;
431-
FunctionType *getFunctionType(Module &M) const override;
445+
FunctionType *getFunctionType(const Module &M) const override;
432446
static StringRef getUnmangledName(StringRef MangledName);
433447

434448
bool parseFuncName(StringRef &mangledName) override;
@@ -458,7 +472,9 @@ class AMDGPUUnmangledLibFunc : public AMDGPULibFuncImpl {
458472
}
459473
std::string getName() const override { return Name; }
460474
unsigned getNumArgs() const override;
461-
FunctionType *getFunctionType(Module &M) const override { return FuncTy; }
475+
FunctionType *getFunctionType(const Module &M) const override {
476+
return FuncTy;
477+
}
462478

463479
bool parseFuncName(StringRef &Name) override;
464480

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -passes=amdgpu-simplifylib %s | FileCheck %s
3+
4+
; Make sure we can produce a valid FunctionType for the expected
5+
; signature of image functions.
6+
7+
declare i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4))
8+
9+
define i32 @call_ocl_image2d_depth(ptr addrspace(4) %img) {
10+
; CHECK-LABEL: define i32 @call_ocl_image2d_depth(
11+
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
12+
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4) [[IMG]])
13+
; CHECK-NEXT: ret i32 [[RESULT]]
14+
;
15+
%result = call i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4) %img)
16+
ret i32 %result
17+
}
18+
19+
declare i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4))
20+
21+
define i32 @call_ocl_image3d_depth(ptr addrspace(4) %img) {
22+
; CHECK-LABEL: define i32 @call_ocl_image3d_depth(
23+
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
24+
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4) [[IMG]])
25+
; CHECK-NEXT: ret i32 [[RESULT]]
26+
;
27+
%result = call i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4) %img)
28+
ret i32 %result
29+
}
30+
31+
declare i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4))
32+
33+
define i32 @call_get_image_width14ocl_image1d_ro(ptr addrspace(4) %img) {
34+
; CHECK-LABEL: define i32 @call_get_image_width14ocl_image1d_ro(
35+
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
36+
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4) [[IMG]])
37+
; CHECK-NEXT: ret i32 [[RESULT]]
38+
;
39+
%result = call i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4) %img)
40+
ret i32 %result
41+
}
42+
43+
declare <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4))
44+
45+
define <2 x i32> @call_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) %img) {
46+
; CHECK-LABEL: define <2 x i32> @call_Z13get_image_dim20ocl_image2d_array_ro(
47+
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
48+
; CHECK-NEXT: [[RESULT:%.*]] = call <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) [[IMG]])
49+
; CHECK-NEXT: ret <2 x i32> [[RESULT]]
50+
;
51+
%result = call <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) %img)
52+
ret <2 x i32> %result
53+
}
54+
55+
declare i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4))
56+
57+
define i32 @call_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) %img) {
58+
; CHECK-LABEL: define i32 @call_Z15get_image_width20ocl_image1d_array_ro(
59+
; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
60+
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) [[IMG]])
61+
; CHECK-NEXT: ret i32 [[RESULT]]
62+
;
63+
%result = call i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) %img)
64+
ret i32 %result
65+
}

llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,47 @@ define float @test_ldexp_f32_strictfp(float %x, i32 %y) #4 {
242242
ret float %ldexp
243243
}
244244

245+
;---------------------------------------------------------------------
246+
; Invalid signatures
247+
;---------------------------------------------------------------------
248+
249+
; Declared with wrong type, second argument is float
250+
declare float @_Z5ldexpff(float noundef, float noundef)
251+
252+
define float @call_wrong_typed_ldexp_f32_second_arg(float %x, float %wrongtype) {
253+
; CHECK-LABEL: define float @call_wrong_typed_ldexp_f32_second_arg
254+
; CHECK-SAME: (float [[X:%.*]], float [[WRONGTYPE:%.*]]) {
255+
; CHECK-NEXT: [[CALL:%.*]] = call float @_Z5ldexpff(float [[X]], float [[WRONGTYPE]])
256+
; CHECK-NEXT: ret float [[CALL]]
257+
;
258+
%call = call float @_Z5ldexpff(float %x, float %wrongtype)
259+
ret float %call
260+
}
261+
262+
declare <2 x float> @_Z5ldexpDv2_fS_(<2 x float>, <2 x float>)
263+
264+
define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg(<2 x float> %x, <2 x float> %wrongtype) {
265+
; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg
266+
; CHECK-SAME: (<2 x float> [[X:%.*]], <2 x float> [[WRONGTYPE:%.*]]) {
267+
; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> [[X]], <2 x float> [[WRONGTYPE]])
268+
; CHECK-NEXT: ret <2 x float> [[CALL]]
269+
;
270+
%call = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> %x, <2 x float> %wrongtype)
271+
ret <2 x float> %call
272+
}
273+
274+
declare <2 x float> @_Z5ldexpDv2_ff(<2 x float>, float)
275+
276+
define <2 x float> @call_wrong_typed_ldexp_v2f32_f32(<2 x float> %x, float %wrongtype) {
277+
; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_f32
278+
; CHECK-SAME: (<2 x float> [[X:%.*]], float [[WRONGTYPE:%.*]]) {
279+
; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> [[X]], float [[WRONGTYPE]])
280+
; CHECK-NEXT: ret <2 x float> [[CALL]]
281+
;
282+
%call = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> %x, float %wrongtype)
283+
ret <2 x float> %call
284+
}
285+
245286
attributes #0 = { nobuiltin }
246287
attributes #1 = { "no-builtins" }
247288
attributes #2 = { nounwind memory(none) }

0 commit comments

Comments
 (0)