Skip to content

Commit 17857d9

Browse files
[X86] Generate kmov for masking integers (llvm#120593)
When we have an integer used as a bit mask the llvm ir looks something like this ``` %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768> %cmp1 = icmp ne <16 x i32> %1, zeroinitializer ``` where `.splat` is vector containing the mask in all lanes. The assembly generated for this looks like ``` vpbroadcastd %ecx, %zmm0 vptestmd .LCPI0_0(%rip), %zmm0, %k1 ``` where we have a constant table of powers of 2. Instead of doing this we could just move the relevant bits directly to `k` registers using a `kmov` instruction. ``` kmovw %ecx, %k1 ``` This is faster and also reduces code size.
1 parent dddfd77 commit 17857d9

File tree

3 files changed

+735
-3
lines changed

3 files changed

+735
-3
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55751,6 +55751,88 @@ static SDValue truncateAVX512SetCCNoBWI(EVT VT, EVT OpVT, SDValue LHS,
5575155751
return SDValue();
5575255752
}
5575355753

55754+
// The pattern (setcc (and (broadcast x), (2^n, 2^{n+1}, ...)), (0, 0, ...),
55755+
// eq/ne) is generated when using an integer as a mask. Instead of generating a
55756+
// broadcast + vptest, we can directly move the integer to a mask register.
55757+
static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
55758+
const SDLoc &DL, SelectionDAG &DAG,
55759+
const X86Subtarget &Subtarget) {
55760+
if (CC != ISD::SETNE && CC != ISD::SETEQ)
55761+
return SDValue();
55762+
55763+
if (!Subtarget.hasAVX512())
55764+
return SDValue();
55765+
55766+
if (Op0.getOpcode() != ISD::AND)
55767+
return SDValue();
55768+
55769+
SDValue Broadcast = Op0.getOperand(0);
55770+
if (Broadcast.getOpcode() != X86ISD::VBROADCAST &&
55771+
Broadcast.getOpcode() != X86ISD::VBROADCAST_LOAD)
55772+
return SDValue();
55773+
55774+
SDValue Load = Op0.getOperand(1);
55775+
EVT LoadVT = Load.getSimpleValueType();
55776+
55777+
APInt UndefElts;
55778+
SmallVector<APInt, 32> EltBits;
55779+
if (!getTargetConstantBitsFromNode(Load, LoadVT.getScalarSizeInBits(),
55780+
UndefElts, EltBits,
55781+
/*AllowWholeUndefs*/ true,
55782+
/*AllowPartialUndefs*/ false) ||
55783+
UndefElts[0] || !EltBits[0].isPowerOf2() || UndefElts.getBitWidth() > 16)
55784+
return SDValue();
55785+
55786+
// Check if the constant pool contains only powers of 2 starting from some
55787+
// 2^N. The table may also contain undefs because of widening of vector
55788+
// operands.
55789+
unsigned N = EltBits[0].logBase2();
55790+
unsigned Len = UndefElts.getBitWidth();
55791+
for (unsigned I = 1; I != Len; ++I) {
55792+
if (UndefElts[I]) {
55793+
if (!UndefElts.extractBits(Len - (I + 1), I + 1).isAllOnes())
55794+
return SDValue();
55795+
break;
55796+
}
55797+
55798+
if (EltBits[I].getBitWidth() <= N + I || !EltBits[I].isOneBitSet(N + I))
55799+
return SDValue();
55800+
}
55801+
55802+
MVT BroadcastOpVT = Broadcast.getSimpleValueType().getVectorElementType();
55803+
SDValue BroadcastOp;
55804+
if (Broadcast.getOpcode() != X86ISD::VBROADCAST) {
55805+
BroadcastOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, BroadcastOpVT,
55806+
Broadcast, DAG.getVectorIdxConstant(0, DL));
55807+
} else {
55808+
BroadcastOp = Broadcast.getOperand(0);
55809+
if (BroadcastOp.getValueType().isVector())
55810+
return SDValue();
55811+
}
55812+
55813+
SDValue Masked = BroadcastOp;
55814+
if (N != 0) {
55815+
APInt Mask = APInt::getLowBitsSet(BroadcastOpVT.getSizeInBits(), Len);
55816+
SDValue ShiftedValue = DAG.getNode(ISD::SRL, DL, BroadcastOpVT, BroadcastOp,
55817+
DAG.getConstant(N, DL, BroadcastOpVT));
55818+
Masked = DAG.getNode(ISD::AND, DL, BroadcastOpVT, ShiftedValue,
55819+
DAG.getConstant(Mask, DL, BroadcastOpVT));
55820+
}
55821+
// We can't extract more than 16 bits using this pattern, because 2^{17} will
55822+
// not fit in an i16 and a vXi32 where X > 16 is more than 512 bits.
55823+
SDValue Trunc = DAG.getAnyExtOrTrunc(Masked, DL, MVT::i16);
55824+
SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1, Trunc);
55825+
55826+
if (CC == ISD::SETEQ)
55827+
Bitcast = DAG.getNOT(DL, Bitcast, MVT::v16i1);
55828+
55829+
if (VT != MVT::v16i1)
55830+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Bitcast,
55831+
DAG.getVectorIdxConstant(0, DL));
55832+
55833+
return Bitcast;
55834+
}
55835+
5575455836
static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
5575555837
TargetLowering::DAGCombinerInfo &DCI,
5575655838
const X86Subtarget &Subtarget) {
@@ -55883,6 +55965,11 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
5588355965
"Unexpected condition code!");
5588455966
return Op0.getOperand(0);
5588555967
}
55968+
55969+
if (IsVZero1)
55970+
if (SDValue V =
55971+
combineAVX512SetCCToKMOV(VT, Op0, TmpCC, DL, DAG, Subtarget))
55972+
return V;
5588655973
}
5588755974

5588855975
// Try and make unsigned vector comparison signed. On pre AVX512 targets there

0 commit comments

Comments
 (0)