Skip to content

[InstCombine] Fold (x < y) ? -1 : zext(x > y) and (x > y) ? 1 : sext(x < y) to ucmp/scmp(x, y) #105272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3560,7 +3560,9 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {

// This function tries to fold the following operations:
// (x < y) ? -1 : zext(x != y)
// (x < y) ? -1 : zext(x > y)
// (x > y) ? 1 : sext(x != y)
// (x > y) ? 1 : sext(x < y)
// Into ucmp/scmp(x, y), where signedness is determined by the signedness
// of the comparison in the original sequence.
Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
Expand Down Expand Up @@ -3589,16 +3591,23 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
ICmpInst::isSigned(Pred) ? Intrinsic::scmp : Intrinsic::ucmp;

bool Replace = false;
ICmpInst::Predicate ExtendedCmpPredicate;
// (x < y) ? -1 : zext(x != y)
// (x < y) ? -1 : zext(x > y)
if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
match(FV, m_ZExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
m_Specific(RHS)))))
match(FV, m_ZExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
m_Specific(RHS)))) &&
(ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
Replace = true;

// (x > y) ? 1 : sext(x != y)
// (x > y) ? 1 : sext(x < y)
if (ICmpInst::isGT(Pred) && match(TV, m_One()) &&
match(FV, m_SExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
m_Specific(RHS)))))
match(FV, m_SExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
m_Specific(RHS)))) &&
(ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
Replace = true;

if (Replace)
Expand Down
28 changes: 28 additions & 0 deletions llvm/test/Transforms/InstCombine/scmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,20 @@ define i8 @scmp_from_select_lt(i32 %x, i32 %y) {
ret i8 %r
}

; Fold (x s< y) ? -1 : zext(x s> y) into scmp(x, y)
define i8 @scmp_from_select_lt_and_gt(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_from_select_lt_and_gt(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: ret i8 [[R]]
;
%gt_bool = icmp sgt i32 %x, %y
%gt = zext i1 %gt_bool to i8
%lt = icmp slt i32 %x, %y
%r = select i1 %lt, i8 -1, i8 %gt
ret i8 %r
}

; Vector version
define <4 x i8> @scmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: define <4 x i8> @scmp_from_select_vec_lt(
Expand Down Expand Up @@ -315,3 +329,17 @@ define i8 @scmp_of_sub_and_zero_neg3(i32 %x, i32 %y) {
%r = call i8 @llvm.ucmp(i32 %diff, i32 0)
ret i8 %r
}

; Fold (x s> y) ? 1 : sext(x s< y)
define i8 @scmp_from_select_gt_and_lt(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_from_select_gt_and_lt(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: ret i8 [[R]]
;
%lt_bool = icmp slt i32 %x, %y
%lt = sext i1 %lt_bool to i8
%gt = icmp sgt i32 %x, %y
%r = select i1 %gt, i8 1, i8 %lt
ret i8 %r
}
42 changes: 11 additions & 31 deletions llvm/test/Transforms/InstCombine/select-select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ define float @foo1(float %a) {

define float @foo2(float %a) {
; CHECK-LABEL: @foo2(
; CHECK-NEXT: [[B:%.*]] = fcmp ule float [[C:%.*]], 0.000000e+00
; CHECK-NEXT: [[D:%.*]] = fcmp olt float [[C]], 1.000000e+00
; CHECK-NEXT: [[E:%.*]] = select i1 [[D]], float [[C]], float 1.000000e+00
; CHECK-NEXT: [[B:%.*]] = fcmp ule float [[A:%.*]], 0.000000e+00
; CHECK-NEXT: [[TMP1:%.*]] = fcmp olt float [[A]], 1.000000e+00
; CHECK-NEXT: [[E:%.*]] = select i1 [[TMP1]], float [[A]], float 1.000000e+00
; CHECK-NEXT: [[F:%.*]] = select i1 [[B]], float 0.000000e+00, float [[E]]
; CHECK-NEXT: ret float [[F]]
;
Expand Down Expand Up @@ -330,10 +330,7 @@ define i8 @strong_order_cmp_eq_ugt(i32 %a, i32 %b) {

define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_slt_sgt(
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.lt = icmp slt i32 %a, %b
Expand All @@ -345,10 +342,7 @@ define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {

define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_ult_ugt(
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.lt = icmp ult i32 %a, %b
Expand All @@ -360,10 +354,7 @@ define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {

define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_sgt_slt(
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.gt = icmp sgt i32 %a, %b
Expand All @@ -375,10 +366,7 @@ define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {

define i8 @strong_order_cmp_ugt_ult(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_ugt_ult(
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.gt = icmp ugt i32 %a, %b
Expand Down Expand Up @@ -460,8 +448,7 @@ define i8 @strong_order_cmp_ugt_ult_zext_not_oneuse(i32 %a, i32 %b) {
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
; CHECK-NEXT: call void @use8(i8 [[ZEXT]])
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.gt = icmp ugt i32 %a, %b
Expand All @@ -477,8 +464,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
; CHECK-NEXT: call void @use8(i8 [[SEXT]])
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.lt = icmp slt i32 %a, %b
Expand All @@ -491,10 +477,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {

define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector(
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
; CHECK-NEXT: [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 -1, i8 -1>, <2 x i8> [[ZEXT]]
; CHECK-NEXT: [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_LT]]
;
%cmp.gt = icmp ugt <2 x i32> %a, %b
Expand All @@ -506,10 +489,7 @@ define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {

define <2 x i8> @strong_order_cmp_ugt_ult_vector_poison(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector_poison(
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
; CHECK-NEXT: [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 poison, i8 -1>, <2 x i8> [[ZEXT]]
; CHECK-NEXT: [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_LT]]
;
%cmp.gt = icmp ugt <2 x i32> %a, %b
Expand Down
32 changes: 30 additions & 2 deletions llvm/test/Transforms/InstCombine/ucmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ define i8 @ucmp_from_select_lt(i32 %x, i32 %y) {
ret i8 %r
}

; Fold (x u< y) ? -1 : zext(x u> y) into ucmp(x, y)
define i8 @ucmp_from_select_lt_and_gt(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @ucmp_from_select_lt_and_gt(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: ret i8 [[R]]
;
%gt_bool = icmp ugt i32 %x, %y
%gt = zext i1 %gt_bool to i8
%lt = icmp ult i32 %x, %y
%r = select i1 %lt, i8 -1, i8 %gt
ret i8 %r
}

; Vector version
define <4 x i8> @ucmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: define <4 x i8> @ucmp_from_select_vec_lt(
Expand Down Expand Up @@ -349,13 +363,13 @@ define i8 @ucmp_from_select_le_neg1(i32 %x, i32 %y) {
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ult i32 [[X]], [[Y]]
; CHECK-NEXT: [[NE:%.*]] = sext i1 [[NE_BOOL]] to i8
; CHECK-NEXT: [[LE_NOT:%.*]] = icmp ugt i32 [[X]], [[Y]]
; CHECK-NEXT: [[LE_NOT:%.*]] = icmp ult i32 [[X]], [[Y]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[LE_NOT]], i8 1, i8 [[NE]]
; CHECK-NEXT: ret i8 [[R]]
;
%ne_bool = icmp ult i32 %x, %y
%ne = sext i1 %ne_bool to i8
%le = icmp ule i32 %x, %y
%le = icmp uge i32 %x, %y
%r = select i1 %le, i8 %ne, i8 1
ret i8 %r
}
Expand Down Expand Up @@ -513,3 +527,17 @@ define i8 @ucmp_from_select_ge_neg4(i32 %x, i32 %y) {
%r = select i1 %ge, i8 %ne, i8 3
ret i8 %r
}

; Fold (x > y) ? 1 : sext(x < y)
define i8 @ucmp_from_select_gt_and_lt(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @ucmp_from_select_gt_and_lt(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: ret i8 [[R]]
;
%lt_bool = icmp ult i32 %x, %y
%lt = sext i1 %lt_bool to i8
%gt = icmp ugt i32 %x, %y
%r = select i1 %gt, i8 1, i8 %lt
ret i8 %r
}
Loading