Skip to content

Commit 0381e01

Browse files
committed
Recommit "[RISCV] Add isel optimization for (and (sra y, c2), c1) to recover regression from #101751. (#104114)"
Fixed an incorrect cast. Original message: If c1 is a shifted mask with c3 leading zeros and c4 trailing zeros. If c2 is greater than c3, we can use (srli (srai y, c2 - c3), c3 + c4) followed by a SHXADD with c4 as the X amount. Without Zba we can use (slli (srli (srai y, c2 - c3), c3 + c4), c4). Alive2: https://alive2.llvm.org/ce/z/AwhheR
1 parent fd7904a commit 0381e01

File tree

2 files changed

+116
-2
lines changed

2 files changed

+116
-2
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,8 +1462,6 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
14621462

14631463
const uint64_t C1 = N1C->getZExtValue();
14641464

1465-
// Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a mask
1466-
// with c3 leading zeros and c2 is larger than c3.
14671465
if (N0.getOpcode() == ISD::SRA && isa<ConstantSDNode>(N0.getOperand(1)) &&
14681466
N0.hasOneUse()) {
14691467
unsigned C2 = N0.getConstantOperandVal(1);
@@ -1477,6 +1475,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
14771475
X.getOpcode() == ISD::SHL &&
14781476
isa<ConstantSDNode>(X.getOperand(1)) &&
14791477
X.getConstantOperandVal(1) == 32;
1478+
// Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a
1479+
// mask with c3 leading zeros and c2 is larger than c3.
14801480
if (isMask_64(C1) && !Skip) {
14811481
unsigned Leading = XLen - llvm::bit_width(C1);
14821482
if (C2 > Leading) {
@@ -1490,6 +1490,27 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
14901490
return;
14911491
}
14921492
}
1493+
1494+
// Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
1495+
// leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
1496+
// use (slli (srli (srai y, c2 - c3), c3 + c4), c4).
1497+
if (isShiftedMask_64(C1) && !Skip) {
1498+
unsigned Leading = XLen - llvm::bit_width(C1);
1499+
unsigned Trailing = llvm::countr_zero(C1);
1500+
if (C2 > Leading && Leading > 0 && Trailing > 0) {
1501+
SDNode *SRAI = CurDAG->getMachineNode(
1502+
RISCV::SRAI, DL, VT, N0.getOperand(0),
1503+
CurDAG->getTargetConstant(C2 - Leading, DL, VT));
1504+
SDNode *SRLI = CurDAG->getMachineNode(
1505+
RISCV::SRLI, DL, VT, SDValue(SRAI, 0),
1506+
CurDAG->getTargetConstant(Leading + Trailing, DL, VT));
1507+
SDNode *SLLI = CurDAG->getMachineNode(
1508+
RISCV::SLLI, DL, VT, SDValue(SRLI, 0),
1509+
CurDAG->getTargetConstant(Trailing, DL, VT));
1510+
ReplaceNode(Node, SLLI);
1511+
return;
1512+
}
1513+
}
14931514
}
14941515

14951516
// If C1 masks off the upper bits only (but can't be formed as an
@@ -3032,6 +3053,33 @@ bool RISCVDAGToDAGISel::selectSHXADDOp(SDValue N, unsigned ShAmt,
30323053
return true;
30333054
}
30343055
}
3056+
} else if (N0.getOpcode() == ISD::SRA && N0.hasOneUse() &&
3057+
isa<ConstantSDNode>(N0.getOperand(1))) {
3058+
uint64_t Mask = N.getConstantOperandVal(1);
3059+
unsigned C2 = N0.getConstantOperandVal(1);
3060+
3061+
// Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
3062+
// leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
3063+
// use (srli (srai y, c2 - c3), c3 + c4) followed by a SHXADD with c4 as
3064+
// the X amount.
3065+
if (isShiftedMask_64(Mask)) {
3066+
unsigned XLen = Subtarget->getXLen();
3067+
unsigned Leading = XLen - llvm::bit_width(Mask);
3068+
unsigned Trailing = llvm::countr_zero(Mask);
3069+
if (C2 > Leading && Leading > 0 && Trailing == ShAmt) {
3070+
SDLoc DL(N);
3071+
EVT VT = N.getValueType();
3072+
Val = SDValue(CurDAG->getMachineNode(
3073+
RISCV::SRAI, DL, VT, N0.getOperand(0),
3074+
CurDAG->getTargetConstant(C2 - Leading, DL, VT)),
3075+
0);
3076+
Val = SDValue(CurDAG->getMachineNode(
3077+
RISCV::SRLI, DL, VT, Val,
3078+
CurDAG->getTargetConstant(Leading + ShAmt, DL, VT)),
3079+
0);
3080+
return true;
3081+
}
3082+
}
30353083
}
30363084
} else if (bool LeftShift = N.getOpcode() == ISD::SHL;
30373085
(LeftShift || N.getOpcode() == ISD::SRL) &&

llvm/test/CodeGen/RISCV/rv64zba.ll

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2988,3 +2988,69 @@ entry:
29882988
%2 = and i64 %1, 34359738360
29892989
ret i64 %2
29902990
}
2991+
2992+
define ptr @srai_srli_sh3add(ptr %0, i64 %1) nounwind {
2993+
; RV64I-LABEL: srai_srli_sh3add:
2994+
; RV64I: # %bb.0: # %entry
2995+
; RV64I-NEXT: srai a1, a1, 32
2996+
; RV64I-NEXT: srli a1, a1, 6
2997+
; RV64I-NEXT: slli a1, a1, 3
2998+
; RV64I-NEXT: add a0, a0, a1
2999+
; RV64I-NEXT: ret
3000+
;
3001+
; RV64ZBA-LABEL: srai_srli_sh3add:
3002+
; RV64ZBA: # %bb.0: # %entry
3003+
; RV64ZBA-NEXT: srai a1, a1, 32
3004+
; RV64ZBA-NEXT: srli a1, a1, 6
3005+
; RV64ZBA-NEXT: sh3add a0, a1, a0
3006+
; RV64ZBA-NEXT: ret
3007+
entry:
3008+
%2 = ashr i64 %1, 32
3009+
%3 = lshr i64 %2, 6
3010+
%4 = getelementptr i64, ptr %0, i64 %3
3011+
ret ptr %4
3012+
}
3013+
3014+
define ptr @srai_srli_slli(ptr %0, i64 %1) nounwind {
3015+
; CHECK-LABEL: srai_srli_slli:
3016+
; CHECK: # %bb.0: # %entry
3017+
; CHECK-NEXT: srai a1, a1, 32
3018+
; CHECK-NEXT: srli a1, a1, 6
3019+
; CHECK-NEXT: slli a1, a1, 4
3020+
; CHECK-NEXT: add a0, a0, a1
3021+
; CHECK-NEXT: ret
3022+
entry:
3023+
%2 = ashr i64 %1, 32
3024+
%3 = lshr i64 %2, 6
3025+
%4 = getelementptr i128, ptr %0, i64 %3
3026+
ret ptr %4
3027+
}
3028+
3029+
; Negative to make sure the peephole added for srai_srli_slli and
3030+
; srai_srli_sh3add doesn't break this.
3031+
define i64 @srai_andi(i64 %x) nounwind {
3032+
; CHECK-LABEL: srai_andi:
3033+
; CHECK: # %bb.0: # %entry
3034+
; CHECK-NEXT: srai a0, a0, 8
3035+
; CHECK-NEXT: andi a0, a0, -8
3036+
; CHECK-NEXT: ret
3037+
entry:
3038+
%y = ashr i64 %x, 8
3039+
%z = and i64 %y, -8
3040+
ret i64 %z
3041+
}
3042+
3043+
; Negative to make sure the peephole added for srai_srli_slli and
3044+
; srai_srli_sh3add doesn't break this.
3045+
define i64 @srai_lui_and(i64 %x) nounwind {
3046+
; CHECK-LABEL: srai_lui_and:
3047+
; CHECK: # %bb.0: # %entry
3048+
; CHECK-NEXT: srai a0, a0, 8
3049+
; CHECK-NEXT: lui a1, 1048574
3050+
; CHECK-NEXT: and a0, a0, a1
3051+
; CHECK-NEXT: ret
3052+
entry:
3053+
%y = ashr i64 %x, 8
3054+
%z = and i64 %y, -8192
3055+
ret i64 %z
3056+
}

0 commit comments

Comments
 (0)