@@ -41346,6 +41346,156 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
41346
41346
return SDValue();
41347
41347
}
41348
41348
41349
+ // Simplify a decomposed (sext (setcc)). Assumes prior check that
41350
+ // bitwidth(sext)==bitwidth(setcc operands).
41351
+ static SDValue simplifySExtOfDecomposedSetCCImpl(
41352
+ SelectionDAG &DAG, const SDLoc &DL, ISD::CondCode CC, SDValue Op0,
41353
+ SDValue Op1, const APInt &OriginalDemandedBits,
41354
+ const APInt &OriginalDemandedElts, bool AllowNOT, unsigned Depth) {
41355
+ // Possible TODO: We could handle any power of two demanded bit + unsigned
41356
+ // comparison. There are no x86 specific comparisons that are unsigned so its
41357
+ // unneeded.
41358
+ if (!OriginalDemandedBits.isSignMask())
41359
+ return SDValue();
41360
+
41361
+ EVT OpVT = Op0.getValueType();
41362
+ // We need need nofpclass(nan inf nzero) to handle floats.
41363
+ auto hasOkayFPFlags = [](SDValue Op) {
41364
+ return Op.getOpcode() == ISD::SINT_TO_FP ||
41365
+ Op.getOpcode() == ISD::UINT_TO_FP ||
41366
+ (Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
41367
+ Op->getFlags().hasNoSignedZeros());
41368
+ };
41369
+
41370
+ if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
41371
+ return SDValue();
41372
+
41373
+ auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
41374
+ if (OpVT.isFloatingPoint()) {
41375
+ const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41376
+ return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
41377
+ }
41378
+ return V0.eq(V1);
41379
+ };
41380
+
41381
+ // Assume we canonicalized constants to Op1. That isn't always true but we
41382
+ // call this function twice with inverted CC/Operands so its fine either way.
41383
+ APInt Op1C;
41384
+ unsigned ValWidth = OriginalDemandedBits.getBitWidth();
41385
+ if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
41386
+ Op1C = APInt::getZero(ValWidth);
41387
+ } else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
41388
+ Op1C = APInt::getAllOnes(ValWidth);
41389
+ } else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
41390
+ Op1C = C->getValueAPF().bitcastToAPInt();
41391
+ } else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
41392
+ Op1C = C->getAPIntValue();
41393
+ } else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
41394
+ // isConstantSplatVector sets `Op1C`.
41395
+ } else {
41396
+ return SDValue();
41397
+ }
41398
+
41399
+ bool Not = false;
41400
+ bool Okay = false;
41401
+ assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
41402
+ "Invalid constant operand");
41403
+
41404
+ switch (CC) {
41405
+ case ISD::SETGE:
41406
+ case ISD::SETOGE:
41407
+ Not = true;
41408
+ [[fallthrough]];
41409
+ case ISD::SETLT:
41410
+ case ISD::SETOLT:
41411
+ // signbit(sext(x s< 0)) == signbit(x)
41412
+ // signbit(sext(x s>= 0)) == signbit(~x)
41413
+ Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
41414
+ // For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break
41415
+ // this fold.
41416
+ // NB: We only need de-norm check here, for the rest of the constants any
41417
+ // relationship with a de-norm value and zero will be identical.
41418
+ if (Okay && OpVT.isFloatingPoint()) {
41419
+ // Values from integers are always normal.
41420
+ if (Op0.getOpcode() == ISD::SINT_TO_FP ||
41421
+ Op0.getOpcode() == ISD::UINT_TO_FP)
41422
+ break;
41423
+
41424
+ // See if we can prove normal with known bits.
41425
+ KnownBits Op0Known =
41426
+ DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth);
41427
+ // Negative/positive doesn't matter.
41428
+ Op0Known.One.clearSignBit();
41429
+ Op0Known.Zero.clearSignBit();
41430
+
41431
+ // Get min normal value.
41432
+ const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41433
+ KnownBits MinNormal = KnownBits::makeConstant(
41434
+ APFloat::getSmallestNormalized(Sem).bitcastToAPInt());
41435
+ // Are we above de-norm range?
41436
+ std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal);
41437
+ Okay = Op0Normal.value_or(false);
41438
+ }
41439
+ break;
41440
+ case ISD::SETGT:
41441
+ case ISD::SETOGT:
41442
+ Not = true;
41443
+ [[fallthrough]];
41444
+ case ISD::SETLE:
41445
+ case ISD::SETOLE:
41446
+ // signbit(sext(x s<= -1)) == signbit(x)
41447
+ // signbit(sext(x s> -1)) == signbit(~x)
41448
+ Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
41449
+ break;
41450
+ case ISD::SETULT:
41451
+ Not = true;
41452
+ [[fallthrough]];
41453
+ case ISD::SETUGE:
41454
+ // signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
41455
+ // signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
41456
+ Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits);
41457
+ break;
41458
+ case ISD::SETULE:
41459
+ Not = true;
41460
+ [[fallthrough]];
41461
+ case ISD::SETUGT:
41462
+ // signbit(sext(x u> SIGNED_MAX)) == signbit(x)
41463
+ // signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
41464
+ Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1);
41465
+ break;
41466
+ default:
41467
+ break;
41468
+ }
41469
+
41470
+ Okay &= Not ? AllowNOT : true;
41471
+ if (!Okay)
41472
+ return SDValue();
41473
+
41474
+ if (!Not)
41475
+ return Op0;
41476
+
41477
+ if (!OpVT.isFloatingPoint())
41478
+ return DAG.getNOT(DL, Op0, OpVT);
41479
+
41480
+ // Possible TODO: We could use `fneg` to do not.
41481
+ return SDValue();
41482
+ }
41483
+
41484
+ static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, const SDLoc &DL,
41485
+ ISD::CondCode CC, SDValue Op0,
41486
+ SDValue Op1,
41487
+ const APInt &OriginalDemandedBits,
41488
+ const APInt &OriginalDemandedElts,
41489
+ bool AllowNOT, unsigned Depth) {
41490
+ if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
41491
+ DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts,
41492
+ AllowNOT, Depth))
41493
+ return R;
41494
+ return simplifySExtOfDecomposedSetCCImpl(
41495
+ DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
41496
+ OriginalDemandedElts, AllowNOT, Depth);
41497
+ }
41498
+
41349
41499
// Simplify variable target shuffle masks based on the demanded elements.
41350
41500
// TODO: Handle DemandedBits in mask indices as well?
41351
41501
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
@@ -42525,13 +42675,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
42525
42675
}
42526
42676
break;
42527
42677
}
42528
- case X86ISD::PCMPGT:
42529
- // icmp sgt(0, R) == ashr(R, BitWidth-1).
42530
- // iff we only need the sign bit then we can use R directly.
42531
- if (OriginalDemandedBits.isSignMask() &&
42532
- ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42533
- return TLO.CombineTo(Op, Op.getOperand(1));
42678
+ case X86ISD::PCMPGT: {
42679
+ SDLoc DL(Op);
42680
+ if (SDValue R = simplifySExtOfDecomposedSetCC(
42681
+ TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42682
+ OriginalDemandedBits, OriginalDemandedElts,
42683
+ /*AllowNOT*/ true, Depth))
42684
+ return TLO.CombineTo(Op, R);
42685
+ break;
42686
+ }
42687
+ case X86ISD::CMPP: {
42688
+ SDLoc DL(Op);
42689
+ ISD::CondCode CC = X86::getCondForCMPPImm(
42690
+ cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42691
+ if (SDValue R = simplifySExtOfDecomposedSetCC(
42692
+ TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
42693
+ OriginalDemandedBits, OriginalDemandedElts,
42694
+ !(TLO.LegalOperations() && TLO.LegalTypes()), Depth))
42695
+ return TLO.CombineTo(Op, R);
42534
42696
break;
42697
+ }
42535
42698
case X86ISD::MOVMSK: {
42536
42699
SDValue Src = Op.getOperand(0);
42537
42700
MVT SrcVT = Src.getSimpleValueType();
@@ -42715,13 +42878,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
42715
42878
if (DemandedBits.isSignMask())
42716
42879
return Op.getOperand(0);
42717
42880
break;
42718
- case X86ISD::PCMPGT:
42719
- // icmp sgt(0, R) == ashr(R, BitWidth-1).
42720
- // iff we only need the sign bit then we can use R directly.
42721
- if (DemandedBits.isSignMask() &&
42722
- ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42723
- return Op.getOperand(1);
42881
+ case X86ISD::PCMPGT: {
42882
+ SDLoc DL(Op);
42883
+ if (SDValue R = simplifySExtOfDecomposedSetCC(
42884
+ DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42885
+ DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth))
42886
+ return R;
42887
+ break;
42888
+ }
42889
+ case X86ISD::CMPP: {
42890
+ SDLoc DL(Op);
42891
+ ISD::CondCode CC = X86::getCondForCMPPImm(
42892
+ cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42893
+ if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0),
42894
+ Op.getOperand(1),
42895
+ DemandedBits, DemandedElts,
42896
+ /*AllowNOT*/ false, Depth))
42897
+ return R;
42724
42898
break;
42899
+ }
42725
42900
case X86ISD::BLENDV: {
42726
42901
// BLENDV: Cond (MSB) ? LHS : RHS
42727
42902
SDValue Cond = Op.getOperand(0);
@@ -48397,7 +48572,7 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG,
48397
48572
48398
48573
// We do not split for SSE at all, but we need to split vectors for AVX1 and
48399
48574
// AVX2.
48400
- if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
48575
+ if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
48401
48576
TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) {
48402
48577
SDValue LoX, HiX;
48403
48578
std::tie(LoX, HiX) = splitVector(X, DAG, DL);
0 commit comments