Skip to content

Commit 817f453

Browse files
authored
[RISCV][TTI] Refactor getCastInstrCost to exit early (#86619)
To reduce the indentation by using early returns, this patch hoist the return for illegal type and non vector type earlier. It should mostly be an NFC.
1 parent c6a65e4 commit 817f453

File tree

1 file changed

+65
-68
lines changed

1 file changed

+65
-68
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 65 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
897897
TTI::CastContextHint CCH,
898898
TTI::TargetCostKind CostKind,
899899
const Instruction *I) {
900-
if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
901-
// FIXME: Need to compute legalizing cost for illegal types.
902-
if (!isTypeLegal(Src) || !isTypeLegal(Dst))
903-
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
904-
905-
// Skip if element size of Dst or Src is bigger than ELEN.
906-
if (Src->getScalarSizeInBits() > ST->getELen() ||
907-
Dst->getScalarSizeInBits() > ST->getELen())
908-
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
909-
910-
int ISD = TLI->InstructionOpcodeToISD(Opcode);
911-
assert(ISD && "Invalid opcode");
912-
913-
// FIXME: Need to consider vsetvli and lmul.
914-
int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
915-
(int)Log2_32(Src->getScalarSizeInBits());
916-
switch (ISD) {
917-
case ISD::SIGN_EXTEND:
918-
case ISD::ZERO_EXTEND:
919-
if (Src->getScalarSizeInBits() == 1) {
920-
// We do not use vsext/vzext to extend from mask vector.
921-
// Instead we use the following instructions to extend from mask vector:
922-
// vmv.v.i v8, 0
923-
// vmerge.vim v8, v8, -1, v0
924-
return 2;
925-
}
926-
return 1;
927-
case ISD::TRUNCATE:
928-
if (Dst->getScalarSizeInBits() == 1) {
929-
// We do not use several vncvt to truncate to mask vector. So we could
930-
// not use PowDiff to calculate it.
931-
// Instead we use the following instructions to truncate to mask vector:
932-
// vand.vi v8, v8, 1
933-
// vmsne.vi v0, v8, 0
934-
return 2;
935-
}
936-
[[fallthrough]];
937-
case ISD::FP_EXTEND:
938-
case ISD::FP_ROUND:
939-
// Counts of narrow/widen instructions.
940-
return std::abs(PowDiff);
941-
case ISD::FP_TO_SINT:
942-
case ISD::FP_TO_UINT:
943-
case ISD::SINT_TO_FP:
944-
case ISD::UINT_TO_FP:
945-
if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
946-
// The cost of convert from or to mask vector is different from other
947-
// cases. We could not use PowDiff to calculate it.
948-
// For mask vector to fp, we should use the following instructions:
949-
// vmv.v.i v8, 0
950-
// vmerge.vim v8, v8, -1, v0
951-
// vfcvt.f.x.v v8, v8
952-
953-
// And for fp vector to mask, we use:
954-
// vfncvt.rtz.x.f.w v9, v8
955-
// vand.vi v8, v9, 1
956-
// vmsne.vi v0, v8, 0
957-
return 3;
958-
}
959-
if (std::abs(PowDiff) <= 1)
960-
return 1;
961-
// Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
962-
// so it only need two conversion.
963-
if (Src->isIntOrIntVectorTy())
964-
return 2;
965-
// Counts of narrow/widen instructions.
966-
return std::abs(PowDiff);
900+
bool IsVectorType = isa<VectorType>(Dst) && isa<VectorType>(Src);
901+
bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
902+
(Src->getScalarSizeInBits() <= ST->getELen()) &&
903+
(Dst->getScalarSizeInBits() <= ST->getELen());
904+
905+
// FIXME: Need to compute legalizing cost for illegal types.
906+
if (!IsVectorType || !IsTypeLegal)
907+
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
908+
909+
int ISD = TLI->InstructionOpcodeToISD(Opcode);
910+
assert(ISD && "Invalid opcode");
911+
912+
// FIXME: Need to consider vsetvli and lmul.
913+
int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
914+
(int)Log2_32(Src->getScalarSizeInBits());
915+
switch (ISD) {
916+
case ISD::SIGN_EXTEND:
917+
case ISD::ZERO_EXTEND:
918+
if (Src->getScalarSizeInBits() == 1) {
919+
// We do not use vsext/vzext to extend from mask vector.
920+
// Instead we use the following instructions to extend from mask vector:
921+
// vmv.v.i v8, 0
922+
// vmerge.vim v8, v8, -1, v0
923+
return 2;
967924
}
925+
return 1;
926+
case ISD::TRUNCATE:
927+
if (Dst->getScalarSizeInBits() == 1) {
928+
// We do not use several vncvt to truncate to mask vector. So we could
929+
// not use PowDiff to calculate it.
930+
// Instead we use the following instructions to truncate to mask vector:
931+
// vand.vi v8, v8, 1
932+
// vmsne.vi v0, v8, 0
933+
return 2;
934+
}
935+
[[fallthrough]];
936+
case ISD::FP_EXTEND:
937+
case ISD::FP_ROUND:
938+
// Counts of narrow/widen instructions.
939+
return std::abs(PowDiff);
940+
case ISD::FP_TO_SINT:
941+
case ISD::FP_TO_UINT:
942+
case ISD::SINT_TO_FP:
943+
case ISD::UINT_TO_FP:
944+
if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
945+
// The cost of convert from or to mask vector is different from other
946+
// cases. We could not use PowDiff to calculate it.
947+
// For mask vector to fp, we should use the following instructions:
948+
// vmv.v.i v8, 0
949+
// vmerge.vim v8, v8, -1, v0
950+
// vfcvt.f.x.v v8, v8
951+
952+
// And for fp vector to mask, we use:
953+
// vfncvt.rtz.x.f.w v9, v8
954+
// vand.vi v8, v9, 1
955+
// vmsne.vi v0, v8, 0
956+
return 3;
957+
}
958+
if (std::abs(PowDiff) <= 1)
959+
return 1;
960+
// Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
961+
// so it only need two conversion.
962+
if (Src->isIntOrIntVectorTy())
963+
return 2;
964+
// Counts of narrow/widen instructions.
965+
return std::abs(PowDiff);
968966
}
969-
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
970967
}
971968

972969
unsigned RISCVTTIImpl::getEstimatedVLFor(VectorType *Ty) {

0 commit comments

Comments
 (0)