Skip to content

[InstCombine] Improve select equiv fold for plain condition #83405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,51 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) {
return C1I.isOne() || C1I.isAllOnes() || C2I.isOne() || C2I.isAllOnes();
}

/// Try to simplify seletion chain with partially identical conditions, eg:
/// %s1 = select i1 %c1, i32 23, i32 45
/// %s2 = select i1 %c2, i32 666, i32 %s1
/// %s3 = select i1 %c1, i32 789, i32 %s2
/// -->
/// %s2 = select i1 %c2, i32 666, i32 45
/// %s3 = select i1 %c1, i32 789, i32 %s2
static bool simplifySeqSelectWithSameCond(SelectInst &SI,
const SimplifyQuery &SQ,
InstCombinerImpl &IC) {
Value *CondVal = SI.getCondition();
auto trySimplifySeqSelect = [=, &SI, &IC](unsigned OpIndex) {
assert((OpIndex == 1 || OpIndex == 2) && "Unexpected operand index");
SelectInst *SINext = &SI;
Type *SelType = SINext->getType();
Value *ValOp = SINext->getOperand(OpIndex);
Value *CondNext;
// Don't need propagate FMF flag because we update the operand of SINext
// directly.
// It is not profitable to build a new select for SINext with multi-arms.
while (match(ValOp, m_Select(m_Value(CondNext), m_Value(), m_Value()))) {
if (CondNext == CondVal && SINext->hasOneUse()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be match(ValOp, m_OneUse(m_Select(....)) then drop the SINext->hasOneUse(), otherwise this won't fold the base select if it has multiple uses.

Copy link
Contributor Author

@vfdff vfdff Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the following case, it can be optimized because %sel1 is a sinle use node.
If we adjust according above change, then it will not optimize because the operand of %sel1 (%sel0)is multi used.

define i8 @sequence_select_with_same_cond_multi_arms_src(i1 %cond0, i1 %cond1, i8 %a, i8 %b) {
  %sel0 = select i1 %cond0, i8 %a, i8 %b
  %sel1 = select i1 %cond1, i8 %sel0, i8 2; %sel1 used in single node
  %sel2 = select i1 %cond1, i8 %sel0, i8 3
  %sel3 = select i1 %cond0, i8 %sel1, i8 %sel2
  ret i8 %sel3
}

define i8 @sequence_select_with_same_cond_multi_arms_tgt(i1 %cond0, i1 %cond1, i8 %a, i8 %b) {
  %sel0 = select i1 %cond0, i8 %a, i8 %b
  %sel1 = select i1 %cond1, i8 %a, i8 2; %sel1 used in single node
  %sel2 = select i1 %cond1, i8 %sel0, i8 3
  %sel3 = select i1 %cond0, i8 %sel1, i8 %sel2
  ret i8 %sel3
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, what about V == SINext || SINext->hasOneUse() then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean ValOp == SINext || SINext->hasOneUse() ? the condition ValOp == SINext is protected by ValOp = SINext->getOperand(OpIndex), and the CondNext == CondVal is also need, consider case. https://alive2.llvm.org/ce/z/TKLmMG

define i32 @src(i32 %a, i1 %c1, i1 %c2, i1 %c3){
  %s1 = select i1 %c1, i32 23, i32 45
  %s2 = select i1 %c2, i32 666, i32 %s1 ; this node can be optimized iff %c1 == %c3
  %s3 = select i1 %c3, i32 789, i32 %s2
  ret i32 %s3
}

define i32 @tgt(i32 %a, i1 %c1, i1 %c2, i1 %c3){
  %s1 = select i1 %c1, i32 23, i32 45
  %s2 = select i1 %c2, i32 666, i32 45
  %s3 = select i1 %c3, i32 789, i32 %s2
  ret i32 %s3
}

IC.replaceOperand(*SINext, OpIndex,

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

cast<SelectInst>(ValOp)->getOperand(OpIndex));
return true;
}

SINext = cast<SelectInst>(ValOp);
SelType = SINext->getType();
ValOp = SINext->getOperand(OpIndex);
}

This comment was marked as outdated.

This comment was marked as outdated.

return false;
};

// Try to simplify the true value of select.
if (trySimplifySeqSelect(/*OpIndex=*/1))
return true;

// Try to simplify the false value of select.
if (trySimplifySeqSelect(/*OpIndex=*/2))
return true;

return false;
}

/// Try to fold the select into one of the operands to allow further
/// optimization.
Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
Expand Down Expand Up @@ -567,6 +612,9 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
if (Instruction *R = TryFoldSelectIntoOp(SI, FalseVal, TrueVal, true))
return R;

if (simplifySeqSelectWithSameCond(SI, SQ, *this))
return &SI;

return nullptr;
}

Expand Down
61 changes: 50 additions & 11 deletions llvm/test/Transforms/InstCombine/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4527,9 +4527,8 @@ define i32 @src_select_xxory_eq0_xorxy_y(i32 %x, i32 %y) {

define i32 @sequence_select_with_same_cond_false(i1 %c1, i1 %c2){
; CHECK-LABEL: @sequence_select_with_same_cond_false(
; CHECK-NEXT: [[S1:%.*]] = select i1 [[C1:%.*]], i32 23, i32 45
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 666, i32 [[S1]]
; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1]], i32 789, i32 [[S2]]
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 666, i32 45
; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1:%.*]], i32 789, i32 [[S2]]
; CHECK-NEXT: ret i32 [[S3]]
;
%s1 = select i1 %c1, i32 23, i32 45
Expand All @@ -4540,9 +4539,8 @@ define i32 @sequence_select_with_same_cond_false(i1 %c1, i1 %c2){

define i32 @sequence_select_with_same_cond_true(i1 %c1, i1 %c2){
; CHECK-LABEL: @sequence_select_with_same_cond_true(
; CHECK-NEXT: [[S1:%.*]] = select i1 [[C1:%.*]], i32 45, i32 23
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 [[S1]], i32 666
; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1]], i32 [[S2]], i32 789
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 45, i32 666
; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1:%.*]], i32 [[S2]], i32 789
; CHECK-NEXT: ret i32 [[S3]]
;
%s1 = select i1 %c1, i32 45, i32 23
Expand All @@ -4553,9 +4551,8 @@ define i32 @sequence_select_with_same_cond_true(i1 %c1, i1 %c2){

define double @sequence_select_with_same_cond_double(double %a, i1 %c1, i1 %c2, double %r1, double %r2){
; CHECK-LABEL: @sequence_select_with_same_cond_double(
; CHECK-NEXT: [[S1:%.*]] = select i1 [[C1:%.*]], double 1.000000e+00, double 0.000000e+00
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], double [[S1]], double 2.000000e+00
; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1]], double [[S2]], double 3.000000e+00
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], double 1.000000e+00, double 2.000000e+00
; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1:%.*]], double [[S2]], double 3.000000e+00
; CHECK-NEXT: ret double [[S3]]
;
%s1 = select i1 %c1, double 1.0, double 0.0
Expand All @@ -4564,19 +4561,44 @@ define double @sequence_select_with_same_cond_double(double %a, i1 %c1, i1 %c2,
ret double %s3
}

; Confirm the FMF flag is propagated
define float @sequence_select_with_same_cond_float_and_fmf_flag1(i1 %c1, i1 %c2){
; CHECK-LABEL: @sequence_select_with_same_cond_float_and_fmf_flag1(
; CHECK-NEXT: [[S2:%.*]] = select fast i1 [[C2:%.*]], float 6.660000e+02, float 4.500000e+01
; CHECK-NEXT: [[S3:%.*]] = select fast i1 [[C1:%.*]], float 7.890000e+02, float [[S2]]
; CHECK-NEXT: ret float [[S3]]
;
%s1 = select i1 %c1, float 23.0, float 45.0
%s2 = select fast i1 %c2, float 666.0, float %s1 ; has fast flag
%s3 = select fast i1 %c1, float 789.0, float %s2
ret float %s3
}

define float @sequence_select_with_same_cond_float_and_fmf_flag2(i1 %c1, i1 %c2){
; CHECK-LABEL: @sequence_select_with_same_cond_float_and_fmf_flag2(
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], float 6.660000e+02, float 4.500000e+01
; CHECK-NEXT: [[S3:%.*]] = select fast i1 [[C1:%.*]], float 7.890000e+02, float [[S2]]
; CHECK-NEXT: ret float [[S3]]
;
%s1 = select fast i1 %c1, float 23.0, float 45.0
%s2 = select i1 %c2, float 666.0, float %s1 ; has no fast flag
%s3 = select fast i1 %c1, float 789.0, float %s2
ret float %s3
}

declare void @use32(i32)

define i32 @sequence_select_with_same_cond_extra_use(i1 %c1, i1 %c2){
; CHECK-LABEL: @sequence_select_with_same_cond_extra_use(
; CHECK-NEXT: [[S1:%.*]] = select i1 [[C1:%.*]], i32 23, i32 45
; CHECK-NEXT: call void @use32(i32 [[S1]])
; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 666, i32 [[S1]]
; CHECK-NEXT: call void @use32(i32 [[S2]])
; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1]], i32 789, i32 [[S2]]
; CHECK-NEXT: ret i32 [[S3]]
;
%s1 = select i1 %c1, i32 23, i32 45
call void @use32(i32 %s1)
%s2 = select i1 %c2, i32 666, i32 %s1
call void @use32(i32 %s2)
%s3 = select i1 %c1, i32 789, i32 %s2
ret i32 %s3
}
Expand Down Expand Up @@ -4612,3 +4634,20 @@ define i8 @test_replace_freeze_oneuse(i1 %x, i8 %y) {
%sel = select i1 %x, i8 0, i8 %shl.fr
ret i8 %sel
}

; first, %sel2 change into select i1 %cond1, i8 %sel0, i8 3, the the %sel1 is OneUse
; second, %sel1 change into select i1 %cond1, i8 %a, i8 2
define i8 @sequence_select_with_same_cond_multi_arms(i1 %cond0, i1 %cond1, i8 %a, i8 %b) {
; CHECK-LABEL: @sequence_select_with_same_cond_multi_arms(
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[COND0:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1:%.*]], i8 [[A]], i8 2
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[COND1]], i8 [[SEL0]], i8 3
; CHECK-NEXT: [[SEL3:%.*]] = select i1 [[COND0]], i8 [[SEL1]], i8 [[SEL2]]
; CHECK-NEXT: ret i8 [[SEL3]]
;
%sel0 = select i1 %cond0, i8 %a, i8 %b

This comment was marked as outdated.

%sel1 = select i1 %cond1, i8 %sel0, i8 2; %sel1 used in multi arms
%sel2 = select i1 %cond1, i8 %sel1, i8 3
%sel3 = select i1 %cond0, i8 %sel1, i8 %sel2
ret i8 %sel3
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests that show flag preservation behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added sequence_select_with_same_cond_float_and_fmf_flag1 and sequence_select_with_same_cond_float_and_fmf_flag2, thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to have tests with different flags (and more specific than fast). Need to see the union / intersect behavior

Loading