-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[CodeGenPrepare][RISCV] Combine (X ^ Y) and (X == Y) where appropriate #130922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-backend-risc-v Author: Ryan Buchner (bababuck) ChangesFixes #130510. In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for cases where the (X ^ Y) will be re-used. If a constant is being used for the XOR before a branch, ensure that it is small enough to fit within a 12-bit immediate field. Otherwise, the equality check is more efficient than the check against 0, see the following:
Similarly, if the XOR is between 1 and a size one integer, we should still fold away the XOR since that comparison can be optimized as a comparison against 0.
One question about my code is that I used a hard-coded value for the width of a RISCV ALU immediate. Do you know of a way that I can gather this from the Full diff: https://github.com/llvm/llvm-project/pull/130922.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d5fbd4c380746..2acb7cb321d07 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -8578,7 +8578,8 @@ static bool optimizeBranch(BranchInst *Branch, const TargetLowering &TLI,
}
if (Cmp->isEquality() &&
(match(UI, m_Add(m_Specific(X), m_SpecificInt(-CmpC))) ||
- match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))))) {
+ match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))) ||
+ match(UI, m_Xor(m_Specific(X), m_SpecificInt(CmpC))))) {
IRBuilder<> Builder(Branch);
if (UI->getParent() != Branch->getParent())
UI->moveBefore(Branch->getIterator());
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27a4bbce1f5fc..3abc835376f7b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17194,8 +17194,47 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
return true;
}
+ // If XOR is reused and has an immediate that will fit in XORI,
+ // do not fold
+ auto Is12BitConstant = [](const SDValue &Op) -> bool {
+ if (Op.getOpcode() == ISD::Constant) {
+ const int64_t RiscvAluImmBits = 12;
+ const int64_t RiscvAluImmUpperBound = (1 << RiscvAluImmBits) - 1;
+ const int64_t RiscvAluImmLowerBound = -(1 << RiscvAluImmBits);
+ const int64_t XorCnst =
+ llvm::dyn_cast<llvm::ConstantSDNode>(Op)->getSExtValue();
+ return (XorCnst >= RiscvAluImmLowerBound) &&
+ (XorCnst <= RiscvAluImmUpperBound);
+ }
+ return false;
+ };
+ // Fold (X(i1) ^ 1) == 0 -> X != 0
+ auto SingleBitOp = [&DAG](const SDValue &VarOp,
+ const SDValue &ConstOp) -> bool {
+ if (ConstOp.getOpcode() == ISD::Constant) {
+ const int64_t XorCnst =
+ llvm::dyn_cast<llvm::ConstantSDNode>(ConstOp)->getSExtValue();
+ const APInt Mask = APInt::getBitsSetFrom(VarOp.getValueSizeInBits(), 1);
+ return (XorCnst == 1) && DAG.MaskedValueIsZero(VarOp, Mask);
+ }
+ return false;
+ };
+ auto OnlyUsedBySelectOrBR = [](const SDValue &Op) -> bool {
+ for (const SDUse &Use : Op->uses()) {
+ const SDNode *UseNode = Use.getUser();
+ const unsigned Opcode = UseNode->getOpcode();
+ if (Opcode != RISCVISD::SELECT_CC && Opcode != RISCVISD::BR_CC) {
+ return false;
+ }
+ }
+ return true;
+ };
+
// Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne)
- if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) {
+ if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS) &&
+ (!Is12BitConstant(LHS.getOperand(1)) ||
+ SingleBitOp(LHS.getOperand(0), LHS.getOperand(1))) &&
+ OnlyUsedBySelectOrBR(LHS)) {
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
return true;
diff --git a/llvm/test/CodeGen/RISCV/select-constant-xor.ll b/llvm/test/CodeGen/RISCV/select-constant-xor.ll
index 2e26ae78e2dd8..f24e03ecd7d67 100644
--- a/llvm/test/CodeGen/RISCV/select-constant-xor.ll
+++ b/llvm/test/CodeGen/RISCV/select-constant-xor.ll
@@ -239,3 +239,43 @@ define i32 @oneusecmp(i32 %a, i32 %b, i32 %d) {
%x = add i32 %s, %s2
ret i32 %x
}
+
+define i32 @xor_branch_ret(i32 %x) {
+; RV32-LABEL: xor_branch_ret:
+; RV32: # %bb.0: # %entry
+; RV32-NEXT: xori a0, a0, -1365
+; RV32-NEXT: beqz a0, .LBB11_2
+; RV32-NEXT: # %bb.1: # %if.then
+; RV32-NEXT: ret
+; RV32-NEXT: .LBB11_2: # %if.end
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT: .cfi_offset ra, -4
+; RV32-NEXT: call abort
+;
+; RV64-LABEL: xor_branch_ret:
+; RV64: # %bb.0: # %entry
+; RV64-NEXT: xori a0, a0, -1365
+; RV64-NEXT: sext.w a1, a0
+; RV64-NEXT: beqz a1, .LBB11_2
+; RV64-NEXT: # %bb.1: # %if.then
+; RV64-NEXT: ret
+; RV64-NEXT: .LBB11_2: # %if.end
+; RV64-NEXT: addi sp, sp, -16
+; RV64-NEXT: .cfi_def_cfa_offset 16
+; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: .cfi_offset ra, -8
+; RV64-NEXT: call abort
+entry:
+ %cmp.not = icmp eq i32 %x, -1365
+ br i1 %cmp.not, label %if.end, label %if.then
+if.then:
+ %xor = xor i32 %x, -1365
+ ret i32 %xor
+if.end:
+ tail call void @abort() #2
+ unreachable
+}
+
+declare void @abort()
|
d84eb12
to
7b32882
Compare
return false; | ||
}; | ||
auto OnlyUsedBySelectOrBR = [](const SDValue &Op) -> bool { | ||
for (const SDUse &Use : Op->uses()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use Op->users()
const SDValue &ConstOp) -> bool { | ||
if (ConstOp.getOpcode() == ISD::Constant) { | ||
const int64_t XorCnst = | ||
llvm::dyn_cast<llvm::ConstantSDNode>(ConstOp)->getSExtValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The llvm::
is unnecessary. Move the dyn_cast into the if
instead of checking the opcode.
const int64_t RiscvAluImmBits = 12; | ||
const int64_t RiscvAluImmUpperBound = (1 << RiscvAluImmBits) - 1; | ||
const int64_t RiscvAluImmLowerBound = -(1 << RiscvAluImmBits); | ||
const int64_t XorCnst = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The llvm:: is unnecessary. Move the dyn_cast into the if instead of checking the opcode.
auto Is12BitConstant = [](const SDValue &Op) -> bool { | ||
if (Op.getOpcode() == ISD::Constant) { | ||
const int64_t RiscvAluImmBits = 12; | ||
const int64_t RiscvAluImmUpperBound = (1 << RiscvAluImmBits) - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use isInt<12>
like the rest of the RISCV backend.
aa0d2fb
to
9b6b6fa
Compare
I addressed the above comments, but I came across one issue in my tests: When working with riscv64 with i32 an i32 input (a0 is i32), the following code generation occurs:
Because the result of the sign extension is being used for the comparison, it does not matching the XOR folding pattern. One though is that this sign extension is not needed because we are comparing against 0, so we can fold away the sign-extension. However, I'm not sure if that is legal or not (i.e. is an LLVM i32 guaranteed to have the upper bits 0'ed when compiling to a 64 bit architecture). Any thoughts on how to proceed? I want to make sure I am not completely off track. |
If %x is a function argument with the |
// do not fold | ||
auto IsXorImmediate = [](const SDValue &Op) -> bool { | ||
if (const auto XorCnst = dyn_cast<ConstantSDNode>(Op)) { | ||
auto isLegalXorImmediate = [](int64_t Imm) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This lambda doesn't really improve the code. Just use the isInt<12> directly.
auto OnlyUsedBySelectOrBR = [](const SDValue &Op) -> bool { | ||
for (const SDNode *UserNode : Op->users()) { | ||
const unsigned Opcode = UserNode->getOpcode(); | ||
if (Opcode != RISCVISD::SELECT_CC && Opcode != RISCVISD::BR_CC) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop curly braces
@@ -17194,8 +17194,41 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL, | |||
return true; | |||
} | |||
|
|||
// If XOR is reused and has an immediate that will fit in XORI, | |||
// do not fold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a period to make this a complete sentence.
Thanks for the detail. Since we cannot make this assumption for the general case, then would it be reasonable to explicitly check for the case where the LHS is the result of an SEXT of an XOR? |
Yes. That's reasonable. |
e89b0ae
to
bdf57a5
Compare
// Fold ((sext (xor X, C)), 0, eq/ne) -> ((sext(X), C, eq/ne) | ||
if (LHS.getOpcode() == ISD::SIGN_EXTEND_INREG) { | ||
const SDValue LHS0 = LHS.getOperand(0); | ||
if (IsFoldableXorEq(LHS0, RHS) && isa<ConstantSDNode>(LHS0.getOperand(1))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an extra constraint for the sign extension case that the XOR have one operand that is a constant. This keeps this change limited in scope so that it will only effect this specific case that results from the other changes in this patch.
I have made changes according to the above conversation, this is ready for review again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need unwind info in the tests
bdf57a5
to
fabb86c
Compare
Updated for latest suggestions. |
@topperc are there any more adjustments I should make before this can be merged? |
// If XOR is reused and has an immediate that will fit in XORI, | ||
// do not fold. | ||
auto IsXorImmediate = [](const SDValue &Op) -> bool { | ||
if (const auto XorCnst = dyn_cast<ConstantSDNode>(Op)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto *XorCnst
. LLVM Coding Standards don't want auto
to hide that something is a pointer.
// Fold (X(i1) ^ 1) == 0 -> X != 0 | ||
auto SingleBitOp = [&DAG](const SDValue &VarOp, | ||
const SDValue &ConstOp) -> bool { | ||
if (const auto XorCnst = dyn_cast<ConstantSDNode>(ConstOp)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above
@@ -17194,12 +17194,56 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL, | |||
return true; | |||
} | |||
|
|||
// If XOR is reused and has an immediate that will fit in XORI, | |||
// do not fold. | |||
auto IsXorImmediate = [](const SDValue &Op) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lambda functions should following the normal function naming convention, which starts with lower case. ditto for other occurrences here. nvm, this is not written.
In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for cases where the (X ^ Y) will be re-used. Fixes llvm#130510.
fabb86c
to
7433a60
Compare
Addressed the prior comments, rebased onto latest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@bababuck Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
Thanks for the help everyone! |
Fixes #130510.
In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for cases where the (X ^ Y) will be re-used.
If a constant is being used for the XOR before a branch, ensure that it is small enough to fit within a 12-bit immediate field. Otherwise, the equality check is more efficient than the check against 0, see the following:
Similarly, if the XOR is between 1 and a size one integer, we should still fold away the XOR since that comparison can be optimized as a comparison against 0.
One question about my code is that I used a hard-coded value for the width of a RISCV ALU immediate. Do you know of a way that I can gather this from the
context
, I was unable to devise one.