Skip to content

[RISCV][CostModel] Add cost for fabs/fsqrt of type bf16/f16 #118608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/CostTable.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PatternMatch.h"
#include <cmath>
Expand Down Expand Up @@ -1035,21 +1036,66 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
}
break;
}
case Intrinsic::fabs:
case Intrinsic::fabs: {
auto LT = getTypeLegalizationCost(RetTy);
if (ST->hasVInstructions() && LT.second.isVector()) {
// lui a0, 8
// addi a0, a0, -1
// vsetvli a1, zero, e16, m1, ta, ma
// vand.vx v8, v8, a0
// f16 with zvfhmin and bf16 with zvfhbmin
if (LT.second.getVectorElementType() == MVT::bf16 ||
(LT.second.getVectorElementType() == MVT::f16 &&
!ST->hasVInstructionsF16()))
return LT.first * getRISCVInstructionCost(RISCV::VAND_VX, LT.second,
CostKind) +
2;
else
return LT.first *
getRISCVInstructionCost(RISCV::VFSGNJX_VV, LT.second, CostKind);
}
break;
}
case Intrinsic::sqrt: {
auto LT = getTypeLegalizationCost(RetTy);
// TODO: add f16/bf16, bf16 with zvfbfmin && f16 with zvfhmin
if (ST->hasVInstructions() && LT.second.isVector()) {
unsigned Op;
switch (ICA.getID()) {
case Intrinsic::fabs:
Op = RISCV::VFSGNJX_VV;
break;
case Intrinsic::sqrt:
Op = RISCV::VFSQRT_V;
break;
SmallVector<unsigned, 4> ConvOp;
SmallVector<unsigned, 2> FsqrtOp;
MVT ConvType = LT.second;
MVT FsqrtType = LT.second;
// f16 with zvfhmin and bf16 with zvfbfmin and the type of nxv32[b]f16
// will be spilt.
if (LT.second.getVectorElementType() == MVT::bf16) {
if (LT.second == MVT::nxv32bf16) {
ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVTBF16_F_F_V,
RISCV::VFNCVTBF16_F_F_W, RISCV::VFNCVTBF16_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V};
ConvType = MVT::nxv16f16;
FsqrtType = MVT::nxv16f32;
Comment on lines +1069 to +1074
Copy link
Contributor

@arcbbb arcbbb Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To handle nxv32bf16, you need 2 iterations of (vfwcvt nxv16bf16), (vfsqrt nxv16f32), and (vfncvt nxv16bf16).
The cost could therefore be LT.first * 2 * (vfwcvt + vfsqrt + vfncvt).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ConvOp have been spilt, so neednt * 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see now. I overlooked the duplicated opcode.

} else {
ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFNCVTBF16_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V};
FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure the reason getTypeToPromoteTo is needed here and below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It converts the f16 type to the f32 type as long as there is a setOperationAction for ISD::FSQRT call in RISCVISelLowering.cpp with Promote.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. Thanks for the explanation!

}
} else if (LT.second.getVectorElementType() == MVT::f16 &&
!ST->hasVInstructionsF16()) {
if (LT.second == MVT::nxv32f16) {
ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFWCVT_F_F_V,
RISCV::VFNCVT_F_F_W, RISCV::VFNCVT_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V};
ConvType = MVT::nxv16f16;
FsqrtType = MVT::nxv16f32;
Comment on lines +1082 to +1087
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

@LiqinWeng LiqinWeng Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ConvOp have been spilt, so neednt * 2 , same as FsqrtOp

} else {
ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFNCVT_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V};
FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType);
}
} else {
FsqrtOp = {RISCV::VFSQRT_V};
}
return LT.first * getRISCVInstructionCost(Op, LT.second, CostKind);

return LT.first * (getRISCVInstructionCost(FsqrtOp, FsqrtType, CostKind) +
getRISCVInstructionCost(ConvOp, ConvType, CostKind));
}
break;
}
Expand Down
Loading
Loading