Skip to content

Commit 0398f01

Browse files
committed
Address comment. Rewrite the flow
1 parent 966e68b commit 0398f01

File tree

1 file changed

+35
-40
lines changed

1 file changed

+35
-40
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,47 +1108,44 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
11081108
IsSigned ? RISCV::VFNCVT_RTZ_X_F_W : RISCV::VFNCVT_RTZ_XU_F_W;
11091109
unsigned SrcEltSize = Src->getScalarSizeInBits();
11101110
unsigned DstEltSize = Dst->getScalarSizeInBits();
1111+
InstructionCost Cost = 0;
11111112
if ((SrcEltSize == 16) &&
11121113
(!ST->hasVInstructionsF16() || ((DstEltSize >> 1) > SrcEltSize))) {
1114+
// If the target only supports vfhmin or it is fp16-to-i64 conversion
11131115
// pre-widening to f32 and then convert f32 to integer
11141116
VectorType *VecF32Ty =
11151117
VectorType::get(Type::getFloatTy(Dst->getContext()),
11161118
cast<VectorType>(Dst)->getElementCount());
11171119
std::pair<InstructionCost, MVT> VecF32LT =
11181120
getTypeLegalizationCost(VecF32Ty);
1119-
InstructionCost WidenCost = getRISCVInstructionCost(
1120-
RISCV::VFWCVT_F_F_V, VecF32LT.second, CostKind);
1121-
InstructionCost ConvCost =
1122-
getCastInstrCost(Opcode, Dst, VecF32Ty, CCH, CostKind, I);
1123-
return VecF32LT.first * WidenCost + ConvCost;
1121+
Cost +=
1122+
VecF32LT.first * getRISCVInstructionCost(RISCV::VFWCVT_F_F_V,
1123+
VecF32LT.second, CostKind);
1124+
Cost += getCastInstrCost(Opcode, Dst, VecF32Ty, CCH, CostKind, I);
1125+
return Cost;
11241126
}
11251127
if (DstEltSize == SrcEltSize)
1126-
return getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
1127-
if ((DstEltSize >> 1) == SrcEltSize)
1128-
return getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
1129-
InstructionCost TruncCost = 0;
1128+
Cost += getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
1129+
else if (DstEltSize > SrcEltSize)
1130+
Cost += getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
1131+
else { // (SrcEltSize > DstEltSize)
1132+
// First do a narrowing conversion to an integer half the size, then
1133+
// truncate if needed.
1134+
MVT ElementVT = MVT::getIntegerVT(SrcEltSize >> 1);
1135+
MVT VecVT = DstLT.second.changeVectorElementType(ElementVT);
1136+
Cost += getRISCVInstructionCost(FNCVT, VecVT, CostKind);
1137+
}
11301138
if ((SrcEltSize >> 1) > DstEltSize) {
1131-
// For fp vector to mask, we use:
1132-
// vfncvt.rtz.x.f.w v9, v8
1133-
// vand.vi v8, v9, 1 generated by Trunc
1134-
// vmsne.vi v0, v8, 0 generated by Trunc
1139+
// For mask type, we use:
1140+
// vand.vi v8, v9, 1
1141+
// vmsne.vi v0, v8, 0
11351142
VectorType *VecTy =
11361143
VectorType::get(IntegerType::get(Dst->getContext(), SrcEltSize >> 1),
11371144
cast<VectorType>(Dst)->getElementCount());
1138-
TruncCost =
1145+
Cost +=
11391146
getCastInstrCost(Instruction::Trunc, Dst, VecTy, CCH, CostKind, I);
11401147
}
1141-
if (SrcEltSize > DstEltSize) {
1142-
// First do a narrowing conversion to an integer half the size, then
1143-
// truncate if needed.
1144-
MVT ElementVT = MVT::getIntegerVT(SrcEltSize >> 1);
1145-
MVT VecVT = DstLT.second.changeVectorElementType(ElementVT);
1146-
InstructionCost ConvCost =
1147-
getRISCVInstructionCost(FNCVT, VecVT, CostKind);
1148-
return ConvCost + TruncCost;
1149-
}
1150-
1151-
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
1148+
return Cost;
11521149
}
11531150
case ISD::SINT_TO_FP:
11541151
case ISD::UINT_TO_FP: {
@@ -1159,22 +1156,22 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
11591156
unsigned SrcEltSize = Src->getScalarSizeInBits();
11601157
unsigned DstEltSize = Dst->getScalarSizeInBits();
11611158

1159+
InstructionCost Cost = 0;
11621160
if ((DstEltSize == 16) &&
11631161
(!ST->hasVInstructionsF16() || ((SrcEltSize >> 1) > DstEltSize))) {
1164-
// convert to f32 and then f32 to f16
1162+
// If the target only supports vfhmin or it is i64-to-fp16 conversion
1163+
// it is converted to f32 and then converted to f16
11651164
VectorType *VecF32Ty =
11661165
VectorType::get(Type::getFloatTy(Dst->getContext()),
11671166
cast<VectorType>(Dst)->getElementCount());
11681167
std::pair<InstructionCost, MVT> VecF32LT =
11691168
getTypeLegalizationCost(VecF32Ty);
1170-
InstructionCost FP32ConvCost =
1171-
getCastInstrCost(Opcode, VecF32Ty, Src, CCH, CostKind, I);
1172-
return FP32ConvCost +
1173-
VecF32LT.first * getRISCVInstructionCost(RISCV::VFNCVT_F_F_W,
1169+
Cost = VecF32LT.first * getRISCVInstructionCost(RISCV::VFNCVT_F_F_W,
11741170
DstLT.second, CostKind);
1171+
Cost += getCastInstrCost(Opcode, VecF32Ty, Src, CCH, CostKind, I);
1172+
return Cost;
11751173
}
11761174

1177-
InstructionCost PreWidenCost = 0;
11781175
if ((DstEltSize >> 1) > SrcEltSize) {
11791176
// Do pre-widening before converting:
11801177
// 1. Backend could lower (v[sz]ext i8 to double) to
@@ -1187,17 +1184,15 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
11871184
VectorType::get(IntegerType::get(Dst->getContext(), SrcEltSize),
11881185
cast<VectorType>(Dst)->getElementCount());
11891186
unsigned Op = IsSigned ? Instruction::SExt : Instruction::ZExt;
1190-
PreWidenCost = getCastInstrCost(Op, VecTy, Src, CCH, CostKind, I);
1187+
Cost += getCastInstrCost(Op, VecTy, Src, CCH, CostKind, I);
11911188
}
11921189
if (DstEltSize == SrcEltSize)
1193-
return PreWidenCost +
1194-
getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
1195-
if ((DstEltSize >> 1) == SrcEltSize)
1196-
return PreWidenCost +
1197-
getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
1198-
if ((SrcEltSize >> 1) == DstEltSize)
1199-
return getRISCVInstructionCost(FNCVT, DstLT.second, CostKind);
1200-
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
1190+
Cost += getRISCVInstructionCost(FCVT, DstLT.second, CostKind);
1191+
else if (DstEltSize > SrcEltSize)
1192+
Cost += getRISCVInstructionCost(FWCVT, DstLT.second, CostKind);
1193+
else
1194+
Cost += getRISCVInstructionCost(FNCVT, DstLT.second, CostKind);
1195+
return Cost;
12011196
}
12021197
}
12031198
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

0 commit comments

Comments
 (0)