@@ -43413,7 +43413,11 @@ static SDValue createMMXBuildVector(BuildVectorSDNode *BV, SelectionDAG &DAG,
43413
43413
// the chain.
43414
43414
static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
43415
43415
SelectionDAG &DAG,
43416
- const X86Subtarget &Subtarget) {
43416
+ const X86Subtarget &Subtarget,
43417
+ unsigned Depth = 0) {
43418
+ if (Depth >= SelectionDAG::MaxRecursionDepth)
43419
+ return SDValue(); // Limit search depth.
43420
+
43417
43421
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
43418
43422
unsigned Opc = V.getOpcode();
43419
43423
switch (Opc) {
@@ -43425,14 +43429,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
43425
43429
return DAG.getBitcast(VT, Src);
43426
43430
break;
43427
43431
}
43432
+ case ISD::Constant: {
43433
+ auto *C = cast<ConstantSDNode>(V);
43434
+ if (C->isZero())
43435
+ return DAG.getConstant(0, DL, VT);
43436
+ if (C->isAllOnes())
43437
+ return DAG.getAllOnesConstant(DL, VT);
43438
+ break;
43439
+ }
43428
43440
case ISD::TRUNCATE: {
43429
43441
// If we find a suitable source, a truncated scalar becomes a subvector.
43430
43442
SDValue Src = V.getOperand(0);
43431
43443
EVT NewSrcVT =
43432
43444
EVT::getVectorVT(*DAG.getContext(), MVT::i1, Src.getValueSizeInBits());
43433
43445
if (TLI.isTypeLegal(NewSrcVT))
43434
- if (SDValue N0 =
43435
- combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
43446
+ if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
43447
+ Subtarget, Depth + 1 ))
43436
43448
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N0,
43437
43449
DAG.getIntPtrConstant(0, DL));
43438
43450
break;
@@ -43444,20 +43456,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
43444
43456
EVT NewSrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
43445
43457
Src.getScalarValueSizeInBits());
43446
43458
if (TLI.isTypeLegal(NewSrcVT))
43447
- if (SDValue N0 =
43448
- combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
43459
+ if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
43460
+ Subtarget, Depth + 1 ))
43449
43461
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
43450
43462
Opc == ISD::ANY_EXTEND ? DAG.getUNDEF(VT)
43451
43463
: DAG.getConstant(0, DL, VT),
43452
43464
N0, DAG.getIntPtrConstant(0, DL));
43453
43465
break;
43454
43466
}
43455
- case ISD::OR: {
43456
- // If we find suitable sources, we can just move an OR to the vector domain.
43457
- SDValue Src0 = V.getOperand(0);
43458
- SDValue Src1 = V.getOperand(1);
43459
- if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
43460
- if (SDValue N1 = combineBitcastToBoolVector(VT, Src1, DL, DAG, Subtarget))
43467
+ case ISD::OR:
43468
+ case ISD::XOR: {
43469
+ // If we find suitable sources, we can just move the op to the vector
43470
+ // domain.
43471
+ if (SDValue N0 = combineBitcastToBoolVector(VT, V.getOperand(0), DL, DAG,
43472
+ Subtarget, Depth + 1))
43473
+ if (SDValue N1 = combineBitcastToBoolVector(VT, V.getOperand(1), DL, DAG,
43474
+ Subtarget, Depth + 1))
43461
43475
return DAG.getNode(Opc, DL, VT, N0, N1);
43462
43476
break;
43463
43477
}
@@ -43469,13 +43483,20 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
43469
43483
break;
43470
43484
43471
43485
if (auto *Amt = dyn_cast<ConstantSDNode>(V.getOperand(1)))
43472
- if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
43486
+ if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget,
43487
+ Depth + 1))
43473
43488
return DAG.getNode(
43474
43489
X86ISD::KSHIFTL, DL, VT, N0,
43475
43490
DAG.getTargetConstant(Amt->getZExtValue(), DL, MVT::i8));
43476
43491
break;
43477
43492
}
43478
43493
}
43494
+
43495
+ // Does the inner bitcast already exist?
43496
+ if (Depth > 0)
43497
+ if (SDNode *Alt = DAG.getNodeIfExists(ISD::BITCAST, DAG.getVTList(VT), {V}))
43498
+ return SDValue(Alt, 0);
43499
+
43479
43500
return SDValue();
43480
43501
}
43481
43502
0 commit comments