Skip to content

Commit ff64740

Browse files
committed
[ValueTracking][X86] Compute KnownBits for phadd/phsub
Add KnownBits computations to ValueTracking and X86 DAG lowering. These instructions add/subtract adjacent vector elements in their operands. Example: phadd [X1, X2] [Y1, Y2] = [X1 + X2, Y1 + Y2]. This means that, in this example, we can compute the KnownBits of the operation by computing the KnownBits of [X1, X2] + [X1, X2] and [Y1, Y2] + [Y1, Y2] and intersecting the results. This approach also generalizes to all x86 vector types. There are also the operations phadd.sw and phsub.sw, which perform saturating addition/subtraction. Use sadd_sat and ssub_sat to compute the KnownBits of these operations. Also adjust the existing test case pr53247.ll because it can be transformed to a constant using the new KnownBits computation. Fixes #82516.
1 parent 3caccd8 commit ff64740

File tree

8 files changed

+288
-153
lines changed

8 files changed

+288
-153
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,23 @@ void processShuffleMasks(
246246
function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
247247
function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
248248

249+
/// Compute the demanded elements mask of horizontal binary operations. A
250+
/// horizontal operation combines two adjacent elements in a vector operand.
251+
/// This function returns a mask for the elements that correspond to the first
252+
/// operand of this horizontal combination. For example, for two vectors
253+
/// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
254+
/// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
255+
/// result of this function to the left by 1.
256+
///
257+
/// \param VectorBitWidth the total bit width of the vector
258+
/// \param DemandedElts the demanded elements mask for the operation
259+
/// \param DemandedLHS the demanded elements mask for the left operand
260+
/// \param DemandedRHS the demanded elements mask for the right operand
261+
void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
262+
const APInt &DemandedElts,
263+
APInt &DemandedLHS,
264+
APInt &DemandedRHS);
265+
249266
/// Compute a map of integer instructions to their minimum legal type
250267
/// size.
251268
///

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,43 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
950950
return KnownOut;
951951
}
952952

953+
static KnownBits computeKnownBitsForHorizontalOperation(
954+
const Operator *I, const APInt &DemandedElts, unsigned Depth,
955+
const SimplifyQuery &Q,
956+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
957+
KnownBitsFunc) {
958+
APInt DemandedEltsLHS, DemandedEltsRHS;
959+
getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
960+
DemandedElts, DemandedEltsLHS,
961+
DemandedEltsRHS);
962+
963+
std::array<KnownBits, 2> KnownLHS;
964+
for (unsigned Index = 0; Index < KnownLHS.size(); ++Index) {
965+
if (!DemandedEltsLHS.isZero()) {
966+
KnownLHS[Index] =
967+
computeKnownBits(I->getOperand(0), DemandedEltsLHS, Depth + 1, Q);
968+
} else {
969+
KnownLHS[Index] = KnownBits(I->getType()->getScalarSizeInBits());
970+
KnownLHS[Index].setAllZero();
971+
}
972+
DemandedEltsLHS <<= 1;
973+
}
974+
std::array<KnownBits, 2> KnownRHS;
975+
for (unsigned Index = 0; Index < KnownRHS.size(); ++Index) {
976+
if (!DemandedEltsRHS.isZero()) {
977+
KnownRHS[Index] =
978+
computeKnownBits(I->getOperand(1), DemandedEltsRHS, Depth + 1, Q);
979+
} else {
980+
KnownRHS[Index] = KnownBits(I->getType()->getScalarSizeInBits());
981+
KnownRHS[Index].setAllZero();
982+
}
983+
DemandedEltsRHS <<= 1;
984+
}
985+
986+
return KnownBitsFunc(KnownLHS[0], KnownLHS[1])
987+
.intersectWith(KnownBitsFunc(KnownRHS[0], KnownRHS[1]));
988+
}
989+
953990
// Public so this can be used in `SimplifyDemandedUseBits`.
954991
KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
955992
const KnownBits &KnownLHS,
@@ -1725,6 +1762,56 @@ static void computeKnownBitsFromOperator(const Operator *I,
17251762
case Intrinsic::x86_sse42_crc32_64_64:
17261763
Known.Zero.setBitsFrom(32);
17271764
break;
1765+
case Intrinsic::x86_ssse3_phadd_d:
1766+
case Intrinsic::x86_ssse3_phadd_w:
1767+
case Intrinsic::x86_ssse3_phadd_d_128:
1768+
case Intrinsic::x86_ssse3_phadd_w_128:
1769+
case Intrinsic::x86_avx2_phadd_d:
1770+
case Intrinsic::x86_avx2_phadd_w: {
1771+
Known = computeKnownBitsForHorizontalOperation(
1772+
I, DemandedElts, Depth, Q,
1773+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1774+
return KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
1775+
/*NUW=*/false, KnownLHS,
1776+
KnownRHS);
1777+
});
1778+
break;
1779+
}
1780+
case Intrinsic::x86_ssse3_phadd_sw:
1781+
case Intrinsic::x86_ssse3_phadd_sw_128:
1782+
case Intrinsic::x86_avx2_phadd_sw: {
1783+
Known = computeKnownBitsForHorizontalOperation(
1784+
I, DemandedElts, Depth, Q,
1785+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1786+
return KnownBits::sadd_sat(KnownLHS, KnownRHS);
1787+
});
1788+
break;
1789+
}
1790+
case Intrinsic::x86_ssse3_phsub_d:
1791+
case Intrinsic::x86_ssse3_phsub_w:
1792+
case Intrinsic::x86_ssse3_phsub_d_128:
1793+
case Intrinsic::x86_ssse3_phsub_w_128:
1794+
case Intrinsic::x86_avx2_phsub_d:
1795+
case Intrinsic::x86_avx2_phsub_w: {
1796+
Known = computeKnownBitsForHorizontalOperation(
1797+
I, DemandedElts, Depth, Q,
1798+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1799+
return KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
1800+
/*NUW=*/false, KnownLHS,
1801+
KnownRHS);
1802+
});
1803+
break;
1804+
}
1805+
case Intrinsic::x86_ssse3_phsub_sw:
1806+
case Intrinsic::x86_ssse3_phsub_sw_128:
1807+
case Intrinsic::x86_avx2_phsub_sw: {
1808+
Known = computeKnownBitsForHorizontalOperation(
1809+
I, DemandedElts, Depth, Q,
1810+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1811+
return KnownBits::ssub_sat(KnownLHS, KnownRHS);
1812+
});
1813+
break;
1814+
}
17281815
case Intrinsic::riscv_vsetvli:
17291816
case Intrinsic::riscv_vsetvlimax: {
17301817
bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,33 @@ void llvm::processShuffleMasks(
541541
}
542542
}
543543

544+
void llvm::getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
545+
const APInt &DemandedElts,
546+
APInt &DemandedLHS,
547+
APInt &DemandedRHS) {
548+
int NumLanes = std::max<int>(1, VectorBitWidth / 128);
549+
int NumElts = DemandedElts.getBitWidth();
550+
int NumEltsPerLane = NumElts / NumLanes;
551+
int HalfEltsPerLane = NumEltsPerLane / 2;
552+
553+
DemandedLHS = APInt::getZero(NumElts);
554+
DemandedRHS = APInt::getZero(NumElts);
555+
556+
// Map DemandedElts to the horizontal operands.
557+
for (int Idx = 0; Idx != NumElts; ++Idx) {
558+
if (!DemandedElts[Idx])
559+
continue;
560+
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
561+
int LocalIdx = Idx % NumEltsPerLane;
562+
if (LocalIdx < HalfEltsPerLane) {
563+
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx);
564+
} else {
565+
LocalIdx -= HalfEltsPerLane;
566+
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx);
567+
}
568+
}
569+
}
570+
544571
MapVector<Instruction *, uint64_t>
545572
llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
546573
const TargetTransformInfo *TTI) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 101 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5180,29 +5180,10 @@ static void getPackDemandedElts(EVT VT, const APInt &DemandedElts,
51805180
// Split the demanded elts of a HADD/HSUB node between its operands.
51815181
static void getHorizDemandedElts(EVT VT, const APInt &DemandedElts,
51825182
APInt &DemandedLHS, APInt &DemandedRHS) {
5183-
int NumLanes = VT.getSizeInBits() / 128;
5184-
int NumElts = DemandedElts.getBitWidth();
5185-
int NumEltsPerLane = NumElts / NumLanes;
5186-
int HalfEltsPerLane = NumEltsPerLane / 2;
5187-
5188-
DemandedLHS = APInt::getZero(NumElts);
5189-
DemandedRHS = APInt::getZero(NumElts);
5190-
5191-
// Map DemandedElts to the horizontal operands.
5192-
for (int Idx = 0; Idx != NumElts; ++Idx) {
5193-
if (!DemandedElts[Idx])
5194-
continue;
5195-
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
5196-
int LocalIdx = Idx % NumEltsPerLane;
5197-
if (LocalIdx < HalfEltsPerLane) {
5198-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5199-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5200-
} else {
5201-
LocalIdx -= HalfEltsPerLane;
5202-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5203-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5204-
}
5205-
}
5183+
getHorizDemandedEltsForFirstOperand(VT.getSizeInBits(), DemandedElts,
5184+
DemandedLHS, DemandedRHS);
5185+
DemandedLHS |= DemandedLHS << 1;
5186+
DemandedRHS |= DemandedRHS << 1;
52065187
}
52075188

52085189
/// Calculates the shuffle mask corresponding to the target-specific opcode.
@@ -36953,6 +36934,43 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
3695336934
Known = Known.zext(64);
3695436935
}
3695536936

36937+
static KnownBits computeKnownBitsForHorizontalOperation(
36938+
const SDValue Op, const APInt &DemandedElts, unsigned Depth,
36939+
unsigned OpIndexStart, const SelectionDAG &DAG,
36940+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
36941+
KnownBitsFunc) {
36942+
APInt DemandedEltsLHS, DemandedEltsRHS;
36943+
getHorizDemandedEltsForFirstOperand(Op.getValueType().getSizeInBits(),
36944+
DemandedElts, DemandedEltsLHS,
36945+
DemandedEltsRHS);
36946+
36947+
std::array<KnownBits, 2> KnownLHS;
36948+
for (unsigned Index = 0; Index < KnownLHS.size(); ++Index) {
36949+
if (!DemandedEltsLHS.isZero()) {
36950+
KnownLHS[Index] = DAG.computeKnownBits(Op.getOperand(OpIndexStart),
36951+
DemandedEltsLHS, Depth + 1);
36952+
} else {
36953+
KnownLHS[Index] = KnownBits(Op.getScalarValueSizeInBits());
36954+
KnownLHS[Index].setAllZero();
36955+
}
36956+
DemandedEltsLHS <<= 1;
36957+
}
36958+
std::array<KnownBits, 2> KnownRHS;
36959+
for (unsigned Index = 0; Index < KnownRHS.size(); ++Index) {
36960+
if (!DemandedEltsRHS.isZero()) {
36961+
KnownRHS[Index] = DAG.computeKnownBits(Op.getOperand(OpIndexStart + 1),
36962+
DemandedEltsRHS, Depth + 1);
36963+
} else {
36964+
KnownRHS[Index] = KnownBits(Op.getScalarValueSizeInBits());
36965+
KnownRHS[Index].setAllZero();
36966+
}
36967+
DemandedEltsRHS <<= 1;
36968+
}
36969+
36970+
return KnownBitsFunc(KnownLHS[0], KnownLHS[1])
36971+
.intersectWith(KnownBitsFunc(KnownRHS[0], KnownRHS[1]));
36972+
}
36973+
3695636974
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3695736975
KnownBits &Known,
3695836976
const APInt &DemandedElts,
@@ -37262,6 +37280,17 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3726237280
}
3726337281
break;
3726437282
}
37283+
case X86ISD::HADD:
37284+
case X86ISD::HSUB: {
37285+
Known = computeKnownBitsForHorizontalOperation(
37286+
Op, DemandedElts, Depth, /*OpIndexStart=*/0, DAG,
37287+
[Opc](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37288+
return KnownBits::computeForAddSub(
37289+
/*Add=*/Opc == X86ISD::HADD, /*NSW=*/false, /*NUW=*/false,
37290+
KnownLHS, KnownRHS);
37291+
});
37292+
break;
37293+
}
3726537294
case ISD::INTRINSIC_WO_CHAIN: {
3726637295
switch (Op->getConstantOperandVal(0)) {
3726737296
case Intrinsic::x86_sse2_psad_bw:
@@ -37276,6 +37305,55 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3727637305
computeKnownBitsForPSADBW(LHS, RHS, Known, DemandedElts, DAG, Depth);
3727737306
break;
3727837307
}
37308+
case Intrinsic::x86_ssse3_phadd_d:
37309+
case Intrinsic::x86_ssse3_phadd_w:
37310+
case Intrinsic::x86_ssse3_phadd_d_128:
37311+
case Intrinsic::x86_ssse3_phadd_w_128:
37312+
case Intrinsic::x86_avx2_phadd_d:
37313+
case Intrinsic::x86_avx2_phadd_w: {
37314+
Known = computeKnownBitsForHorizontalOperation(
37315+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37316+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37317+
return KnownBits::computeForAddSub(
37318+
/*Add=*/true, /*NSW=*/false, /*NUW=*/false, KnownLHS, KnownRHS);
37319+
});
37320+
break;
37321+
}
37322+
case Intrinsic::x86_ssse3_phadd_sw:
37323+
case Intrinsic::x86_ssse3_phadd_sw_128:
37324+
case Intrinsic::x86_avx2_phadd_sw: {
37325+
Known = computeKnownBitsForHorizontalOperation(
37326+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37327+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37328+
return KnownBits::sadd_sat(KnownLHS, KnownRHS);
37329+
});
37330+
break;
37331+
}
37332+
case Intrinsic::x86_ssse3_phsub_d:
37333+
case Intrinsic::x86_ssse3_phsub_w:
37334+
case Intrinsic::x86_ssse3_phsub_d_128:
37335+
case Intrinsic::x86_ssse3_phsub_w_128:
37336+
case Intrinsic::x86_avx2_phsub_d:
37337+
case Intrinsic::x86_avx2_phsub_w: {
37338+
Known = computeKnownBitsForHorizontalOperation(
37339+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37340+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37341+
return KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
37342+
/*NUW=*/false, KnownLHS,
37343+
KnownRHS);
37344+
});
37345+
break;
37346+
}
37347+
case Intrinsic::x86_ssse3_phsub_sw:
37348+
case Intrinsic::x86_ssse3_phsub_sw_128:
37349+
case Intrinsic::x86_avx2_phsub_sw: {
37350+
Known = computeKnownBitsForHorizontalOperation(
37351+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37352+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37353+
return KnownBits::ssub_sat(KnownLHS, KnownRHS);
37354+
});
37355+
break;
37356+
}
3727937357
}
3728037358
break;
3728137359
}

0 commit comments

Comments
 (0)