Skip to content

[X86] Generate kmov for masking integers #120593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
822ae48
Generate `kmov` for masking integers
abhishek-kaushik22 Dec 19, 2024
95c8864
Merge branch 'main' into kmov
abhishek-kaushik22 Dec 19, 2024
3f39f65
Review Changes
abhishek-kaushik22 Dec 20, 2024
47e9a51
Merge branch 'main' into kmov
abhishek-kaushik22 Dec 20, 2024
57c4aa0
Update tests
abhishek-kaushik22 Dec 23, 2024
85b9945
Combine to KMOV
abhishek-kaushik22 Dec 24, 2024
389c871
Merge branch 'main' into kmov
abhishek-kaushik22 Dec 24, 2024
1ae4114
Use getTargetConstantBitsFromNode
abhishek-kaushik22 Jan 22, 2025
76e904b
Merge branch 'main' into kmov
abhishek-kaushik22 Jan 22, 2025
ca6c246
Update test
abhishek-kaushik22 Jan 22, 2025
606d7f6
Update X86ISelLowering.cpp
abhishek-kaushik22 Jan 22, 2025
a51f6cb
Merge branch 'main' into kmov
abhishek-kaushik22 Jan 23, 2025
abd6a4d
Update X86ISelLowering.cpp
abhishek-kaushik22 Jan 23, 2025
1fef3b3
fix reviews
abhishek-kaushik22 Feb 5, 2025
cb51265
fix reviews
abhishek-kaushik22 Feb 6, 2025
11f9dbb
Update X86ISelLowering.cpp
abhishek-kaushik22 Feb 7, 2025
bfc963b
Update X86ISelLowering.cpp
abhishek-kaushik22 Feb 7, 2025
e1a9e35
Address review comments
abhishek-kaushik22 Feb 22, 2025
e209af1
Merge branch 'main' into kmov
abhishek-kaushik22 Feb 22, 2025
b76c86d
Fix tests
abhishek-kaushik22 Feb 22, 2025
0feb0a1
Remove basic block name from tests
abhishek-kaushik22 Feb 25, 2025
250fc9c
Use getVectorIdxConstant instead of getConstant
abhishek-kaushik22 Feb 25, 2025
a8ba133
Merge branch 'main' into kmov
abhishek-kaushik22 Feb 25, 2025
60a5c70
Use DAG.getNOT
abhishek-kaushik22 Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55447,6 +55447,95 @@ static SDValue truncateAVX512SetCCNoBWI(EVT VT, EVT OpVT, SDValue LHS,
return SDValue();
}

// The pattern (setcc (and (broadcast x), (2^n, 2^{n+1}, ...)), (0, 0, ...),
// eq/ne) is generated when using an integer as a mask. Instead of generating a
// broadcast + vptest, we can directly move the integer to a mask register.
static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
const SDLoc &DL, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
if (CC != ISD::SETNE && CC != ISD::SETEQ)
return SDValue();

if (!Subtarget.hasAVX512())
return SDValue();

if (Op0.getOpcode() != ISD::AND)
return SDValue();

SDValue Broadcast = Op0.getOperand(0);
if (Broadcast.getOpcode() != X86ISD::VBROADCAST &&
Broadcast.getOpcode() != X86ISD::VBROADCAST_LOAD)
return SDValue();

SDValue Load = Op0.getOperand(1);
EVT LoadVT = Load.getSimpleValueType();

APInt UndefElts;
SmallVector<APInt, 32> EltBits;
if (!getTargetConstantBitsFromNode(Load, LoadVT.getScalarSizeInBits(),
UndefElts, EltBits,
/*AllowWholeUndefs*/ true,
/*AllowPartialUndefs*/ false) ||
UndefElts[0] || !EltBits[0].isPowerOf2())
return SDValue();

// Check if the constant pool contains only powers of 2 starting from some
// 2^N. The table may also contain undefs because of widening of vector
// operands.
unsigned N = EltBits[0].logBase2();
unsigned Len = UndefElts.getBitWidth();
for (unsigned I = 1; I != Len; ++I) {
if (UndefElts[I]) {
if (!UndefElts.extractBits(Len - (I + 1), I + 1).isAllOnes())
return SDValue();
break;
}

if (EltBits[I].getBitWidth() <= N + I || !EltBits[I].isOneBitSet(N + I))
return SDValue();
}

const TargetLowering &TLI = DAG.getTargetLoweringInfo();
const DataLayout &DataLayout = DAG.getDataLayout();
MVT VecIdxTy = TLI.getVectorIdxTy(DataLayout);
MVT BroadcastOpVT = Broadcast.getSimpleValueType().getVectorElementType();
SDValue BroadcastOp;
if (Broadcast.getOpcode() != X86ISD::VBROADCAST) {
BroadcastOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, BroadcastOpVT,
Broadcast, DAG.getConstant(0, DL, VecIdxTy));
} else {
BroadcastOp = Broadcast.getOperand(0);
if (BroadcastOp.getValueType().isVector())
return SDValue();
}

SDValue Masked = BroadcastOp;
if (N != 0) {
unsigned Mask = (1ULL << Len) - 1;
SDValue ShiftedValue = DAG.getNode(ISD::SRL, DL, BroadcastOpVT, BroadcastOp,
DAG.getConstant(N, DL, BroadcastOpVT));
Masked = DAG.getNode(ISD::AND, DL, BroadcastOpVT, ShiftedValue,
DAG.getConstant(Mask, DL, BroadcastOpVT));
}
// We can't extract more than 16 bits using this pattern, because 2^{17} will
// not fit in an i16 and a vXi32 where X > 16 is more than 512 bits.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the check that VT is not greater than v16i1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the check here: UndefElts.getBitWidth() > 16

SDValue Trunc = DAG.getAnyExtOrTrunc(Masked, DL, MVT::i16);
SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1, Trunc);
MVT PtrTy = TLI.getPointerTy(DataLayout);

if (CC == ISD::SETEQ)
Bitcast =
DAG.getNode(ISD::XOR, DL, MVT::v16i1, Bitcast,
DAG.getSplatBuildVector(MVT::v16i1, DL,
DAG.getAllOnesConstant(DL, PtrTy)));

if (VT != MVT::v16i1)
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Bitcast,
DAG.getConstant(0, DL, PtrTy));

return Bitcast;
}

static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
Expand Down Expand Up @@ -55579,6 +55668,11 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
"Unexpected condition code!");
return Op0.getOperand(0);
}

if (IsVZero1)
if (SDValue V =
combineAVX512SetCCToKMOV(VT, Op0, TmpCC, DL, DAG, Subtarget))
return V;
}

// Try and make unsigned vector comparison signed. On pre AVX512 targets there
Expand Down
Loading