@@ -17759,6 +17759,83 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
17759
17759
return DAG.getZExtOrTrunc(Pop, DL, VT);
17760
17760
}
17761
17761
17762
+ static SDValue performSHLCombine(SDNode *N,
17763
+ TargetLowering::DAGCombinerInfo &DCI,
17764
+ const RISCVSubtarget &Subtarget) {
17765
+ // (shl (zext x), y) -> (vwsll x, y)
17766
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
17767
+ return V;
17768
+
17769
+ // (shl (sext x), C) -> (vwmulsu x, 1u << C)
17770
+ // (shl (zext x), C) -> (vwmulu x, 1u << C)
17771
+
17772
+ if (!DCI.isAfterLegalizeDAG())
17773
+ return SDValue();
17774
+
17775
+ SDValue LHS = N->getOperand(0);
17776
+ if (!LHS.hasOneUse())
17777
+ return SDValue();
17778
+ unsigned Opcode;
17779
+ switch (LHS.getOpcode()) {
17780
+ case ISD::SIGN_EXTEND:
17781
+ case RISCVISD::VSEXT_VL:
17782
+ Opcode = RISCVISD::VWMULSU_VL;
17783
+ break;
17784
+ case ISD::ZERO_EXTEND:
17785
+ case RISCVISD::VZEXT_VL:
17786
+ Opcode = RISCVISD::VWMULU_VL;
17787
+ break;
17788
+ default:
17789
+ return SDValue();
17790
+ }
17791
+
17792
+ SDValue RHS = N->getOperand(1);
17793
+ APInt ShAmt;
17794
+ uint64_t ShAmtInt;
17795
+ if (ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
17796
+ ShAmtInt = ShAmt.getZExtValue();
17797
+ else if (RHS.getOpcode() == RISCVISD::VMV_V_X_VL &&
17798
+ RHS.getOperand(1).getOpcode() == ISD::Constant)
17799
+ ShAmtInt = RHS.getConstantOperandVal(1);
17800
+ else
17801
+ return SDValue();
17802
+
17803
+ // Better foldings:
17804
+ // (shl (sext x), 1) -> (vwadd x, x)
17805
+ // (shl (zext x), 1) -> (vwaddu x, x)
17806
+ if (ShAmtInt <= 1)
17807
+ return SDValue();
17808
+
17809
+ SDValue NarrowOp = LHS.getOperand(0);
17810
+ MVT NarrowVT = NarrowOp.getSimpleValueType();
17811
+ uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
17812
+ if (ShAmtInt >= NarrowBits)
17813
+ return SDValue();
17814
+ MVT VT = N->getSimpleValueType(0);
17815
+ if (NarrowBits * 2 != VT.getScalarSizeInBits())
17816
+ return SDValue();
17817
+
17818
+ SelectionDAG &DAG = DCI.DAG;
17819
+ SDLoc DL(N);
17820
+ SDValue Passthru, Mask, VL;
17821
+ switch (N->getOpcode()) {
17822
+ case ISD::SHL:
17823
+ Passthru = DAG.getUNDEF(VT);
17824
+ std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
17825
+ break;
17826
+ case RISCVISD::SHL_VL:
17827
+ Passthru = N->getOperand(2);
17828
+ Mask = N->getOperand(3);
17829
+ VL = N->getOperand(4);
17830
+ break;
17831
+ default:
17832
+ llvm_unreachable("Expected SHL");
17833
+ }
17834
+ return DAG.getNode(Opcode, DL, VT, NarrowOp,
17835
+ DAG.getConstant(1ULL << ShAmtInt, SDLoc(RHS), NarrowVT),
17836
+ Passthru, Mask, VL);
17837
+ }
17838
+
17762
17839
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
17763
17840
DAGCombinerInfo &DCI) const {
17764
17841
SelectionDAG &DAG = DCI.DAG;
@@ -18392,7 +18469,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
18392
18469
break;
18393
18470
}
18394
18471
case RISCVISD::SHL_VL:
18395
- if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
18472
+ if (SDValue V = performSHLCombine (N, DCI, Subtarget))
18396
18473
return V;
18397
18474
[[fallthrough]];
18398
18475
case RISCVISD::SRA_VL:
@@ -18417,7 +18494,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
18417
18494
case ISD::SRL:
18418
18495
case ISD::SHL: {
18419
18496
if (N->getOpcode() == ISD::SHL) {
18420
- if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
18497
+ if (SDValue V = performSHLCombine (N, DCI, Subtarget))
18421
18498
return V;
18422
18499
}
18423
18500
SDValue ShAmt = N->getOperand(1);
0 commit comments