Skip to content

[RISCV] Add isel optimization for (and (sra y, c2), c1) to recover regression from #101751. #104114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,8 +1451,6 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {

const uint64_t C1 = N1C->getZExtValue();

// Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a mask
// with c3 leading zeros and c2 is larger than c3.
if (N0.getOpcode() == ISD::SRA && isa<ConstantSDNode>(N0.getOperand(1)) &&
N0.hasOneUse()) {
unsigned C2 = N0.getConstantOperandVal(1);
Expand All @@ -1466,6 +1464,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
X.getOpcode() == ISD::SHL &&
isa<ConstantSDNode>(X.getOperand(1)) &&
X.getConstantOperandVal(1) == 32;
// Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a
// mask with c3 leading zeros and c2 is larger than c3.
if (isMask_64(C1) && !Skip) {
unsigned Leading = XLen - llvm::bit_width(C1);
if (C2 > Leading) {
Expand All @@ -1479,6 +1479,27 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
return;
}
}

// Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
// leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
// use (slli (srli (srai y, c2 - c3), c3 + c4), c4).
if (isShiftedMask_64(C1) && !Skip) {
unsigned Leading = XLen - llvm::bit_width(C1);
unsigned Trailing = llvm::countr_zero(C1);
if (C2 > Leading && Leading > 0 && Trailing > 0) {
SDNode *SRAI = CurDAG->getMachineNode(
RISCV::SRAI, DL, VT, N0.getOperand(0),
CurDAG->getTargetConstant(C2 - Leading, DL, VT));
SDNode *SRLI = CurDAG->getMachineNode(
RISCV::SRLI, DL, VT, SDValue(SRAI, 0),
CurDAG->getTargetConstant(Leading + Trailing, DL, VT));
SDNode *SLLI = CurDAG->getMachineNode(
RISCV::SLLI, DL, VT, SDValue(SRLI, 0),
CurDAG->getTargetConstant(Trailing, DL, VT));
ReplaceNode(Node, SLLI);
return;
}
}
}

// If C1 masks off the upper bits only (but can't be formed as an
Expand Down Expand Up @@ -3019,6 +3040,33 @@ bool RISCVDAGToDAGISel::selectSHXADDOp(SDValue N, unsigned ShAmt,
return true;
}
}
} else if (N0.getOpcode() == ISD::SRA && N0.hasOneUse() &&
isa<ConstantSDNode>(N.getOperand(1))) {
Copy link
Member

@dtcxzyw dtcxzyw Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isa<ConstantSDNode>(N0.getOperand(1))

uint64_t Mask = N.getConstantOperandVal(1);
unsigned C2 = N0.getConstantOperandVal(1);

// Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
// leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
// use (srli (srai y, c2 - c3), c3 + c4) followed by a SHXADD with c4 as
// the X amount.
if (isShiftedMask_64(Mask)) {
unsigned XLen = Subtarget->getXLen();
unsigned Leading = XLen - llvm::bit_width(Mask);
unsigned Trailing = llvm::countr_zero(Mask);
if (C2 > Leading && Leading > 0 && Trailing == ShAmt) {
SDLoc DL(N);
EVT VT = N.getValueType();
Val = SDValue(CurDAG->getMachineNode(
RISCV::SRAI, DL, VT, N0.getOperand(0),
CurDAG->getTargetConstant(C2 - Leading, DL, VT)),
0);
Val = SDValue(CurDAG->getMachineNode(
RISCV::SRLI, DL, VT, Val,
CurDAG->getTargetConstant(Leading + ShAmt, DL, VT)),
0);
return true;
}
}
}
} else if (bool LeftShift = N.getOpcode() == ISD::SHL;
(LeftShift || N.getOpcode() == ISD::SRL) &&
Expand Down
66 changes: 66 additions & 0 deletions llvm/test/CodeGen/RISCV/rv64zba.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2988,3 +2988,69 @@ entry:
%2 = and i64 %1, 34359738360
ret i64 %2
}

define ptr @srai_srli_sh3add(ptr %0, i64 %1) nounwind {
; RV64I-LABEL: srai_srli_sh3add:
; RV64I: # %bb.0: # %entry
; RV64I-NEXT: srai a1, a1, 32
; RV64I-NEXT: srli a1, a1, 6
; RV64I-NEXT: slli a1, a1, 3
; RV64I-NEXT: add a0, a0, a1
; RV64I-NEXT: ret
;
; RV64ZBA-LABEL: srai_srli_sh3add:
; RV64ZBA: # %bb.0: # %entry
; RV64ZBA-NEXT: srai a1, a1, 32
; RV64ZBA-NEXT: srli a1, a1, 6
; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ret
entry:
%2 = ashr i64 %1, 32
%3 = lshr i64 %2, 6
%4 = getelementptr i64, ptr %0, i64 %3
ret ptr %4
}

define ptr @srai_srli_slli(ptr %0, i64 %1) nounwind {
; CHECK-LABEL: srai_srli_slli:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: srai a1, a1, 32
; CHECK-NEXT: srli a1, a1, 6
; CHECK-NEXT: slli a1, a1, 4
; CHECK-NEXT: add a0, a0, a1
; CHECK-NEXT: ret
entry:
%2 = ashr i64 %1, 32
%3 = lshr i64 %2, 6
%4 = getelementptr i128, ptr %0, i64 %3
ret ptr %4
}

; Negative to make sure the peephole added for srai_srli_slli and
; srai_srli_sh3add doesn't break this.
define i64 @srai_andi(i64 %x) nounwind {
; CHECK-LABEL: srai_andi:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: srai a0, a0, 8
; CHECK-NEXT: andi a0, a0, -8
; CHECK-NEXT: ret
entry:
%y = ashr i64 %x, 8
%z = and i64 %y, -8
ret i64 %z
}

; Negative to make sure the peephole added for srai_srli_slli and
; srai_srli_sh3add doesn't break this.
define i64 @srai_lui_and(i64 %x) nounwind {
; CHECK-LABEL: srai_lui_and:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: srai a0, a0, 8
; CHECK-NEXT: lui a1, 1048574
; CHECK-NEXT: and a0, a0, a1
; CHECK-NEXT: ret
entry:
%y = ashr i64 %x, 8
%z = and i64 %y, -8192
ret i64 %z
}
Loading