Skip to content

[RISCV][TTI] Refactor getCastInstrCost to exit early #86619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 65 additions & 68 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
const Instruction *I) {
if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
// FIXME: Need to compute legalizing cost for illegal types.
if (!isTypeLegal(Src) || !isTypeLegal(Dst))
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

// Skip if element size of Dst or Src is bigger than ELEN.
if (Src->getScalarSizeInBits() > ST->getELen() ||
Dst->getScalarSizeInBits() > ST->getELen())
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");

// FIXME: Need to consider vsetvli and lmul.
int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
(int)Log2_32(Src->getScalarSizeInBits());
switch (ISD) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
if (Src->getScalarSizeInBits() == 1) {
// We do not use vsext/vzext to extend from mask vector.
// Instead we use the following instructions to extend from mask vector:
// vmv.v.i v8, 0
// vmerge.vim v8, v8, -1, v0
return 2;
}
return 1;
case ISD::TRUNCATE:
if (Dst->getScalarSizeInBits() == 1) {
// We do not use several vncvt to truncate to mask vector. So we could
// not use PowDiff to calculate it.
// Instead we use the following instructions to truncate to mask vector:
// vand.vi v8, v8, 1
// vmsne.vi v0, v8, 0
return 2;
}
[[fallthrough]];
case ISD::FP_EXTEND:
case ISD::FP_ROUND:
// Counts of narrow/widen instructions.
return std::abs(PowDiff);
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
// The cost of convert from or to mask vector is different from other
// cases. We could not use PowDiff to calculate it.
// For mask vector to fp, we should use the following instructions:
// vmv.v.i v8, 0
// vmerge.vim v8, v8, -1, v0
// vfcvt.f.x.v v8, v8

// And for fp vector to mask, we use:
// vfncvt.rtz.x.f.w v9, v8
// vand.vi v8, v9, 1
// vmsne.vi v0, v8, 0
return 3;
}
if (std::abs(PowDiff) <= 1)
return 1;
// Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
// so it only need two conversion.
if (Src->isIntOrIntVectorTy())
return 2;
// Counts of narrow/widen instructions.
return std::abs(PowDiff);
bool IsVectorType = isa<VectorType>(Dst) && isa<VectorType>(Src);
bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
(Src->getScalarSizeInBits() <= ST->getELen()) &&
(Dst->getScalarSizeInBits() <= ST->getELen());

// FIXME: Need to compute legalizing cost for illegal types.
if (!IsVectorType || !IsTypeLegal)
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");

// FIXME: Need to consider vsetvli and lmul.
int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
(int)Log2_32(Src->getScalarSizeInBits());
switch (ISD) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
if (Src->getScalarSizeInBits() == 1) {
// We do not use vsext/vzext to extend from mask vector.
// Instead we use the following instructions to extend from mask vector:
// vmv.v.i v8, 0
// vmerge.vim v8, v8, -1, v0
return 2;
}
return 1;
case ISD::TRUNCATE:
if (Dst->getScalarSizeInBits() == 1) {
// We do not use several vncvt to truncate to mask vector. So we could
// not use PowDiff to calculate it.
// Instead we use the following instructions to truncate to mask vector:
// vand.vi v8, v8, 1
// vmsne.vi v0, v8, 0
return 2;
}
[[fallthrough]];
case ISD::FP_EXTEND:
case ISD::FP_ROUND:
// Counts of narrow/widen instructions.
return std::abs(PowDiff);
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
// The cost of convert from or to mask vector is different from other
// cases. We could not use PowDiff to calculate it.
// For mask vector to fp, we should use the following instructions:
// vmv.v.i v8, 0
// vmerge.vim v8, v8, -1, v0
// vfcvt.f.x.v v8, v8

// And for fp vector to mask, we use:
// vfncvt.rtz.x.f.w v9, v8
// vand.vi v8, v9, 1
// vmsne.vi v0, v8, 0
return 3;
}
if (std::abs(PowDiff) <= 1)
return 1;
// Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
// so it only need two conversion.
if (Src->isIntOrIntVectorTy())
return 2;
// Counts of narrow/widen instructions.
return std::abs(PowDiff);
}
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing a return if none of the cases in the switch match

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp:967:1: warning: non-void function does not return a value in all control paths [-Wreturn-type]
  967 | }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the trouble. I fix it with commit 5dc0c75


unsigned RISCVTTIImpl::getEstimatedVLFor(VectorType *Ty) {
Expand Down