@@ -1462,8 +1462,6 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
1462
1462
1463
1463
const uint64_t C1 = N1C->getZExtValue ();
1464
1464
1465
- // Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a mask
1466
- // with c3 leading zeros and c2 is larger than c3.
1467
1465
if (N0.getOpcode () == ISD::SRA && isa<ConstantSDNode>(N0.getOperand (1 )) &&
1468
1466
N0.hasOneUse ()) {
1469
1467
unsigned C2 = N0.getConstantOperandVal (1 );
@@ -1477,6 +1475,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
1477
1475
X.getOpcode () == ISD::SHL &&
1478
1476
isa<ConstantSDNode>(X.getOperand (1 )) &&
1479
1477
X.getConstantOperandVal (1 ) == 32 ;
1478
+ // Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a
1479
+ // mask with c3 leading zeros and c2 is larger than c3.
1480
1480
if (isMask_64 (C1) && !Skip) {
1481
1481
unsigned Leading = XLen - llvm::bit_width (C1);
1482
1482
if (C2 > Leading) {
@@ -1490,6 +1490,27 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
1490
1490
return ;
1491
1491
}
1492
1492
}
1493
+
1494
+ // Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
1495
+ // leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
1496
+ // use (slli (srli (srai y, c2 - c3), c3 + c4), c4).
1497
+ if (isShiftedMask_64 (C1) && !Skip) {
1498
+ unsigned Leading = XLen - llvm::bit_width (C1);
1499
+ unsigned Trailing = llvm::countr_zero (C1);
1500
+ if (C2 > Leading && Leading > 0 && Trailing > 0 ) {
1501
+ SDNode *SRAI = CurDAG->getMachineNode (
1502
+ RISCV::SRAI, DL, VT, N0.getOperand (0 ),
1503
+ CurDAG->getTargetConstant (C2 - Leading, DL, VT));
1504
+ SDNode *SRLI = CurDAG->getMachineNode (
1505
+ RISCV::SRLI, DL, VT, SDValue (SRAI, 0 ),
1506
+ CurDAG->getTargetConstant (Leading + Trailing, DL, VT));
1507
+ SDNode *SLLI = CurDAG->getMachineNode (
1508
+ RISCV::SLLI, DL, VT, SDValue (SRLI, 0 ),
1509
+ CurDAG->getTargetConstant (Trailing, DL, VT));
1510
+ ReplaceNode (Node, SLLI);
1511
+ return ;
1512
+ }
1513
+ }
1493
1514
}
1494
1515
1495
1516
// If C1 masks off the upper bits only (but can't be formed as an
@@ -3032,6 +3053,33 @@ bool RISCVDAGToDAGISel::selectSHXADDOp(SDValue N, unsigned ShAmt,
3032
3053
return true ;
3033
3054
}
3034
3055
}
3056
+ } else if (N0.getOpcode () == ISD::SRA && N0.hasOneUse () &&
3057
+ isa<ConstantSDNode>(N0.getOperand (1 ))) {
3058
+ uint64_t Mask = N.getConstantOperandVal (1 );
3059
+ unsigned C2 = N0.getConstantOperandVal (1 );
3060
+
3061
+ // Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
3062
+ // leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
3063
+ // use (srli (srai y, c2 - c3), c3 + c4) followed by a SHXADD with c4 as
3064
+ // the X amount.
3065
+ if (isShiftedMask_64 (Mask)) {
3066
+ unsigned XLen = Subtarget->getXLen ();
3067
+ unsigned Leading = XLen - llvm::bit_width (Mask);
3068
+ unsigned Trailing = llvm::countr_zero (Mask);
3069
+ if (C2 > Leading && Leading > 0 && Trailing == ShAmt) {
3070
+ SDLoc DL (N);
3071
+ EVT VT = N.getValueType ();
3072
+ Val = SDValue (CurDAG->getMachineNode (
3073
+ RISCV::SRAI, DL, VT, N0.getOperand (0 ),
3074
+ CurDAG->getTargetConstant (C2 - Leading, DL, VT)),
3075
+ 0 );
3076
+ Val = SDValue (CurDAG->getMachineNode (
3077
+ RISCV::SRLI, DL, VT, Val,
3078
+ CurDAG->getTargetConstant (Leading + ShAmt, DL, VT)),
3079
+ 0 );
3080
+ return true ;
3081
+ }
3082
+ }
3035
3083
}
3036
3084
} else if (bool LeftShift = N.getOpcode () == ISD::SHL;
3037
3085
(LeftShift || N.getOpcode () == ISD::SRL) &&
0 commit comments