@@ -91,6 +91,10 @@ RISCVTTIImpl::getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
91
91
case RISCV::VMV_S_X:
92
92
case RISCV::VFMV_F_S:
93
93
case RISCV::VFMV_S_F:
94
+ case RISCV::VMOR_MM:
95
+ case RISCV::VMXOR_MM:
96
+ case RISCV::VMAND_MM:
97
+ case RISCV::VMANDN_MM:
94
98
case RISCV::VMNAND_MM:
95
99
case RISCV::VCPOP_M:
96
100
Cost += 1 ;
@@ -1383,7 +1387,13 @@ InstructionCost RISCVTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1383
1387
getRISCVInstructionCost (RISCV::VMSLT_VV, LT.second , CostKind);
1384
1388
}
1385
1389
1386
- if ((Opcode == Instruction::FCmp) && ValTy->isVectorTy ()) {
1390
+ if ((Opcode == Instruction::FCmp) && ValTy->isVectorTy () &&
1391
+ CmpInst::isFPPredicate (VecPred)) {
1392
+
1393
+ // Use VMXOR_MM and VMXNOR_MM to generate all true/false mask
1394
+ if ((VecPred == CmpInst::FCMP_FALSE) || (VecPred == CmpInst::FCMP_TRUE))
1395
+ return getRISCVInstructionCost (RISCV::VMXOR_MM, LT.second , CostKind);
1396
+
1387
1397
// If we do not support the input floating point vector type, use the base
1388
1398
// one which will calculate as:
1389
1399
// ScalarizeCost + Num * Cost for fixed vector,
@@ -1393,16 +1403,34 @@ InstructionCost RISCVTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1393
1403
(ValTy->getScalarSizeInBits () == 64 && !ST->hasVInstructionsF64 ()))
1394
1404
return BaseT::getCmpSelInstrCost (Opcode, ValTy, CondTy, VecPred, CostKind,
1395
1405
I);
1406
+
1407
+ // Assuming vector fp compare and mask instructions are all the same cost
1408
+ // until a need arises to differentiate them.
1396
1409
switch (VecPred) {
1397
- // Support natively.
1398
- case CmpInst::FCMP_OEQ:
1399
- case CmpInst::FCMP_OGT:
1400
- case CmpInst::FCMP_OGE:
1401
- case CmpInst::FCMP_OLT:
1402
- case CmpInst::FCMP_OLE:
1403
- case CmpInst::FCMP_UNE:
1404
- return LT.first * 1 ;
1405
- // TODO: Other comparisons?
1410
+ case CmpInst::FCMP_ONE: // vmflt.vv + vmflt.vv + vmor.mm
1411
+ case CmpInst::FCMP_ORD: // vmfeq.vv + vmfeq.vv + vmand.mm
1412
+ case CmpInst::FCMP_UNO: // vmfne.vv + vmfne.vv + vmor.mm
1413
+ case CmpInst::FCMP_UEQ: // vmflt.vv + vmflt.vv + vmnor.mm
1414
+ return LT.first * getRISCVInstructionCost (
1415
+ {RISCV::VMFLT_VV, RISCV::VMFLT_VV, RISCV::VMOR_MM},
1416
+ LT.second , CostKind);
1417
+
1418
+ case CmpInst::FCMP_UGT: // vmfle.vv + vmnot.m
1419
+ case CmpInst::FCMP_UGE: // vmflt.vv + vmnot.m
1420
+ case CmpInst::FCMP_ULT: // vmfle.vv + vmnot.m
1421
+ case CmpInst::FCMP_ULE: // vmflt.vv + vmnot.m
1422
+ return LT.first *
1423
+ getRISCVInstructionCost ({RISCV::VMFLT_VV, RISCV::VMNAND_MM},
1424
+ LT.second , CostKind);
1425
+
1426
+ case CmpInst::FCMP_OEQ: // vmfeq.vv
1427
+ case CmpInst::FCMP_OGT: // vmflt.vv
1428
+ case CmpInst::FCMP_OGE: // vmfle.vv
1429
+ case CmpInst::FCMP_OLT: // vmflt.vv
1430
+ case CmpInst::FCMP_OLE: // vmfle.vv
1431
+ case CmpInst::FCMP_UNE: // vmfne.vv
1432
+ return LT.first *
1433
+ getRISCVInstructionCost (RISCV::VMFLT_VV, LT.second , CostKind);
1406
1434
default :
1407
1435
break ;
1408
1436
}
0 commit comments