Skip to content

Commit 88cc35b

Browse files
authored
[InstCombine] Fold binop (select cond, a, b), (select cond, b, a) to binop a, b (#74953)
``` CommutativeBinOp(select(V, A, B), select(V, B, A) --> CommutativeBinOp(A, B) CommutativeIntrinsicCall(select(V, A, B), select(V, B, A), ...) --> CommutativeIntrinsicCall(A, B, ...) ``` https://alive2.llvm.org/ce/z/8CDUZ4 Closes #73904
1 parent f68435f commit 88cc35b

File tree

4 files changed

+595
-0
lines changed

4 files changed

+595
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
15361536
}
15371537

15381538
if (II->isCommutative()) {
1539+
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
1540+
return I;
1541+
15391542
if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
15401543
return NewCall;
15411544
}
@@ -4217,3 +4220,23 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
42174220
Call.setCalledFunction(FTy, NestF);
42184221
return &Call;
42194222
}
4223+
4224+
// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
4225+
Instruction *
4226+
InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
4227+
assert(II.isCommutative());
4228+
4229+
Value *A, *B, *C;
4230+
bool LHSIsSelect =
4231+
match(II.getOperand(0), m_Select(m_Value(A), m_Value(B), m_Value(C)));
4232+
bool RHSIsSymmetricalSelect = match(
4233+
II.getOperand(1), m_Select(m_Specific(A), m_Specific(C), m_Specific(B)));
4234+
4235+
if (LHSIsSelect && RHSIsSymmetricalSelect) {
4236+
replaceOperand(II, 0, B);
4237+
replaceOperand(II, 1, C);
4238+
return ⅈ
4239+
}
4240+
4241+
return nullptr;
4242+
}

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
276276
bool transformConstExprCastCall(CallBase &Call);
277277
Instruction *transformCallThroughTrampoline(CallBase &Call,
278278
IntrinsicInst &Tramp);
279+
Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);
279280

280281
Value *simplifyMaskedLoad(IntrinsicInst &II);
281282
Instruction *simplifyMaskedStore(IntrinsicInst &II);

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,14 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
11321132
};
11331133

11341134
if (LHSIsSelect && RHSIsSelect && A == D) {
1135+
// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
1136+
if (I.isCommutative() && B == F && C == E) {
1137+
Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
1138+
if (auto *BO = dyn_cast<BinaryOperator>(BI))
1139+
BO->copyIRFlags(&I);
1140+
return BI;
1141+
}
1142+
11351143
// (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
11361144
Cond = A;
11371145
True = simplifyBinOp(Opcode, B, E, FMF, Q);

0 commit comments

Comments
 (0)