@@ -23299,6 +23299,126 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
23299
23299
}
23300
23300
}
23301
23301
23302
+ // We get bad codegen for v8i32 compares on avx targets (without avx2) so if
23303
+ // possible convert to a v8f32 compare.
23304
+ if (VTOp0 == MVT::v8i32 && Subtarget.hasAVX() && !Subtarget.hasAVX2()) {
23305
+ std::optional<KnownBits> KnownOps[2];
23306
+ // Check if an op is known to be in a certain range.
23307
+ auto OpInRange = [&DAG, Op, &KnownOps](unsigned OpNo, bool CmpLT,
23308
+ const APInt Bound) {
23309
+ if (!KnownOps[OpNo].has_value())
23310
+ KnownOps[OpNo] = DAG.computeKnownBits(Op.getOperand(OpNo));
23311
+
23312
+ if (KnownOps[OpNo]->isUnknown())
23313
+ return false;
23314
+
23315
+ std::optional<bool> Res;
23316
+ if (CmpLT)
23317
+ Res = KnownBits::ult(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
23318
+ else
23319
+ Res = KnownBits::ugt(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
23320
+ return Res.has_value() && *Res;
23321
+ };
23322
+
23323
+ bool OkayCvt = false;
23324
+ bool OkayBitcast = false;
23325
+
23326
+ const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(MVT::f32);
23327
+
23328
+ // For cvt up to 1 << (Significand Precision), (1 << 24 for ieee float)
23329
+ const APInt MaxConvertableCvt =
23330
+ APInt(32, (1U << APFloat::semanticsPrecision(Sem)));
23331
+ // For bitcast up to (and including) first inf representation (0x7f800000 +
23332
+ // 1 for ieee float)
23333
+ const APInt MaxConvertableBitcast =
23334
+ APFloat::getInf(Sem).bitcastToAPInt() + 1;
23335
+
23336
+ assert(
23337
+ MaxConvertableBitcast.getBitWidth() == 32 &&
23338
+ MaxConvertableCvt == (1U << 24) &&
23339
+ MaxConvertableBitcast == 0x7f800001 &&
23340
+ "This transform has only been verified to IEEE Single Precision Float");
23341
+
23342
+ // For bitcast we need both lhs/op1 u< MaxConvertableBitcast
23343
+ // NB: It might be worth it to enable to bitcast version for unsigned avx2
23344
+ // comparisons as they typically require multiple instructions to lower
23345
+ // (they don't fit `vpcmpeq`/`vpcmpgt` well).
23346
+ if (OpInRange(1, /*CmpLT*/ true, MaxConvertableBitcast) &&
23347
+ OpInRange(0, /*CmpLT*/ true, MaxConvertableBitcast)) {
23348
+ OkayBitcast = true;
23349
+ }
23350
+ // We want to convert icmp -> fcmp using `sitofp` iff one of the converts
23351
+ // will be constant folded.
23352
+ else if ((DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op1)) ||
23353
+ DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op0)))) {
23354
+ if (isUnsignedIntSetCC(Cond)) {
23355
+ // For cvt + unsigned compare we need both lhs/rhs >= 0 and either lhs
23356
+ // or rhs < MaxConvertableCvt
23357
+
23358
+ if (OpInRange(1, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
23359
+ OpInRange(0, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
23360
+ (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
23361
+ OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt)))
23362
+ OkayCvt = true;
23363
+ } else {
23364
+ // For cvt + signed compare we need abs(lhs) or abs(rhs) <
23365
+ // MaxConvertableCvt
23366
+ if (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
23367
+ OpInRange(1, /*CmpLT*/ false, -MaxConvertableCvt) ||
23368
+ OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt) ||
23369
+ OpInRange(0, /*CmpLT*/ false, -MaxConvertableCvt))
23370
+ OkayCvt = true;
23371
+ }
23372
+ }
23373
+
23374
+ if (OkayBitcast || OkayCvt) {
23375
+ switch (Cond) {
23376
+ default:
23377
+ llvm_unreachable("Unexpected SETCC condition");
23378
+ // Get the new FP condition. Note for the unsigned conditions we have
23379
+ // verified its okay to convert to the signed version.
23380
+ case ISD::SETULT:
23381
+ case ISD::SETLT:
23382
+ Cond = ISD::SETOLT;
23383
+ break;
23384
+ case ISD::SETUGT:
23385
+ case ISD::SETGT:
23386
+ Cond = ISD::SETOGT;
23387
+ break;
23388
+ case ISD::SETULE:
23389
+ case ISD::SETLE:
23390
+ Cond = ISD::SETOLE;
23391
+ break;
23392
+ case ISD::SETUGE:
23393
+ case ISD::SETGE:
23394
+ Cond = ISD::SETOGE;
23395
+ break;
23396
+ case ISD::SETEQ:
23397
+ Cond = ISD::SETOEQ;
23398
+ break;
23399
+ case ISD::SETNE:
23400
+ Cond = ISD::SETONE;
23401
+ break;
23402
+ }
23403
+
23404
+ MVT FpVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
23405
+ SDNodeFlags Flags;
23406
+ Flags.setNoNaNs(true);
23407
+ Flags.setNoInfs(true);
23408
+ Flags.setNoSignedZeros(true);
23409
+ if (OkayBitcast) {
23410
+ Op0 = DAG.getBitcast(FpVT, Op0);
23411
+ Op1 = DAG.getBitcast(FpVT, Op1);
23412
+ } else {
23413
+ Op0 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op0);
23414
+ Op1 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op1);
23415
+ }
23416
+ Op0->setFlags(Flags);
23417
+ Op1->setFlags(Flags);
23418
+ return DAG.getSetCC(dl, VT, Op0, Op1, Cond);
23419
+ }
23420
+ }
23421
+
23302
23422
// Break 256-bit integer vector compare into smaller ones.
23303
23423
if (VT.is256BitVector() && !Subtarget.hasInt256())
23304
23424
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
@@ -41037,6 +41157,126 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
41037
41157
return SDValue();
41038
41158
}
41039
41159
41160
+ // Simplify a decomposed (sext (setcc)). Assumes prior check that
41161
+ // bitwidth(sext)==bitwidth(setcc operands).
41162
+ static SDValue simplifySExtOfDecomposedSetCCImpl(
41163
+ SelectionDAG &DAG, SDLoc &DL, ISD::CondCode CC, SDValue Op0, SDValue Op1,
41164
+ const APInt &OriginalDemandedBits, bool AllowNOT) {
41165
+ // Possible TODO: We could handle any power of two demanded bit + unsigned
41166
+ // comparison. There are no x86 specific comparisons that are unsigned so its
41167
+ // unneeded.
41168
+ if (!OriginalDemandedBits.isSignMask())
41169
+ return SDValue();
41170
+
41171
+ EVT OpVT = Op0.getValueType();
41172
+ // We need need nofpclass(nan inf nzero) to handle floats.
41173
+ auto hasOkayFPFlags = [](SDValue Op) {
41174
+ return Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
41175
+ Op->getFlags().hasNoSignedZeros();
41176
+ };
41177
+
41178
+ if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
41179
+ return SDValue();
41180
+
41181
+ auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
41182
+ if (OpVT.isFloatingPoint()) {
41183
+ const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41184
+ return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
41185
+ }
41186
+ return V0.eq(V1);
41187
+ };
41188
+
41189
+ // Assume we canonicalized constants to Op1. That isn't always true but we
41190
+ // call this function twice with inverted CC/Operands so its fine either way.
41191
+ APInt Op1C;
41192
+ unsigned ValWidth = OriginalDemandedBits.getBitWidth();
41193
+ if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
41194
+ Op1C = APInt::getZero(ValWidth);
41195
+ } else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
41196
+ Op1C = APInt::getAllOnes(ValWidth);
41197
+ } else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
41198
+ Op1C = C->getValueAPF().bitcastToAPInt();
41199
+ } else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
41200
+ Op1C = C->getAPIntValue();
41201
+ } else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
41202
+ // Pass
41203
+ } else {
41204
+ return SDValue();
41205
+ }
41206
+
41207
+ bool Not = false;
41208
+ bool Okay = false;
41209
+ assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
41210
+ "Invalid constant operand");
41211
+
41212
+ switch (CC) {
41213
+ case ISD::SETGE:
41214
+ case ISD::SETOGE:
41215
+ Not = true;
41216
+ [[fallthrough]];
41217
+ case ISD::SETLT:
41218
+ case ISD::SETOLT:
41219
+ // signbit(sext(x s< 0)) == signbit(x)
41220
+ // signbit(sext(x s>= 0)) == signbit(~x)
41221
+ Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
41222
+ break;
41223
+ case ISD::SETGT:
41224
+ case ISD::SETOGT:
41225
+ Not = true;
41226
+ [[fallthrough]];
41227
+ case ISD::SETLE:
41228
+ case ISD::SETOLE:
41229
+ // signbit(sext(x s<= -1)) == signbit(x)
41230
+ // signbit(sext(x s> -1)) == signbit(~x)
41231
+ Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
41232
+ break;
41233
+ case ISD::SETULT:
41234
+ Not = true;
41235
+ [[fallthrough]];
41236
+ case ISD::SETUGE:
41237
+ // signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
41238
+ // signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
41239
+ Okay = ValsEq(Op1C, OriginalDemandedBits);
41240
+ break;
41241
+ case ISD::SETULE:
41242
+ Not = true;
41243
+ [[fallthrough]];
41244
+ case ISD::SETUGT:
41245
+ // signbit(sext(x u> SIGNED_MAX)) == signbit(x)
41246
+ // signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
41247
+ Okay = ValsEq(Op1C, OriginalDemandedBits - 1);
41248
+ break;
41249
+ default:
41250
+ break;
41251
+ }
41252
+
41253
+ Okay = Not ? AllowNOT : Okay;
41254
+ if (!Okay)
41255
+ return SDValue();
41256
+
41257
+ if (!Not)
41258
+ return Op0;
41259
+
41260
+ if (!OpVT.isFloatingPoint())
41261
+ return DAG.getNOT(DL, Op0, OpVT);
41262
+
41263
+ // Possible TODO: We could use `fneg` to do not.
41264
+ return SDValue();
41265
+ }
41266
+
41267
+ static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, SDLoc &DL,
41268
+ ISD::CondCode CC, SDValue Op0,
41269
+ SDValue Op1,
41270
+ const APInt &OriginalDemandedBits,
41271
+ bool AllowNOT) {
41272
+ if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
41273
+ DAG, DL, CC, Op0, Op1, OriginalDemandedBits, AllowNOT))
41274
+ return R;
41275
+ return simplifySExtOfDecomposedSetCCImpl(
41276
+ DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
41277
+ AllowNOT);
41278
+ }
41279
+
41040
41280
// Simplify variable target shuffle masks based on the demanded elements.
41041
41281
// TODO: Handle DemandedBits in mask indices as well?
41042
41282
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
@@ -42200,13 +42440,24 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
42200
42440
}
42201
42441
break;
42202
42442
}
42203
- case X86ISD::PCMPGT:
42204
- // icmp sgt(0, R) == ashr(R, BitWidth-1).
42205
- // iff we only need the sign bit then we can use R directly.
42206
- if (OriginalDemandedBits.isSignMask() &&
42207
- ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42208
- return TLO.CombineTo(Op, Op.getOperand(1));
42443
+ case X86ISD::PCMPGT: {
42444
+ SDLoc DL(Op);
42445
+ if (SDValue R = simplifySExtOfDecomposedSetCC(
42446
+ TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42447
+ OriginalDemandedBits, !(TLO.LegalOperations() && TLO.LegalTypes())))
42448
+ return TLO.CombineTo(Op, R);
42449
+ break;
42450
+ }
42451
+ case X86ISD::CMPP: {
42452
+ SDLoc DL(Op);
42453
+ ISD::CondCode CC = X86::getCondForCMPPImm(
42454
+ cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42455
+ if (SDValue R = simplifySExtOfDecomposedSetCC(
42456
+ TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
42457
+ OriginalDemandedBits, !(TLO.LegalOperations() && TLO.LegalTypes())))
42458
+ return TLO.CombineTo(Op, R);
42209
42459
break;
42460
+ }
42210
42461
case X86ISD::MOVMSK: {
42211
42462
SDValue Src = Op.getOperand(0);
42212
42463
MVT SrcVT = Src.getSimpleValueType();
@@ -42390,13 +42641,24 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
42390
42641
if (DemandedBits.isSignMask())
42391
42642
return Op.getOperand(0);
42392
42643
break;
42393
- case X86ISD::PCMPGT:
42394
- // icmp sgt(0, R) == ashr(R, BitWidth-1).
42395
- // iff we only need the sign bit then we can use R directly.
42396
- if (DemandedBits.isSignMask() &&
42397
- ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42398
- return Op.getOperand(1);
42644
+ case X86ISD::PCMPGT: {
42645
+ SDLoc DL(Op);
42646
+ if (SDValue R = simplifySExtOfDecomposedSetCC(
42647
+ DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42648
+ DemandedBits, /*AllowNOT*/ false))
42649
+ return R;
42650
+ break;
42651
+ }
42652
+ case X86ISD::CMPP: {
42653
+ SDLoc DL(Op);
42654
+ ISD::CondCode CC = X86::getCondForCMPPImm(
42655
+ cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42656
+ if (SDValue R = simplifySExtOfDecomposedSetCC(
42657
+ DAG, DL, CC, Op.getOperand(0), Op.getOperand(1), DemandedBits,
42658
+ /*AllowNOT*/ false))
42659
+ return R;
42399
42660
break;
42661
+ }
42400
42662
case X86ISD::BLENDV: {
42401
42663
// BLENDV: Cond (MSB) ? LHS : RHS
42402
42664
SDValue Cond = Op.getOperand(0);
0 commit comments