Skip to content

Commit 966e68b

Browse files
committed
[RISCV][TTI] Scale the cost of FP-Int conversion with LMUL
Widening/narrowing the source data type to match the destination data type may require multiple steps. To model the costs, the patch generated the interim type by following the logic in RISCVTargetLowering::lowerVPFPIntConvOp. Address comments 1. Re-implement fp conversion for VFHMIN cases 2. Separate vfhmin cases from cast.ll Using FloatTy instead of DoubleTy to promote fp16 to fp32
1 parent 2d6f4d2 commit 966e68b

File tree

3 files changed

+2350
-1649
lines changed

3 files changed

+2350
-1649
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,34 +1099,106 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
10991099
return Cost;
11001100
}
11011101
case ISD::FP_TO_SINT:
1102-
case ISD::FP_TO_UINT:
1103-
// For fp vector to mask, we use:
1104-
// vfncvt.rtz.x.f.w v9, v8
1105-
// vand.vi v8, v9, 1
1106-
// vmsne.vi v0, v8, 0
1107-
if (Dst->getScalarSizeInBits() == 1)
1108-
return 3;
1102+
case ISD::FP_TO_UINT: {
1103+
unsigned IsSigned = ISD == ISD::FP_TO_SINT;
1104+
unsigned FCVT = IsSigned ? RISCV::VFCVT_RTZ_X_F_V : RISCV::VFCVT_RTZ_XU_F_V;
1105+
unsigned FWCVT =
1106+
IsSigned ? RISCV::VFWCVT_RTZ_X_F_V : RISCV::VFWCVT_RTZ_XU_F_V;
1107+
unsigned FNCVT =
1108+
IsSigned ? RISCV::VFNCVT_RTZ_X_F_W : RISCV::VFNCVT_RTZ_XU_F_W;
1109+
unsigned SrcEltSize = Src->getScalarSizeInBits();
1110+
unsigned DstEltSize = Dst->getScalarSizeInBits();
1111+
if ((SrcEltSize == 16) &&
1112+
(!ST->hasVInstructionsF16() || ((DstEltSize >> 1) > SrcEltSize))) {
1113+
// pre-widening to f32 and then convert f32 to integer
1114+
VectorType *VecF32Ty =
1115+
VectorType::get(Type::getFloatTy(Dst->getContext()),
1116+
cast<VectorType>(Dst)->getElementCount());
1117+
std::pair<InstructionCost, MVT> VecF32LT =
1118+
getTypeLegalizationCost(VecF32Ty);
1119+
InstructionCost WidenCost = getRISCVInstructionCost(
1120+
RISCV::VFWCVT_F_F_V, VecF32LT.second, CostKind);
1121+
InstructionCost ConvCost =
1122+
getCastInstrCost(Opcode, Dst, VecF32Ty, CCH, CostKind, I);
1123+
return VecF32LT.first * WidenCost + ConvCost;
1124+
}
1125+
if (DstEltSize == SrcEltSize)
1126+
return getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
1127+
if ((DstEltSize >> 1) == SrcEltSize)
1128+
return getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
1129+
InstructionCost TruncCost = 0;
1130+
if ((SrcEltSize >> 1) > DstEltSize) {
1131+
// For fp vector to mask, we use:
1132+
// vfncvt.rtz.x.f.w v9, v8
1133+
// vand.vi v8, v9, 1 generated by Trunc
1134+
// vmsne.vi v0, v8, 0 generated by Trunc
1135+
VectorType *VecTy =
1136+
VectorType::get(IntegerType::get(Dst->getContext(), SrcEltSize >> 1),
1137+
cast<VectorType>(Dst)->getElementCount());
1138+
TruncCost =
1139+
getCastInstrCost(Instruction::Trunc, Dst, VecTy, CCH, CostKind, I);
1140+
}
1141+
if (SrcEltSize > DstEltSize) {
1142+
// First do a narrowing conversion to an integer half the size, then
1143+
// truncate if needed.
1144+
MVT ElementVT = MVT::getIntegerVT(SrcEltSize >> 1);
1145+
MVT VecVT = DstLT.second.changeVectorElementType(ElementVT);
1146+
InstructionCost ConvCost =
1147+
getRISCVInstructionCost(FNCVT, VecVT, CostKind);
1148+
return ConvCost + TruncCost;
1149+
}
11091150

1110-
if (std::abs(PowDiff) <= 1)
1111-
return 1;
1151+
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
1152+
}
1153+
case ISD::SINT_TO_FP:
1154+
case ISD::UINT_TO_FP: {
1155+
unsigned IsSigned = ISD == ISD::SINT_TO_FP;
1156+
unsigned FCVT = IsSigned ? RISCV::VFCVT_F_X_V : RISCV::VFCVT_F_XU_V;
1157+
unsigned FWCVT = IsSigned ? RISCV::VFWCVT_F_X_V : RISCV::VFWCVT_F_XU_V;
1158+
unsigned FNCVT = IsSigned ? RISCV::VFNCVT_F_X_W : RISCV::VFNCVT_F_XU_W;
1159+
unsigned SrcEltSize = Src->getScalarSizeInBits();
1160+
unsigned DstEltSize = Dst->getScalarSizeInBits();
11121161

1113-
// Counts of narrow/widen instructions.
1114-
return std::abs(PowDiff);
1162+
if ((DstEltSize == 16) &&
1163+
(!ST->hasVInstructionsF16() || ((SrcEltSize >> 1) > DstEltSize))) {
1164+
// convert to f32 and then f32 to f16
1165+
VectorType *VecF32Ty =
1166+
VectorType::get(Type::getFloatTy(Dst->getContext()),
1167+
cast<VectorType>(Dst)->getElementCount());
1168+
std::pair<InstructionCost, MVT> VecF32LT =
1169+
getTypeLegalizationCost(VecF32Ty);
1170+
InstructionCost FP32ConvCost =
1171+
getCastInstrCost(Opcode, VecF32Ty, Src, CCH, CostKind, I);
1172+
return FP32ConvCost +
1173+
VecF32LT.first * getRISCVInstructionCost(RISCV::VFNCVT_F_F_W,
1174+
DstLT.second, CostKind);
1175+
}
11151176

1116-
case ISD::SINT_TO_FP:
1117-
case ISD::UINT_TO_FP:
1118-
// For mask vector to fp, we should use the following instructions:
1119-
// vmv.v.i v8, 0
1120-
// vmerge.vim v8, v8, -1, v0
1121-
// vfcvt.f.x.v v8, v8
1122-
if (Src->getScalarSizeInBits() == 1)
1123-
return 3;
1124-
1125-
if (std::abs(PowDiff) <= 1)
1126-
return 1;
1127-
// Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
1128-
// so it only need two conversion.
1129-
return 2;
1177+
InstructionCost PreWidenCost = 0;
1178+
if ((DstEltSize >> 1) > SrcEltSize) {
1179+
// Do pre-widening before converting:
1180+
// 1. Backend could lower (v[sz]ext i8 to double) to
1181+
// vfcvt(v[sz]ext.f8 i8),
1182+
// 2. For mask vector to fp, we should use the following instructions:
1183+
// vmv.v.i v8, 0
1184+
// vmerge.vim v8, v8, -1, v0
1185+
SrcEltSize = DstEltSize >> 1;
1186+
VectorType *VecTy =
1187+
VectorType::get(IntegerType::get(Dst->getContext(), SrcEltSize),
1188+
cast<VectorType>(Dst)->getElementCount());
1189+
unsigned Op = IsSigned ? Instruction::SExt : Instruction::ZExt;
1190+
PreWidenCost = getCastInstrCost(Op, VecTy, Src, CCH, CostKind, I);
1191+
}
1192+
if (DstEltSize == SrcEltSize)
1193+
return PreWidenCost +
1194+
getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
1195+
if ((DstEltSize >> 1) == SrcEltSize)
1196+
return PreWidenCost +
1197+
getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
1198+
if ((SrcEltSize >> 1) == DstEltSize)
1199+
return getRISCVInstructionCost(FNCVT, DstLT.second, CostKind);
1200+
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
1201+
}
11301202
}
11311203
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
11321204
}

0 commit comments

Comments
 (0)