Skip to content

[DAG] Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB) #90860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2838,6 +2838,66 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
return DAG.getNode(ISD::ADD, DL, VT, Not, N0.getOperand(0));
}

// Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
// This can help if the inner add has multiple uses.
APInt CM, CA;
if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(N1)) {
if (VT.getScalarSizeInBits() <= 64) {
if (sd_match(N0, m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
m_ConstInt(CM)))) &&
TLI.isLegalAddImmediate(
(CA * CM + CB->getAPIntValue()).getSExtValue())) {
SDNodeFlags Flags;
// If all the inputs are nuw, the outputs can be nuw. If all the input
// are _also_ nsw the outputs can be too.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can flags be preserved?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to drop flags the flags for general reassociations. I think that would apply here unless you know of a reason why it wouldn't?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean I always just stick it in alive and see if it works

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the generic combine: https://alive2.llvm.org/ce/z/k7Yvo3. In general the flags need removing.

In specific cases the flags might be retained, but it feels to me a bit niche and difficult to specify: https://alive2.llvm.org/ce/z/aK5Az3.

Let me know what you think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using constants in the proof is making this more permissive looking than it should be

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first link is the general proof with any values. noundef %CA/%CB/%CM are the constants. In general the flags from the add/mul need to be dropped.
The second link was a specific example where certain constant can keep nuw/nsw, but it is specific to those constants.

It does look like if all the input adds/mul are nsw/nuw then we might be able to keep some flags on the remaining instructions. I'm not sure if it's worth it considering all the other transforms in DAG, but I can try and add that to the patch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played around with a bit and I'm not sure what the rule is, probably best to keep that in a separate patch if it's worth it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, seems to just be and on all instructions

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh - sorry. This is the version with all nuw: https://alive2.llvm.org/ce/z/nxaMNR
And the version with both: https://alive2.llvm.org/ce/z/65Xyyf
It didn't apply with just nsw though: https://alive2.llvm.org/ce/z/7eZnnp

I can remove that patch if you like, it would be just a case of dropping the last patch from this review.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep it - as long as some tests actually hit this

if (N->getFlags().hasNoUnsignedWrap() &&
N0->getFlags().hasNoUnsignedWrap() &&
N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
Flags.setNoUnsignedWrap(true);
if (N->getFlags().hasNoSignedWrap() &&
N0->getFlags().hasNoSignedWrap() &&
N0.getOperand(0)->getFlags().hasNoSignedWrap())
Flags.setNoSignedWrap(true);
}
SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
DAG.getConstant(CM, DL, VT), Flags);
return DAG.getNode(
ISD::ADD, DL, VT, Mul,
DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
}
// Also look in case there is an intermediate add.
if (sd_match(N0, m_OneUse(m_Add(
m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
m_ConstInt(CM))),
m_Value(B)))) &&
TLI.isLegalAddImmediate(
(CA * CM + CB->getAPIntValue()).getSExtValue())) {
SDNodeFlags Flags;
// If all the inputs are nuw, the outputs can be nuw. If all the input
// are _also_ nsw the outputs can be too.
SDValue OMul =
N0.getOperand(0) == B ? N0.getOperand(1) : N0.getOperand(0);
if (N->getFlags().hasNoUnsignedWrap() &&
N0->getFlags().hasNoUnsignedWrap() &&
OMul->getFlags().hasNoUnsignedWrap() &&
OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
Flags.setNoUnsignedWrap(true);
if (N->getFlags().hasNoSignedWrap() &&
N0->getFlags().hasNoSignedWrap() &&
OMul->getFlags().hasNoSignedWrap() &&
OMul.getOperand(0)->getFlags().hasNoSignedWrap())
Flags.setNoSignedWrap(true);
}
SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
DAG.getConstant(CM, DL, VT), Flags);
SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N1), VT, Mul, B, Flags);
return DAG.getNode(
ISD::ADD, DL, VT, Add,
DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
}
}
}

if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
return Combined;

Expand Down
67 changes: 31 additions & 36 deletions llvm/test/CodeGen/AArch64/addimm-mulimm.ll
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ define signext i32 @addmuladd_multiuse(i32 signext %a) {
; CHECK-LABEL: addmuladd_multiuse:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #324 // =0x144
; CHECK-NEXT: mov w9, #1300 // =0x514
; CHECK-NEXT: madd w8, w0, w8, w9
; CHECK-NEXT: add w9, w0, #4
; CHECK-NEXT: mov w10, #4 // =0x4
; CHECK-NEXT: madd w8, w9, w8, w10
; CHECK-NEXT: eor w0, w9, w8
; CHECK-NEXT: ret
%tmp0 = add i32 %a, 4
Expand Down Expand Up @@ -198,11 +198,10 @@ define signext i32 @addmuladd_multiuse2(i32 signext %a) {
; CHECK-LABEL: addmuladd_multiuse2:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #324 // =0x144
; CHECK-NEXT: add w9, w0, #4
; CHECK-NEXT: mov w11, #4 // =0x4
; CHECK-NEXT: lsl w10, w9, #2
; CHECK-NEXT: madd w8, w9, w8, w11
; CHECK-NEXT: add w9, w10, #4
; CHECK-NEXT: lsl w9, w0, #2
; CHECK-NEXT: mov w10, #1300 // =0x514
; CHECK-NEXT: madd w8, w0, w8, w10
; CHECK-NEXT: add w9, w9, #20
; CHECK-NEXT: eor w0, w8, w9
; CHECK-NEXT: ret
%tmp0 = add i32 %a, 4
Expand Down Expand Up @@ -233,8 +232,8 @@ define signext i32 @addaddmuladd_multiuse(i32 signext %a, i32 %b) {
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #324 // =0x144
; CHECK-NEXT: add w9, w0, #4
; CHECK-NEXT: madd w8, w9, w8, w1
; CHECK-NEXT: add w8, w8, #4
; CHECK-NEXT: madd w8, w0, w8, w1
; CHECK-NEXT: add w8, w8, #1300
; CHECK-NEXT: eor w0, w9, w8
; CHECK-NEXT: ret
%tmp0 = add i32 %a, 4
Expand All @@ -249,12 +248,11 @@ define signext i32 @addaddmuladd_multiuse2(i32 signext %a, i32 %b) {
; CHECK-LABEL: addaddmuladd_multiuse2:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #324 // =0x144
; CHECK-NEXT: add w9, w0, #4
; CHECK-NEXT: mov w10, #162 // =0xa2
; CHECK-NEXT: madd w8, w9, w8, w1
; CHECK-NEXT: madd w9, w9, w10, w1
; CHECK-NEXT: add w8, w8, #4
; CHECK-NEXT: add w9, w9, #4
; CHECK-NEXT: mov w9, #162 // =0xa2
; CHECK-NEXT: madd w8, w0, w8, w1
; CHECK-NEXT: madd w9, w0, w9, w1
; CHECK-NEXT: add w8, w8, #1300
; CHECK-NEXT: add w9, w9, #652
; CHECK-NEXT: eor w0, w9, w8
; CHECK-NEXT: ret
%tmp0 = add i32 %a, 4
Expand Down Expand Up @@ -319,17 +317,17 @@ define void @addmuladd_gep(ptr %p, i64 %a) {
define i32 @addmuladd_gep2(ptr %p, i32 %a) {
; CHECK-LABEL: addmuladd_gep2:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #3240 // =0xca8
; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
; CHECK-NEXT: sxtw x8, w1
; CHECK-NEXT: mov w9, #3240 // =0xca8
; CHECK-NEXT: add x8, x8, #1
; CHECK-NEXT: madd x9, x8, x9, x0
; CHECK-NEXT: ldr w9, [x9, #20]
; CHECK-NEXT: tbnz w9, #31, .LBB22_2
; CHECK-NEXT: smaddl x8, w1, w8, x0
; CHECK-NEXT: ldr w8, [x8, #3260]
; CHECK-NEXT: tbnz w8, #31, .LBB22_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: mov w0, wzr
; CHECK-NEXT: ret
; CHECK-NEXT: .LBB22_2: // %then
; CHECK-NEXT: sxtw x8, w1
; CHECK-NEXT: add x8, x8, #1
; CHECK-NEXT: str x8, [x0]
; CHECK-NEXT: mov w0, #1 // =0x1
; CHECK-NEXT: ret
Expand All @@ -351,11 +349,10 @@ define signext i32 @addmuladd_multiuse2_nsw(i32 signext %a) {
; CHECK-LABEL: addmuladd_multiuse2_nsw:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #324 // =0x144
; CHECK-NEXT: add w9, w0, #4
; CHECK-NEXT: mov w11, #4 // =0x4
; CHECK-NEXT: lsl w10, w9, #2
; CHECK-NEXT: madd w8, w9, w8, w11
; CHECK-NEXT: add w9, w10, #4
; CHECK-NEXT: lsl w9, w0, #2
; CHECK-NEXT: mov w10, #1300 // =0x514
; CHECK-NEXT: madd w8, w0, w8, w10
; CHECK-NEXT: add w9, w9, #20
; CHECK-NEXT: eor w0, w8, w9
; CHECK-NEXT: ret
%tmp0 = add nsw i32 %a, 4
Expand All @@ -371,11 +368,10 @@ define signext i32 @addmuladd_multiuse2_nuw(i32 signext %a) {
; CHECK-LABEL: addmuladd_multiuse2_nuw:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #324 // =0x144
; CHECK-NEXT: add w9, w0, #4
; CHECK-NEXT: mov w11, #4 // =0x4
; CHECK-NEXT: lsl w10, w9, #2
; CHECK-NEXT: madd w8, w9, w8, w11
; CHECK-NEXT: add w9, w10, #4
; CHECK-NEXT: lsl w9, w0, #2
; CHECK-NEXT: mov w10, #1300 // =0x514
; CHECK-NEXT: madd w8, w0, w8, w10
; CHECK-NEXT: add w9, w9, #20
; CHECK-NEXT: eor w0, w8, w9
; CHECK-NEXT: ret
%tmp0 = add nuw i32 %a, 4
Expand All @@ -391,11 +387,10 @@ define signext i32 @addmuladd_multiuse2_nswnuw(i32 signext %a) {
; CHECK-LABEL: addmuladd_multiuse2_nswnuw:
; CHECK: // %bb.0:
; CHECK-NEXT: mov w8, #324 // =0x144
; CHECK-NEXT: add w9, w0, #4
; CHECK-NEXT: mov w11, #4 // =0x4
; CHECK-NEXT: lsl w10, w9, #2
; CHECK-NEXT: madd w8, w9, w8, w11
; CHECK-NEXT: add w9, w10, #4
; CHECK-NEXT: lsl w9, w0, #2
; CHECK-NEXT: mov w10, #1300 // =0x514
; CHECK-NEXT: madd w8, w0, w8, w10
; CHECK-NEXT: add w9, w9, #20
; CHECK-NEXT: eor w0, w8, w9
; CHECK-NEXT: ret
%tmp0 = add nsw nuw i32 %a, 4
Expand Down