@@ -1027,39 +1027,45 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
1027
1027
1028
1028
std::optional<SmallVector<APInt>>
1029
1029
llvm::ConstantFoldICmp (unsigned Pred, const Register Op1, const Register Op2,
1030
+ unsigned DstScalarSizeInBits, unsigned ExtOp,
1030
1031
const MachineRegisterInfo &MRI) {
1031
- LLT Ty = MRI.getType (Op1);
1032
- if (Ty != MRI.getType (Op2))
1033
- return std::nullopt;
1032
+ assert (ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
1033
+ ExtOp == TargetOpcode::G_ANYEXT);
1034
+
1035
+ const LLT Ty = MRI.getType (Op1);
1034
1036
1035
- auto TryFoldScalar = [&MRI, Pred](Register LHS,
1036
- Register RHS) -> std::optional<APInt> {
1037
+ auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional<APInt> {
1037
1038
auto LHSCst = getIConstantVRegVal (LHS, MRI);
1038
1039
auto RHSCst = getIConstantVRegVal (RHS, MRI);
1039
1040
if (!LHSCst || !RHSCst)
1040
1041
return std::nullopt;
1041
1042
1043
+ const APInt FalseCst = APInt::getZero (DstScalarSizeInBits);
1044
+ const APInt TrueCst = (ExtOp == TargetOpcode::G_SEXT)
1045
+ ? APInt::getAllOnes (DstScalarSizeInBits)
1046
+ : APInt::getOneBitSet (DstScalarSizeInBits, 0 );
1047
+
1042
1048
switch (Pred) {
1043
1049
case CmpInst::Predicate::ICMP_EQ:
1044
- return APInt ( /* numBits= */ 1 , LHSCst->eq (*RHSCst)) ;
1050
+ return LHSCst->eq (*RHSCst) ? TrueCst : FalseCst ;
1045
1051
case CmpInst::Predicate::ICMP_NE:
1046
- return APInt ( /* numBits= */ 1 , LHSCst->ne (*RHSCst)) ;
1052
+ return LHSCst->ne (*RHSCst) ? TrueCst : FalseCst ;
1047
1053
case CmpInst::Predicate::ICMP_UGT:
1048
- return APInt ( /* numBits= */ 1 , LHSCst->ugt (*RHSCst)) ;
1054
+ return LHSCst->ugt (*RHSCst) ? TrueCst : FalseCst ;
1049
1055
case CmpInst::Predicate::ICMP_UGE:
1050
- return APInt ( /* numBits= */ 1 , LHSCst->uge (*RHSCst)) ;
1056
+ return LHSCst->uge (*RHSCst) ? TrueCst : FalseCst ;
1051
1057
case CmpInst::Predicate::ICMP_ULT:
1052
- return APInt ( /* numBits= */ 1 , LHSCst->ult (*RHSCst)) ;
1058
+ return LHSCst->ult (*RHSCst) ? TrueCst : FalseCst ;
1053
1059
case CmpInst::Predicate::ICMP_ULE:
1054
- return APInt ( /* numBits= */ 1 , LHSCst->ule (*RHSCst)) ;
1060
+ return LHSCst->ule (*RHSCst) ? TrueCst : FalseCst ;
1055
1061
case CmpInst::Predicate::ICMP_SGT:
1056
- return APInt ( /* numBits= */ 1 , LHSCst->sgt (*RHSCst)) ;
1062
+ return LHSCst->sgt (*RHSCst) ? TrueCst : FalseCst ;
1057
1063
case CmpInst::Predicate::ICMP_SGE:
1058
- return APInt ( /* numBits= */ 1 , LHSCst->sge (*RHSCst)) ;
1064
+ return LHSCst->sge (*RHSCst) ? TrueCst : FalseCst ;
1059
1065
case CmpInst::Predicate::ICMP_SLT:
1060
- return APInt ( /* numBits= */ 1 , LHSCst->slt (*RHSCst)) ;
1066
+ return LHSCst->slt (*RHSCst) ? TrueCst : FalseCst ;
1061
1067
case CmpInst::Predicate::ICMP_SLE:
1062
- return APInt ( /* numBits= */ 1 , LHSCst->sle (*RHSCst)) ;
1068
+ return LHSCst->sle (*RHSCst) ? TrueCst : FalseCst ;
1063
1069
default :
1064
1070
return std::nullopt;
1065
1071
}
0 commit comments