Skip to content

[CodeGenPrepare][RISCV] Combine (X ^ Y) and (X == Y) where appropriate #130922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2025

Conversation

bababuck
Copy link
Contributor

Fixes #130510.

In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for cases where the (X ^ Y) will be re-used.

If a constant is being used for the XOR before a branch, ensure that it is small enough to fit within a 12-bit immediate field. Otherwise, the equality check is more efficient than the check against 0, see the following:

# %bb.0:
        lui     a1, 5
        addiw   a1, a1, 1365
        xor     a0, a0, a1
        beqz    a0, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 
# %bb.0:
        lui     a1, 5
        addiw   a1, a1, 1365
        beq    a0, a1, .LBB0_2
# %bb.1: 
        xor     a0, a0, a1
        ret
.LBB0_2: 

Similarly, if the XOR is between 1 and a size one integer, we should still fold away the XOR since that comparison can be optimized as a comparison against 0.

# %bb.0:
        slt a0, a0, a1
        xor  a0, a0, 1
        beqz    a0, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 
# %bb.0:
        slt a0, a0, a1
        bnez    a0, .LBB0_2
# %bb.1: 
        xor  a0, a0, 1
        ret
.LBB0_2: 

One question about my code is that I used a hard-coded value for the width of a RISCV ALU immediate. Do you know of a way that I can gather this from the context, I was unable to devise one.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Ryan Buchner (bababuck)

Changes

Fixes #130510.

In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for cases where the (X ^ Y) will be re-used.

If a constant is being used for the XOR before a branch, ensure that it is small enough to fit within a 12-bit immediate field. Otherwise, the equality check is more efficient than the check against 0, see the following:

# %bb.0:
        lui     a1, 5
        addiw   a1, a1, 1365
        xor     a0, a0, a1
        beqz    a0, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 
# %bb.0:
        lui     a1, 5
        addiw   a1, a1, 1365
        beq    a0, a1, .LBB0_2
# %bb.1: 
        xor     a0, a0, a1
        ret
.LBB0_2: 

Similarly, if the XOR is between 1 and a size one integer, we should still fold away the XOR since that comparison can be optimized as a comparison against 0.

# %bb.0:
        slt a0, a0, a1
        xor  a0, a0, 1
        beqz    a0, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 
# %bb.0:
        slt a0, a0, a1
        bnez    a0, .LBB0_2
# %bb.1: 
        xor  a0, a0, 1
        ret
.LBB0_2: 

One question about my code is that I used a hard-coded value for the width of a RISCV ALU immediate. Do you know of a way that I can gather this from the context, I was unable to devise one.


Full diff: https://github.com/llvm/llvm-project/pull/130922.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+2-1)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+40-1)
  • (modified) llvm/test/CodeGen/RISCV/select-constant-xor.ll (+40)
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d5fbd4c380746..2acb7cb321d07 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -8578,7 +8578,8 @@ static bool optimizeBranch(BranchInst *Branch, const TargetLowering &TLI,
     }
     if (Cmp->isEquality() &&
         (match(UI, m_Add(m_Specific(X), m_SpecificInt(-CmpC))) ||
-         match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))))) {
+         match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))) ||
+         match(UI, m_Xor(m_Specific(X), m_SpecificInt(CmpC))))) {
       IRBuilder<> Builder(Branch);
       if (UI->getParent() != Branch->getParent())
         UI->moveBefore(Branch->getIterator());
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27a4bbce1f5fc..3abc835376f7b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17194,8 +17194,47 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
     return true;
   }
 
+  // If XOR is reused and has an immediate that will fit in XORI,
+  // do not fold
+  auto Is12BitConstant = [](const SDValue &Op) -> bool {
+    if (Op.getOpcode() == ISD::Constant) {
+      const int64_t RiscvAluImmBits = 12;
+      const int64_t RiscvAluImmUpperBound = (1 << RiscvAluImmBits) - 1;
+      const int64_t RiscvAluImmLowerBound = -(1 << RiscvAluImmBits);
+      const int64_t XorCnst =
+          llvm::dyn_cast<llvm::ConstantSDNode>(Op)->getSExtValue();
+      return (XorCnst >= RiscvAluImmLowerBound) &&
+             (XorCnst <= RiscvAluImmUpperBound);
+    }
+    return false;
+  };
+  // Fold (X(i1) ^ 1) == 0 -> X != 0
+  auto SingleBitOp = [&DAG](const SDValue &VarOp,
+                            const SDValue &ConstOp) -> bool {
+    if (ConstOp.getOpcode() == ISD::Constant) {
+      const int64_t XorCnst =
+          llvm::dyn_cast<llvm::ConstantSDNode>(ConstOp)->getSExtValue();
+      const APInt Mask = APInt::getBitsSetFrom(VarOp.getValueSizeInBits(), 1);
+      return (XorCnst == 1) && DAG.MaskedValueIsZero(VarOp, Mask);
+    }
+    return false;
+  };
+  auto OnlyUsedBySelectOrBR = [](const SDValue &Op) -> bool {
+    for (const SDUse &Use : Op->uses()) {
+      const SDNode *UseNode = Use.getUser();
+      const unsigned Opcode = UseNode->getOpcode();
+      if (Opcode != RISCVISD::SELECT_CC && Opcode != RISCVISD::BR_CC) {
+        return false;
+      }
+    }
+    return true;
+  };
+
   // Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne)
-  if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) {
+  if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS) &&
+      (!Is12BitConstant(LHS.getOperand(1)) ||
+       SingleBitOp(LHS.getOperand(0), LHS.getOperand(1))) &&
+      OnlyUsedBySelectOrBR(LHS)) {
     RHS = LHS.getOperand(1);
     LHS = LHS.getOperand(0);
     return true;
diff --git a/llvm/test/CodeGen/RISCV/select-constant-xor.ll b/llvm/test/CodeGen/RISCV/select-constant-xor.ll
index 2e26ae78e2dd8..f24e03ecd7d67 100644
--- a/llvm/test/CodeGen/RISCV/select-constant-xor.ll
+++ b/llvm/test/CodeGen/RISCV/select-constant-xor.ll
@@ -239,3 +239,43 @@ define i32 @oneusecmp(i32 %a, i32 %b, i32 %d) {
   %x = add i32 %s, %s2
   ret i32 %x
 }
+
+define i32 @xor_branch_ret(i32 %x) {
+; RV32-LABEL: xor_branch_ret:
+; RV32:       # %bb.0: # %entry
+; RV32-NEXT:    xori a0, a0, -1365
+; RV32-NEXT:    beqz a0, .LBB11_2
+; RV32-NEXT:  # %bb.1: # %if.then
+; RV32-NEXT:    ret
+; RV32-NEXT:  .LBB11_2: # %if.end
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call abort
+;
+; RV64-LABEL: xor_branch_ret:
+; RV64:       # %bb.0: # %entry
+; RV64-NEXT:    xori a0, a0, -1365
+; RV64-NEXT:    sext.w a1, a0
+; RV64-NEXT:    beqz a1, .LBB11_2
+; RV64-NEXT:  # %bb.1: # %if.then
+; RV64-NEXT:    ret
+; RV64-NEXT:  .LBB11_2: # %if.end
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call abort
+entry:
+  %cmp.not = icmp eq i32 %x, -1365
+  br i1 %cmp.not, label %if.end, label %if.then
+if.then:
+  %xor = xor i32 %x, -1365
+  ret i32 %xor
+if.end:
+    tail call void @abort() #2
+  unreachable
+}
+
+declare void @abort()

return false;
};
auto OnlyUsedBySelectOrBR = [](const SDValue &Op) -> bool {
for (const SDUse &Use : Op->uses()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use Op->users()

const SDValue &ConstOp) -> bool {
if (ConstOp.getOpcode() == ISD::Constant) {
const int64_t XorCnst =
llvm::dyn_cast<llvm::ConstantSDNode>(ConstOp)->getSExtValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The llvm:: is unnecessary. Move the dyn_cast into the if instead of checking the opcode.

const int64_t RiscvAluImmBits = 12;
const int64_t RiscvAluImmUpperBound = (1 << RiscvAluImmBits) - 1;
const int64_t RiscvAluImmLowerBound = -(1 << RiscvAluImmBits);
const int64_t XorCnst =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The llvm:: is unnecessary. Move the dyn_cast into the if instead of checking the opcode.

auto Is12BitConstant = [](const SDValue &Op) -> bool {
if (Op.getOpcode() == ISD::Constant) {
const int64_t RiscvAluImmBits = 12;
const int64_t RiscvAluImmUpperBound = (1 << RiscvAluImmBits) - 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use isInt<12> like the rest of the RISCV backend.

@bababuck bababuck force-pushed the rbuchner/130510 branch 2 times, most recently from aa0d2fb to 9b6b6fa Compare March 12, 2025 21:32
@bababuck
Copy link
Contributor Author

I addressed the above comments, but I came across one issue in my tests:

When working with riscv64 with i32 an i32 input (a0 is i32), the following code generation occurs:

entry:
  %cmp.not = icmp eq i32 %x, 2048
  br i1 %cmp.not, label %if.end, label %if.then
if.then:
  %xor = xor i32 %x, 2048
  ret i32 %xor
if.end:
    tail call void @abort() #2
  unreachable

----->

# %bb.0:
   li a1, 1
   slli a1, a1, 11
   xor a0, a0, a1
   sext.w a1, a0
   beqz a1, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 

Because the result of the sign extension is being used for the comparison, it does not matching the XOR folding pattern.

One though is that this sign extension is not needed because we are comparing against 0, so we can fold away the sign-extension. However, I'm not sure if that is legal or not (i.e. is an LLVM i32 guaranteed to have the upper bits 0'ed when compiling to a 64 bit architecture).

Any thoughts on how to proceed? I want to make sure I am not completely off track.

@topperc
Copy link
Collaborator

topperc commented Mar 12, 2025

I addressed the above comments, but I came across one issue in my tests:

When working with riscv64 with i32 an i32 input (a0 is i32), the following code generation occurs:

entry:
  %cmp.not = icmp eq i32 %x, 2048
  br i1 %cmp.not, label %if.end, label %if.then
if.then:
  %xor = xor i32 %x, 2048
  ret i32 %xor
if.end:
    tail call void @abort() #2
  unreachable

----->

# %bb.0:
   li a1, 1
   slli a1, a1, 11
   xor a0, a0, a1
   sext.w a1, a0
   beqz a1, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 

Because the result of the sign extension is being used for the comparison, it does not matching the XOR folding pattern.

One though is that this sign extension is not needed because we are comparing against 0, so we can fold away the sign-extension. However, I'm not sure if that is legal or not (i.e. is an LLVM i32 guaranteed to have the upper bits 0'ed when compiling to a 64 bit architecture).

If %x is a function argument with the signext attribute we can assume the i32 is sign extended to the full register size. Likewise the zeroext attribute allows us to assume it is zero extended. If neither attribute is present, we can't make any assumptions. The clang frontend should always put signext on i32 function arguments to match the psABI.

// do not fold
auto IsXorImmediate = [](const SDValue &Op) -> bool {
if (const auto XorCnst = dyn_cast<ConstantSDNode>(Op)) {
auto isLegalXorImmediate = [](int64_t Imm) -> bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lambda doesn't really improve the code. Just use the isInt<12> directly.

auto OnlyUsedBySelectOrBR = [](const SDValue &Op) -> bool {
for (const SDNode *UserNode : Op->users()) {
const unsigned Opcode = UserNode->getOpcode();
if (Opcode != RISCVISD::SELECT_CC && Opcode != RISCVISD::BR_CC) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop curly braces

@@ -17194,8 +17194,41 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
return true;
}

// If XOR is reused and has an immediate that will fit in XORI,
// do not fold
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a period to make this a complete sentence.

@bababuck
Copy link
Contributor Author

If %x is a function argument with the signext attribute we can assume the i32 is sign extended to the full register size. Likewise the zeroext attribute allows us to assume it is zero extended. If neither attribute is present, we can't make any assumptions. The clang frontend should always put signext on i32 function arguments to match the psABI.

Thanks for the detail.

Since we cannot make this assumption for the general case, then would it be reasonable to explicitly check for the case where the LHS is the result of an SEXT of an XOR?

@topperc topperc changed the title Combine (X ^ Y) and (X == Y) where appropriate [CodeGenPrepare][RISCV] Combine (X ^ Y) and (X == Y) where appropriate Mar 12, 2025
@topperc
Copy link
Collaborator

topperc commented Mar 12, 2025

ult of an SEXT of an XOR?

Yes. That's reasonable.

@bababuck bababuck force-pushed the rbuchner/130510 branch 3 times, most recently from e89b0ae to bdf57a5 Compare March 13, 2025 19:25
// Fold ((sext (xor X, C)), 0, eq/ne) -> ((sext(X), C, eq/ne)
if (LHS.getOpcode() == ISD::SIGN_EXTEND_INREG) {
const SDValue LHS0 = LHS.getOperand(0);
if (IsFoldableXorEq(LHS0, RHS) && isa<ConstantSDNode>(LHS0.getOperand(1))) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an extra constraint for the sign extension case that the XOR have one operand that is a constant. This keeps this change limited in scope so that it will only effect this specific case that results from the other changes in this patch.

@bababuck
Copy link
Contributor Author

I have made changes according to the above conversation, this is ready for review again.

Copy link
Member

@lenary lenary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need unwind info in the tests

@bababuck
Copy link
Contributor Author

Updated for latest suggestions.

@bababuck
Copy link
Contributor Author

bababuck commented Apr 1, 2025

@topperc are there any more adjustments I should make before this can be merged?

// If XOR is reused and has an immediate that will fit in XORI,
// do not fold.
auto IsXorImmediate = [](const SDValue &Op) -> bool {
if (const auto XorCnst = dyn_cast<ConstantSDNode>(Op))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const auto *XorCnst. LLVM Coding Standards don't want auto to hide that something is a pointer.

// Fold (X(i1) ^ 1) == 0 -> X != 0
auto SingleBitOp = [&DAG](const SDValue &VarOp,
const SDValue &ConstOp) -> bool {
if (const auto XorCnst = dyn_cast<ConstantSDNode>(ConstOp)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

@@ -17194,12 +17194,56 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
return true;
}

// If XOR is reused and has an immediate that will fit in XORI,
// do not fold.
auto IsXorImmediate = [](const SDValue &Op) -> bool {
Copy link
Member

@mshockwave mshockwave Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lambda functions should following the normal function naming convention, which starts with lower case. ditto for other occurrences here. nvm, this is not written.

bababuck added 2 commits April 1, 2025 15:33
In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for
cases where the (X ^ Y) will be re-used.

Fixes llvm#130510.
@bababuck
Copy link
Contributor Author

bababuck commented Apr 1, 2025

Addressed the prior comments, rebased onto latest main.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mshockwave mshockwave merged commit fa2a6d6 into llvm:main Apr 2, 2025
5 of 7 checks passed
Copy link

github-actions bot commented Apr 2, 2025

@bababuck Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@bababuck
Copy link
Contributor Author

bababuck commented Apr 2, 2025

Thanks for the help everyone!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimization: (x != C) comparison can utilize (x ^ C) result
6 participants