Skip to content

Commit 8cf517f

Browse files
committed
Fixup! using helper function to get the reduction opcode.
1 parent 51a97d7 commit 8cf517f

File tree

1 file changed

+19
-45
lines changed

1 file changed

+19
-45
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,66 +1192,40 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
11921192
ICA.getArgTypes()[0], CmpInst::BAD_ICMP_PREDICATE,
11931193
CostKind);
11941194
case Intrinsic::vp_reduce_add:
1195-
return getArithmeticReductionCost(Instruction::Add,
1196-
cast<VectorType>(ICA.getArgTypes()[1]),
1197-
std::nullopt, CostKind);
11981195
case Intrinsic::vp_reduce_fadd:
1199-
return getArithmeticReductionCost(Instruction::FAdd,
1200-
cast<VectorType>(ICA.getArgTypes()[1]),
1201-
ICA.getFlags(), CostKind);
12021196
case Intrinsic::vp_reduce_mul:
1203-
return getArithmeticReductionCost(Instruction::Mul,
1204-
cast<VectorType>(ICA.getArgTypes()[1]),
1205-
std::nullopt, CostKind);
12061197
case Intrinsic::vp_reduce_fmul:
1207-
return getArithmeticReductionCost(Instruction::FMul,
1208-
cast<VectorType>(ICA.getArgTypes()[1]),
1209-
ICA.getFlags(), CostKind);
12101198
case Intrinsic::vp_reduce_and:
1211-
return getArithmeticReductionCost(Instruction::And,
1212-
cast<VectorType>(ICA.getArgTypes()[1]),
1213-
std::nullopt, CostKind);
12141199
case Intrinsic::vp_reduce_or:
1215-
return getArithmeticReductionCost(Instruction::Or,
1216-
cast<VectorType>(ICA.getArgTypes()[1]),
1217-
std::nullopt, CostKind);
1218-
case Intrinsic::vp_reduce_xor:
1219-
return getArithmeticReductionCost(Instruction::Xor,
1220-
cast<VectorType>(ICA.getArgTypes()[1]),
1221-
std::nullopt, CostKind);
1200+
case Intrinsic::vp_reduce_xor: {
1201+
std::optional<Intrinsic::ID> RedID =
1202+
VPIntrinsic::getFunctionalIntrinsicIDForVP(ICA.getID());
1203+
assert(RedID.has_value());
1204+
unsigned RedOp = getArithmeticReductionInstruction(*RedID);
1205+
if (RedOp == Instruction::FAdd || RedOp == Instruction::FMul)
1206+
return getArithmeticReductionCost(RedOp,
1207+
cast<VectorType>(ICA.getArgTypes()[1]),
1208+
ICA.getFlags(), CostKind);
1209+
return getArithmeticReductionCost(
1210+
RedOp, cast<VectorType>(ICA.getArgTypes()[1]), std::nullopt, CostKind);
1211+
}
12221212
case Intrinsic::vp_reduce_smax:
1223-
return getMinMaxReductionCost(Intrinsic::smax,
1224-
cast<VectorType>(ICA.getArgTypes()[1]),
1225-
ICA.getFlags(), CostKind);
12261213
case Intrinsic::vp_reduce_smin:
1227-
return getMinMaxReductionCost(Intrinsic::smin,
1228-
cast<VectorType>(ICA.getArgTypes()[1]),
1229-
ICA.getFlags(), CostKind);
12301214
case Intrinsic::vp_reduce_umax:
1231-
return getMinMaxReductionCost(Intrinsic::umax,
1232-
cast<VectorType>(ICA.getArgTypes()[1]),
1233-
ICA.getFlags(), CostKind);
12341215
case Intrinsic::vp_reduce_umin:
1235-
return getMinMaxReductionCost(Intrinsic::umin,
1236-
cast<VectorType>(ICA.getArgTypes()[1]),
1237-
ICA.getFlags(), CostKind);
12381216
case Intrinsic::vp_reduce_fmax:
1239-
return getMinMaxReductionCost(Intrinsic::maxnum,
1240-
cast<VectorType>(ICA.getArgTypes()[1]),
1241-
ICA.getFlags(), CostKind);
12421217
case Intrinsic::vp_reduce_fmaximum:
1243-
return getMinMaxReductionCost(Intrinsic::maximum,
1244-
cast<VectorType>(ICA.getArgTypes()[1]),
1245-
ICA.getFlags(), CostKind);
12461218
case Intrinsic::vp_reduce_fmin:
1247-
return getMinMaxReductionCost(Intrinsic::minnum,
1248-
cast<VectorType>(ICA.getArgTypes()[1]),
1249-
ICA.getFlags(), CostKind);
1250-
case Intrinsic::vp_reduce_fminimum:
1251-
return getMinMaxReductionCost(Intrinsic::minimum,
1219+
case Intrinsic::vp_reduce_fminimum: {
1220+
std::optional<Intrinsic::ID> RedID =
1221+
VPIntrinsic::getFunctionalIntrinsicIDForVP(ICA.getID());
1222+
assert(RedID.has_value());
1223+
Intrinsic::ID MinMaxID = getMinMaxReductionIntrinsicOp(*RedID);
1224+
return getMinMaxReductionCost(MinMaxID,
12521225
cast<VectorType>(ICA.getArgTypes()[1]),
12531226
ICA.getFlags(), CostKind);
12541227
}
1228+
}
12551229

12561230
if (ST->hasVInstructions() && RetTy->isVectorTy()) {
12571231
if (auto LT = getTypeLegalizationCost(RetTy);

0 commit comments

Comments
 (0)