Skip to content

[NFCI][SYCL] Refactor selection of FP builtin calls #16966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions llvm/include/llvm/Analysis/TargetLibraryInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define LLVM_ANALYSIS_TARGETLIBRARYINFO_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/IntrinsicInst.h"
Expand Down Expand Up @@ -83,6 +84,47 @@ class VecDesc {
NotLibFunc
};

/// Contains all possible FPBuiltin replacement choices by
/// selectFnForFPBuiltinCalls function.
struct FPBuiltinReplacement {
enum Kind {
Unexpected0dot5,
UnrecognizedFPAttrs,
NoSuitableReplacement,
ReplaceWithLLVMIR,
ReplaceWithAltMathFunction,
ReplaceWithApproxNVPTXCallsOrFallback
};

FPBuiltinReplacement(Kind K, const StringRef &ImplName = StringRef())
: RepKind(K), AltMathFunctionImplName(ImplName) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to assign to AltMathFunctionImplName even if K is not set to ReplaceWithAltMathFunction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not. In this case we should to either error out or attempt to lower to standard to LLVM instructions/intrinsics or target specific intrinsics if applicable (like nvvm approx math intrinsics).

// Check that ImplName is non-empty only if K is
// ReplaceWithAltMathFunction.
assert((K != Kind::ReplaceWithAltMathFunction || !ImplName.empty()) &&
"Expected non-empty function name");
}
FPBuiltinReplacement(const FPBuiltinReplacement &O)
: RepKind(O()), AltMathFunctionImplName(O.altMathFunctionImplName()) {}
FPBuiltinReplacement &operator=(const FPBuiltinReplacement &O) {
this->RepKind = O();
this->AltMathFunctionImplName = O.altMathFunctionImplName();
return *this;
}
~FPBuiltinReplacement() {}
Kind operator()() const { return RepKind; }
bool isReplaceble() const { return RepKind > Kind::NoSuitableReplacement; }
const StringRef &altMathFunctionImplName() const {
return AltMathFunctionImplName;
}

private:
/// In case of RepKind = Kind::ReplaceWithAltMathFunction
/// AltMathFunctionImplName also contains the name of the alternate math
/// function implementation.
Kind RepKind;
StringRef AltMathFunctionImplName;
};

/// Implementation of the target library information.
///
/// This class constructs tables that hold the target library information and
Expand Down Expand Up @@ -224,6 +266,16 @@ class TargetLibraryInfoImpl {
/// given alternate math library.
void addAltMathFunctionsFromLib(enum AltMathLibrary AltLib);

// Select an alternate math library implementation that meets the criteria
// described by an FPBuiltinIntrinsic call.
StringRef
selectFPBuiltinImplementation(const FPBuiltinIntrinsic *Builtin) const;

/// Returns the replacement choice for the given FPBuiltinIntrinsic call.
FPBuiltinReplacement
selectFnForFPBuiltinCalls(const FPBuiltinIntrinsic &BuiltinCall,
const TargetTransformInfo &TTI) const;

/// Select an alternate math library implementation that meets the criteria
/// described by an FPBuiltinIntrinsic call.
StringRef selectFPBuiltinImplementation(FPBuiltinIntrinsic *Builtin) const;
Expand Down Expand Up @@ -649,6 +701,13 @@ class TargetLibraryInfo {
bool isKnownVectorFunctionInLibrary(StringRef F) const {
return this->isFunctionVectorizable(F);
}

/// Returns the replacement choice for the given FPBuiltinIntrinsic call.
FPBuiltinReplacement
selectFnForFPBuiltinCalls(const FPBuiltinIntrinsic &BuiltinCall,
const TargetTransformInfo &TTI) const {
return Impl->selectFnForFPBuiltinCalls(BuiltinCall, TTI);
}
};

/// Analysis pass providing the \c TargetLibraryInfo.
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ class FPBuiltinIntrinsic : public IntrinsicInst {
/// Check the callsite attributes for this FPBuiltinIntrinsic against a list
/// of FP attributes that the caller knows how to process to see if the
/// current intrinsic has unrecognized attributes
bool hasUnrecognizedFPAttrs(const StringSet<> HandledAttrs);
bool hasUnrecognizedFPAttrs(const StringSet<> HandledAttrs) const;

/// Methods for support type inquiry through isa, cast, and dyn_cast:
/// @{
Expand Down
62 changes: 61 additions & 1 deletion llvm/lib/Analysis/TargetLibraryInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ void TargetLibraryInfoImpl::addAltMathFunctionsFromLib(
/// Select an alternate math library implementation that meets the criteria
/// described by an FPBuiltinIntrinsic call.
StringRef TargetLibraryInfoImpl::selectFPBuiltinImplementation(
FPBuiltinIntrinsic *Builtin) const {
const FPBuiltinIntrinsic *Builtin) const {
// TODO: Handle the case of no specified accuracy.
if (Builtin->getRequiredAccuracy() == std::nullopt)
return StringRef();
Expand All @@ -1353,6 +1353,66 @@ StringRef TargetLibraryInfoImpl::selectFPBuiltinImplementation(
return I->FnImplName;
}

FPBuiltinReplacement TargetLibraryInfoImpl::selectFnForFPBuiltinCalls(
const FPBuiltinIntrinsic &BuiltinCall,
const TargetTransformInfo &TTI) const {
auto DefaultOpIsCorrectlyRounded = [](const FPBuiltinIntrinsic &BuiltinCall) {
switch (BuiltinCall.getIntrinsicID()) {
case Intrinsic::fpbuiltin_fadd:
case Intrinsic::fpbuiltin_fsub:
case Intrinsic::fpbuiltin_fmul:
case Intrinsic::fpbuiltin_fdiv:
case Intrinsic::fpbuiltin_frem:
case Intrinsic::fpbuiltin_sqrt:
case Intrinsic::fpbuiltin_ldexp:
return true;
default:
return false;
}
};
StringSet<> RecognizedAttrs = {FPBuiltinIntrinsic::FPBUILTIN_MAX_ERROR};
if (BuiltinCall.hasUnrecognizedFPAttrs(std::move(RecognizedAttrs)))
return FPBuiltinReplacement(FPBuiltinReplacement::UnrecognizedFPAttrs);
Triple T(BuiltinCall.getModule()->getTargetTriple());
const auto Accuracy = BuiltinCall.getRequiredAccuracy();
// For fpbuiltin.sqrt, it should always use the native operation for
// x86-based targets because the native instruction is faster (even faster
// than the low-accuracy SVML implementation).
if (T.isX86() && BuiltinCall.getIntrinsicID() == Intrinsic::fpbuiltin_sqrt &&
TTI.haveFastSqrt(BuiltinCall.getOperand(0)->getType()))
return FPBuiltinReplacement(FPBuiltinReplacement::ReplaceWithLLVMIR);
// Several functions for SYCL and CUDA requires "0.5" accuracy levels,
// which means correctly rounded results. For now x86 host and NVPTX
// AltMathLibrary doesn't have such ability. For such accuracy level,
// the fpbuiltins should be replaced by equivalent IR operation or
// llvmbuiltins.
if ((T.isX86() || T.isNVPTX()) && Accuracy == 0.5) {
if (DefaultOpIsCorrectlyRounded(BuiltinCall))
return FPBuiltinReplacement(FPBuiltinReplacement::ReplaceWithLLVMIR);
return FPBuiltinReplacement(FPBuiltinReplacement::Unexpected0dot5);
}
// AltMathLibrary don't have implementation for CUDA approximate precision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: Does 'Accuracy > 0.5' map to 'approximate precision'?

Thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'Accuracy > 0.5' maps on 'not-precise'. Regarding if it maps on nvvm approx intrinsics/nvptx approx instruction - it depends on the operation and what ptx spec mandates.

More details is in the discussion with Joshua in #16714 .

// builtins. Lets map them on NVPTX intrinsics. If no appropriate intrinsics
// are known - skip to emit an error.
if (T.isNVPTX() && Accuracy > 0.5) {
return FPBuiltinReplacement(
FPBuiltinReplacement::ReplaceWithApproxNVPTXCallsOrFallback);
}

/// Call TLI to select a function implementation to call
const StringRef OutAltMathFunctionImplName =
selectFPBuiltinImplementation(&BuiltinCall);
if (OutAltMathFunctionImplName.empty()) {
// Operations that require correct rounding by default can always be
// replaced with the LLVM IR equivalent representation.
if (DefaultOpIsCorrectlyRounded(BuiltinCall))
return FPBuiltinReplacement(FPBuiltinReplacement::ReplaceWithLLVMIR);
return FPBuiltinReplacement(FPBuiltinReplacement::NoSuitableReplacement);
}
return FPBuiltinReplacement(FPBuiltinReplacement::ReplaceWithAltMathFunction,
OutAltMathFunctionImplName);
}

static bool compareByScalarFnName(const VecDesc &LHS, const VecDesc &RHS) {
return LHS.getScalarFnName() < RHS.getScalarFnName();
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/IntrinsicInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ std::optional<float> FPBuiltinIntrinsic::getRequiredAccuracy() const {
}

bool FPBuiltinIntrinsic::hasUnrecognizedFPAttrs(
const StringSet<> recognizedAttrs) {
const StringSet<> recognizedAttrs) const {
AttributeSet FnAttrs = getAttributes().getFnAttrs();
for (const Attribute &Attr : FnAttrs) {
if (!Attr.isStringAttribute())
Expand Down
66 changes: 22 additions & 44 deletions llvm/lib/Transforms/Scalar/FPBuiltinFnSelection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,54 +164,28 @@ static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
dbgs() << BuiltinCall.getRequiredAccuracy().value() << "\n";
});

StringSet<> RecognizedAttrs = {FPBuiltinIntrinsic::FPBUILTIN_MAX_ERROR};
if (BuiltinCall.hasUnrecognizedFPAttrs(RecognizedAttrs)) {
const FPBuiltinReplacement Replacement =
TLI.selectFnForFPBuiltinCalls(BuiltinCall, TTI);

switch (Replacement()) {
default:
llvm_unreachable("Unexpected replacement");
case FPBuiltinReplacement::Unexpected0dot5:
report_fatal_error("Unexpected fpbuiltin requiring 0.5 max error.");
return false;
case FPBuiltinReplacement::UnrecognizedFPAttrs:
report_fatal_error(
Twine(BuiltinCall.getCalledFunction()->getName()) +
Twine(" was called with unrecognized floating-point attributes.\n"),
false);
return false;
}

Triple T(BuiltinCall.getModule()->getTargetTriple());
// for fpbuiltin.sqrt, it should always use the native operation for
// x86-based targets because the native instruction is faster (even faster
// than the low-accuracy SVML implementation).
if (T.isX86() && BuiltinCall.getIntrinsicID() == Intrinsic::fpbuiltin_sqrt &&
TTI.haveFastSqrt(BuiltinCall.getOperand(0)->getType()))
return replaceWithLLVMIR(BuiltinCall);

// Several functions for "sycl" and "cuda" requires "0.5" accuracy levels,
// which means correctly rounded results. For now x86 host and NVPTX
// AltMathLibrary doesn't have such ability. For such accuracy level, the
// fpbuiltins should be replaced by equivalent IR operation or llvmbuiltins.
if ((T.isX86() || T.isNVPTX()) &&
BuiltinCall.getRequiredAccuracy().value() == 0.5) {
switch (BuiltinCall.getIntrinsicID()) {
case Intrinsic::fpbuiltin_fadd:
case Intrinsic::fpbuiltin_fsub:
case Intrinsic::fpbuiltin_fmul:
case Intrinsic::fpbuiltin_fdiv:
case Intrinsic::fpbuiltin_frem:
case Intrinsic::fpbuiltin_sqrt:
case Intrinsic::fpbuiltin_ldexp:
return replaceWithLLVMIR(BuiltinCall);
default:
report_fatal_error("Unexpected fpbuiltin requiring 0.5 max error.");
}
}

// AltMathLibrary don't have implementation for CUDA approximate precision
// builtins. Lets map them on NVPTX intrinsics. If no appropriate intrinsics
// are known - skip to emit an error.
if (T.isNVPTX() && BuiltinCall.getRequiredAccuracy().value() > 0.5)
case FPBuiltinReplacement::ReplaceWithApproxNVPTXCallsOrFallback: {
if (replaceWithApproxNVPTXCallsOrFallback(
BuiltinCall, BuiltinCall.getRequiredAccuracy()))
return true;

/// Call TLI to select a function implementation to call
StringRef ImplName = TLI.selectFPBuiltinImplementation(&BuiltinCall);
if (ImplName.empty()) {
[[fallthrough]];
}
case FPBuiltinReplacement::NoSuitableReplacement: {
LLVM_DEBUG(dbgs() << "No matching implementation found!\n");
std::string RequiredAccuracy;
if (BuiltinCall.getRequiredAccuracy() == std::nullopt)
Expand All @@ -228,10 +202,14 @@ static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
false);
return false;
}

LLVM_DEBUG(dbgs() << "Selected " << ImplName << "\n");

return replaceWithAltMathFunction(BuiltinCall, ImplName);
case FPBuiltinReplacement::ReplaceWithLLVMIR:
return replaceWithLLVMIR(BuiltinCall);
case FPBuiltinReplacement::ReplaceWithAltMathFunction:
LLVM_DEBUG(dbgs() << "Selected " << Replacement.altMathFunctionImplName()
<< "\n");
return replaceWithAltMathFunction(BuiltinCall,
Replacement.altMathFunctionImplName());
}
}

static bool runImpl(const TargetLibraryInfo &TLI,
Expand Down