Skip to content

Commit 89172e9

Browse files
committed
[X86] combineBitcastToBoolVector - add XOR + Constant handling, match existing BITCASTs and limit recursion depth
Add XOR + constant handling to allow us to detect NOT patterns. If a recursive combineBitcastToBoolVector call finds an existing BITCAST node then use that. As combineBitcastToBoolVector is recursive, ensure we limit the maximum recursion depth. Fixes #93000
1 parent fab234a commit 89172e9

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43413,7 +43413,11 @@ static SDValue createMMXBuildVector(BuildVectorSDNode *BV, SelectionDAG &DAG,
4341343413
// the chain.
4341443414
static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
4341543415
SelectionDAG &DAG,
43416-
const X86Subtarget &Subtarget) {
43416+
const X86Subtarget &Subtarget,
43417+
unsigned Depth = 0) {
43418+
if (Depth >= SelectionDAG::MaxRecursionDepth)
43419+
return SDValue(); // Limit search depth.
43420+
4341743421
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4341843422
unsigned Opc = V.getOpcode();
4341943423
switch (Opc) {
@@ -43425,14 +43429,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
4342543429
return DAG.getBitcast(VT, Src);
4342643430
break;
4342743431
}
43432+
case ISD::Constant: {
43433+
auto *C = cast<ConstantSDNode>(V);
43434+
if (C->isZero())
43435+
return DAG.getConstant(0, DL, VT);
43436+
if (C->isAllOnes())
43437+
return DAG.getAllOnesConstant(DL, VT);
43438+
break;
43439+
}
4342843440
case ISD::TRUNCATE: {
4342943441
// If we find a suitable source, a truncated scalar becomes a subvector.
4343043442
SDValue Src = V.getOperand(0);
4343143443
EVT NewSrcVT =
4343243444
EVT::getVectorVT(*DAG.getContext(), MVT::i1, Src.getValueSizeInBits());
4343343445
if (TLI.isTypeLegal(NewSrcVT))
43434-
if (SDValue N0 =
43435-
combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
43446+
if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
43447+
Subtarget, Depth + 1))
4343643448
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N0,
4343743449
DAG.getIntPtrConstant(0, DL));
4343843450
break;
@@ -43444,20 +43456,22 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
4344443456
EVT NewSrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
4344543457
Src.getScalarValueSizeInBits());
4344643458
if (TLI.isTypeLegal(NewSrcVT))
43447-
if (SDValue N0 =
43448-
combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
43459+
if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
43460+
Subtarget, Depth + 1))
4344943461
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
4345043462
Opc == ISD::ANY_EXTEND ? DAG.getUNDEF(VT)
4345143463
: DAG.getConstant(0, DL, VT),
4345243464
N0, DAG.getIntPtrConstant(0, DL));
4345343465
break;
4345443466
}
43455-
case ISD::OR: {
43456-
// If we find suitable sources, we can just move an OR to the vector domain.
43457-
SDValue Src0 = V.getOperand(0);
43458-
SDValue Src1 = V.getOperand(1);
43459-
if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
43460-
if (SDValue N1 = combineBitcastToBoolVector(VT, Src1, DL, DAG, Subtarget))
43467+
case ISD::OR:
43468+
case ISD::XOR: {
43469+
// If we find suitable sources, we can just move the op to the vector
43470+
// domain.
43471+
if (SDValue N0 = combineBitcastToBoolVector(VT, V.getOperand(0), DL, DAG,
43472+
Subtarget, Depth + 1))
43473+
if (SDValue N1 = combineBitcastToBoolVector(VT, V.getOperand(1), DL, DAG,
43474+
Subtarget, Depth + 1))
4346143475
return DAG.getNode(Opc, DL, VT, N0, N1);
4346243476
break;
4346343477
}
@@ -43469,13 +43483,20 @@ static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
4346943483
break;
4347043484

4347143485
if (auto *Amt = dyn_cast<ConstantSDNode>(V.getOperand(1)))
43472-
if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
43486+
if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget,
43487+
Depth + 1))
4347343488
return DAG.getNode(
4347443489
X86ISD::KSHIFTL, DL, VT, N0,
4347543490
DAG.getTargetConstant(Amt->getZExtValue(), DL, MVT::i8));
4347643491
break;
4347743492
}
4347843493
}
43494+
43495+
// Does the inner bitcast already exist?
43496+
if (Depth > 0)
43497+
if (SDNode *Alt = DAG.getNodeIfExists(ISD::BITCAST, DAG.getVTList(VT), {V}))
43498+
return SDValue(Alt, 0);
43499+
4347943500
return SDValue();
4348043501
}
4348143502

llvm/test/CodeGen/X86/pr93000.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ define void @PR93000(ptr %a0, ptr %a1, ptr %a2, <32 x i16> %a3) {
1010
; CHECK-NEXT: .LBB0_1: # %Loop
1111
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
1212
; CHECK-NEXT: kmovd %eax, %k1
13-
; CHECK-NEXT: notl %eax
14-
; CHECK-NEXT: kmovd %eax, %k2
13+
; CHECK-NEXT: knotd %k1, %k2
1514
; CHECK-NEXT: vpblendmw (%rsi), %zmm0, %zmm1 {%k1}
1615
; CHECK-NEXT: vmovdqu16 (%rdx), %zmm1 {%k2}
1716
; CHECK-NEXT: vmovdqu64 %zmm1, (%rsi)

0 commit comments

Comments
 (0)