@@ -16185,24 +16185,24 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
16185
16185
16186
16186
// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
16187
16187
// value for the truncated type.
16188
- static SDValue combineTruncToVnclipu (SDNode *N, SelectionDAG &DAG,
16189
- const RISCVSubtarget &Subtarget) {
16188
+ static SDValue combineTruncToVnclip (SDNode *N, SelectionDAG &DAG,
16189
+ const RISCVSubtarget &Subtarget) {
16190
16190
assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
16191
16191
16192
16192
MVT VT = N->getSimpleValueType(0);
16193
16193
16194
16194
SDValue Mask = N->getOperand(1);
16195
16195
SDValue VL = N->getOperand(2);
16196
16196
16197
- SDValue Src = N->getOperand(0);
16197
+ auto MatchMinMax = [&VL, &Mask](SDValue V, unsigned Opc, unsigned OpcVL,
16198
+ APInt &SplatVal) {
16199
+ if (V.getOpcode() != Opc &&
16200
+ !(V.getOpcode() == OpcVL && V.getOperand(2).isUndef() &&
16201
+ V.getOperand(3) == Mask && V.getOperand(4) == VL))
16202
+ return SDValue();
16198
16203
16199
- // Src must be a UMIN or UMIN_VL.
16200
- if (Src.getOpcode() != ISD::UMIN &&
16201
- !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
16202
- Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
16203
- return SDValue();
16204
+ SDValue Op = V.getOperand(1);
16204
16205
16205
- auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
16206
16206
// Peek through conversion between fixed and scalable vectors.
16207
16207
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
16208
16208
isNullConstant(Op.getOperand(2)) &&
@@ -16213,32 +16213,45 @@ static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
16213
16213
Op = Op.getOperand(1).getOperand(0);
16214
16214
16215
16215
if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
16216
- return true ;
16216
+ return V.getOperand(0) ;
16217
16217
16218
16218
if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
16219
16219
Op.getOperand(2) == VL) {
16220
16220
if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
16221
16221
SplatVal =
16222
16222
Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
16223
- return true ;
16223
+ return V.getOperand(0) ;
16224
16224
}
16225
16225
}
16226
16226
16227
- return false ;
16227
+ return SDValue() ;
16228
16228
};
16229
16229
16230
- APInt C;
16231
- if (!IsSplat(Src.getOperand(1), C))
16232
- return SDValue();
16230
+ auto DetectUSatPattern = [&](SDValue V) {
16231
+ // Src must be a UMIN or UMIN_VL.
16232
+ APInt C;
16233
+ SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, C);
16234
+ if (!UMin)
16235
+ return SDValue();
16236
+
16237
+ if (!C.isMask(VT.getScalarSizeInBits()))
16238
+ return SDValue();
16233
16239
16234
- if (!C.isMask(VT.getScalarSizeInBits()))
16240
+ return UMin;
16241
+ };
16242
+
16243
+ SDValue Val;
16244
+ unsigned ClipOpc;
16245
+ if ((Val = DetectUSatPattern(N->getOperand(0))))
16246
+ ClipOpc = RISCVISD::VNCLIPU_VL;
16247
+ else
16235
16248
return SDValue();
16236
16249
16237
16250
SDLoc DL(N);
16238
16251
// Rounding mode here is arbitrary since we aren't shifting out any bits.
16239
16252
return DAG.getNode(
16240
- RISCVISD::VNCLIPU_VL , DL, VT,
16241
- {Src.getOperand(0) , DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
16253
+ ClipOpc , DL, VT,
16254
+ {Val , DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
16242
16255
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
16243
16256
VL});
16244
16257
}
@@ -16462,7 +16475,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
16462
16475
case RISCVISD::TRUNCATE_VECTOR_VL:
16463
16476
if (SDValue V = combineTruncOfSraSext(N, DAG))
16464
16477
return V;
16465
- return combineTruncToVnclipu (N, DAG, Subtarget);
16478
+ return combineTruncToVnclip (N, DAG, Subtarget);
16466
16479
case ISD::TRUNCATE:
16467
16480
return performTRUNCATECombine(N, DAG, Subtarget);
16468
16481
case ISD::SELECT:
0 commit comments