@@ -3240,40 +3240,53 @@ static SDValue performBitcastCombine(SDNode *N,
3240
3240
return SDValue ();
3241
3241
}
3242
3242
3243
- static SDValue performAnyTrueCombine (SDNode *N, SelectionDAG &DAG) {
3244
- // any_true (setcc <X>, 0, eq)
3245
- // => not (all_true X )
3246
-
3247
- SDLoc DL (N);
3243
+ static SDValue performAnyAllCombine (SDNode *N, SelectionDAG &DAG) {
3244
+ // any_true (setcc <X>, 0, eq) => (not (all_true X))
3245
+ // all_true (setcc <X>, 0, eq) => ( not (any_true X) )
3246
+ // any_true (setcc <X>, 0, ne) => (any_true X)
3247
+ // all_true (setcc <X>, 0, ne) => (all_true X)
3248
3248
assert (N->getOpcode () == ISD::INTRINSIC_WO_CHAIN);
3249
- if (N->getConstantOperandVal (0 ) != Intrinsic::wasm_anytrue)
3250
- return SDValue ();
3249
+ using namespace llvm ::SDPatternMatch;
3250
+ SDLoc DL (N);
3251
+ static auto SimdCombiner =
3252
+ [&](Intrinsic::WASMIntrinsics InPre, ISD::CondCode SetType,
3253
+ Intrinsic::WASMIntrinsics InPost, bool ShouldInvert) -> SDValue {
3254
+ if (N->getConstantOperandVal (0 ) != InPre)
3255
+ return SDValue ();
3251
3256
3252
- SDValue SetCC = N->getOperand (1 );
3253
- if (SetCC.getOpcode () != ISD::SETCC)
3254
- return SDValue ();
3257
+ SDValue LHS;
3258
+ if (!sd_match (N->getOperand (1 ), m_c_SetCC (m_Value (LHS), m_Zero (),
3259
+ m_SpecificCondCode (SetType))))
3260
+ return SDValue ();
3255
3261
3256
- SDValue LHS = SetCC->getOperand (0 );
3257
- SDValue RHS = SetCC->getOperand (1 );
3258
- ISD::CondCode Cond = cast<CondCodeSDNode>(SetCC->getOperand (2 ))->get ();
3259
- EVT LT = LHS.getValueType ();
3260
- unsigned NumElts = LT.getVectorNumElements ();
3261
- if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16 )
3262
- return SDValue ();
3262
+ EVT LT = LHS.getValueType ();
3263
+ unsigned NumElts = LT.getVectorNumElements ();
3264
+ if (LT.getScalarSizeInBits () > 128 / NumElts)
3265
+ return SDValue ();
3263
3266
3264
- EVT Width = MVT::getIntegerVT (128 / NumElts);
3267
+ SDValue Ret = DAG.getZExtOrTrunc (
3268
+ DAG.getNode (ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32 ,
3269
+ {DAG.getConstant (InPost, DL, MVT::i32 ), LHS}),
3270
+ DL, MVT::i1);
3271
+ if (ShouldInvert)
3272
+ Ret = DAG.getNOT (DL, Ret, MVT::i1);
3273
+ return DAG.getZExtOrTrunc (Ret, DL, N->getValueType (0 ));
3274
+ };
3265
3275
3266
- if (!isNullOrNullSplat (RHS) || Cond != ISD::SETEQ)
3267
- return SDValue ();
3276
+ if (SDValue AnyTrueEQ = SimdCombiner (Intrinsic::wasm_anytrue, ISD::SETEQ,
3277
+ Intrinsic::wasm_alltrue, true ))
3278
+ return AnyTrueEQ;
3279
+ if (SDValue AllTrueEQ = SimdCombiner (Intrinsic::wasm_alltrue, ISD::SETEQ,
3280
+ Intrinsic::wasm_anytrue, true ))
3281
+ return AllTrueEQ;
3282
+ if (SDValue AnyTrueNE = SimdCombiner (Intrinsic::wasm_anytrue, ISD::SETNE,
3283
+ Intrinsic::wasm_anytrue, false ))
3284
+ return AnyTrueNE;
3285
+ if (SDValue AllTrueNE = SimdCombiner (Intrinsic::wasm_alltrue, ISD::SETNE,
3286
+ Intrinsic::wasm_alltrue, false ))
3287
+ return AllTrueNE;
3268
3288
3269
- SDValue Ret = DAG.getZExtOrTrunc (
3270
- DAG.getNode (
3271
- ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32 ,
3272
- {DAG.getConstant (Intrinsic::wasm_alltrue, DL, MVT::i32 ),
3273
- DAG.getSExtOrTrunc (LHS, DL, LT.changeVectorElementType (Width))}),
3274
- DL, MVT::i1);
3275
- Ret = DAG.getNOT (DL, Ret, MVT::i1);
3276
- return DAG.getZExtOrTrunc (Ret, DL, N->getValueType (0 ));
3289
+ return SDValue ();
3277
3290
}
3278
3291
3279
3292
template <int MatchRHS, ISD::CondCode MatchCond, bool RequiresNegate,
@@ -3465,8 +3478,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
3465
3478
case ISD::TRUNCATE:
3466
3479
return performTruncateCombine (N, DCI);
3467
3480
case ISD::INTRINSIC_WO_CHAIN: {
3468
- if (auto AnyTrueCombine = performAnyTrueCombine (N, DCI.DAG ))
3469
- return AnyTrueCombine ;
3481
+ if (auto AnyAllCombine = performAnyAllCombine (N, DCI.DAG ))
3482
+ return AnyAllCombine ;
3470
3483
return performLowerPartialReduction (N, DCI.DAG );
3471
3484
}
3472
3485
case ISD::MUL:
0 commit comments