Skip to content

Commit 7545c63

Browse files
authored
[RISCV][TTI] Scale the cost of the sext/zext with LMUL (#86617)
Use the destination data type to measure the LMUL size for latency/throughput cost
1 parent fa1b807 commit 7545c63

File tree

6 files changed

+566
-556
lines changed

6 files changed

+566
-556
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -909,23 +909,33 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
909909
if (!IsTypeLegal)
910910
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
911911

912+
std::pair<InstructionCost, MVT> DstLT = getTypeLegalizationCost(Dst);
913+
912914
int ISD = TLI->InstructionOpcodeToISD(Opcode);
913915
assert(ISD && "Invalid opcode");
914916

915-
// FIXME: Need to consider vsetvli and lmul.
916917
int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
917918
(int)Log2_32(Src->getScalarSizeInBits());
918919
switch (ISD) {
919920
case ISD::SIGN_EXTEND:
920-
case ISD::ZERO_EXTEND:
921-
if (Src->getScalarSizeInBits() == 1) {
921+
case ISD::ZERO_EXTEND: {
922+
const unsigned SrcEltSize = Src->getScalarSizeInBits();
923+
if (SrcEltSize == 1) {
922924
// We do not use vsext/vzext to extend from mask vector.
923925
// Instead we use the following instructions to extend from mask vector:
924926
// vmv.v.i v8, 0
925927
// vmerge.vim v8, v8, -1, v0
926-
return 2;
928+
return getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM},
929+
DstLT.second, CostKind);
927930
}
928-
return 1;
931+
if (PowDiff > 3)
932+
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
933+
unsigned SExtOp[] = {RISCV::VSEXT_VF2, RISCV::VSEXT_VF4, RISCV::VSEXT_VF8};
934+
unsigned ZExtOp[] = {RISCV::VZEXT_VF2, RISCV::VZEXT_VF4, RISCV::VZEXT_VF8};
935+
unsigned Op =
936+
(ISD == ISD::SIGN_EXTEND) ? SExtOp[PowDiff - 1] : ZExtOp[PowDiff - 1];
937+
return getRISCVInstructionCost(Op, DstLT.second, CostKind);
938+
}
929939
case ISD::TRUNCATE:
930940
if (Dst->getScalarSizeInBits() == 1) {
931941
// We do not use several vncvt to truncate to mask vector. So we could

0 commit comments

Comments
 (0)