Skip to content

[IR] Check callee param attributes as well in CallBase::getParamAttr() #91394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions llvm/include/llvm/IR/InstrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1997,13 +1997,19 @@ class CallBase : public Instruction {
/// Get the attribute of a given kind from a given arg
Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const {
assert(ArgNo < arg_size() && "Out of bounds");
return getAttributes().getParamAttr(ArgNo, Kind);
Attribute A = getAttributes().getParamAttr(ArgNo, Kind);
if (A.isValid())
return A;
return getParamAttrOnCalledFunction(ArgNo, Kind);
}

/// Get the attribute of a given kind from a given arg
Attribute getParamAttr(unsigned ArgNo, StringRef Kind) const {
assert(ArgNo < arg_size() && "Out of bounds");
return getAttributes().getParamAttr(ArgNo, Kind);
Attribute A = getAttributes().getParamAttr(ArgNo, Kind);
if (A.isValid())
return A;
return getParamAttrOnCalledFunction(ArgNo, Kind);
}

/// Return true if the data operand at index \p i has the attribute \p
Expand Down Expand Up @@ -2647,6 +2653,8 @@ class CallBase : public Instruction {
return hasFnAttrOnCalledFunction(Kind);
}
template <typename AK> Attribute getFnAttrOnCalledFunction(AK Kind) const;
template <typename AK>
Attribute getParamAttrOnCalledFunction(unsigned ArgNo, AK Kind) const;

/// Determine whether the return value has the given attribute. Supports
/// Attribute::AttrKind and StringRef as \p AttrKind types.
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,22 @@ template Attribute
CallBase::getFnAttrOnCalledFunction(Attribute::AttrKind Kind) const;
template Attribute CallBase::getFnAttrOnCalledFunction(StringRef Kind) const;

template <typename AK>
Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
AK Kind) const {
Value *V = getCalledOperand();

if (auto *F = dyn_cast<Function>(V))
return F->getAttributes().getParamAttr(ArgNo, Kind);

return Attribute();
}
template Attribute
CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
Attribute::AttrKind Kind) const;
template Attribute CallBase::getParamAttrOnCalledFunction(unsigned ArgNo,
StringRef Kind) const;

void CallBase::getOperandBundlesAsDefs(
SmallVectorImpl<OperandBundleDef> &Defs) const {
for (unsigned i = 0, e = getNumOperandBundles(); i != e; ++i)
Expand Down
46 changes: 46 additions & 0 deletions llvm/unittests/IR/AttributesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,4 +340,50 @@ TEST(Attributes, ConstantRangeAttributeCAPI) {
}
}

TEST(Attributes, CalleeAttributes) {
const char *IRString = R"IR(
declare void @f1(i32 %i)
declare void @f2(i32 range(i32 1, 2) %i)
define void @g1(i32 %i) {
call void @f1(i32 %i)
ret void
}
define void @g2(i32 %i) {
call void @f2(i32 %i)
ret void
}
define void @g3(i32 %i) {
call void @f1(i32 range(i32 3, 4) %i)
ret void
}
define void @g4(i32 %i) {
call void @f2(i32 range(i32 3, 4) %i)
ret void
}
)IR";

SMDiagnostic Err;
LLVMContext Context;
std::unique_ptr<Module> M = parseAssemblyString(IRString, Err, Context);
ASSERT_TRUE(M);

{
auto *I = cast<CallBase>(&M->getFunction("g1")->getEntryBlock().front());
ASSERT_FALSE(I->getParamAttr(0, Attribute::Range).isValid());
}
{
auto *I = cast<CallBase>(&M->getFunction("g2")->getEntryBlock().front());
ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
}
{
auto *I = cast<CallBase>(&M->getFunction("g3")->getEntryBlock().front());
ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
}
{
auto *I = cast<CallBase>(&M->getFunction("g4")->getEntryBlock().front());
ASSERT_TRUE(I->getParamAttr(0, Attribute::Range).isValid());
}
}

} // end anonymous namespace