Skip to content

Commit 8c7dc00

Browse files
committed
[ValueTracking][X86] Compute KnownBits for phadd/phsub
Add KnownBits computations to ValueTracking and X86 DAG lowering. These instructions add/subtract adjacent vector elements in their operands. Example: phadd [X1, X2] [Y1, Y2] = [X1 + X2, Y1 + Y2]. This means that, in this example, we can compute the KnownBits of the operation by computing the KnownBits of [X1, X2] + [X1, X2] and [Y1, Y2] + [Y1, Y2] and intersecting the results. This approach also generalizes to all x86 vector types. There are also the operations phadd.sw and phsub.sw, which perform saturating addition/subtraction. Use sadd_sat and ssub_sat to compute the KnownBits of these operations. Also adjust the existing test case pr53247.ll because it can be transformed to a constant using the new KnownBits computation. Fixes #82516.
1 parent 3caccd8 commit 8c7dc00

File tree

8 files changed

+269
-153
lines changed

8 files changed

+269
-153
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,23 @@ void processShuffleMasks(
246246
function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
247247
function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
248248

249+
/// Compute the demanded elements mask of horizontal binary operations. A
250+
/// horizontal operation combines two adjacent elements in a vector operand.
251+
/// This function returns a mask for the elements that correspond to the first
252+
/// operand of this horizontal combination. For example, for two vectors
253+
/// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
254+
/// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
255+
/// result of this function to the left by 1.
256+
///
257+
/// \param VectorBitWidth the total bit width of the vector
258+
/// \param DemandedElts the demanded elements mask for the operation
259+
/// \param DemandedLHS the demanded elements mask for the left operand
260+
/// \param DemandedRHS the demanded elements mask for the right operand
261+
void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
262+
const APInt &DemandedElts,
263+
APInt &DemandedLHS,
264+
APInt &DemandedRHS);
265+
249266
/// Compute a map of integer instructions to their minimum legal type
250267
/// size.
251268
///

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,33 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
950950
return KnownOut;
951951
}
952952

953+
static KnownBits computeKnownBitsForHorizontalOperation(
954+
const Operator *I, const APInt &DemandedElts, unsigned Depth,
955+
const SimplifyQuery &Q,
956+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
957+
KnownBitsFunc) {
958+
APInt DemandedEltsLHS, DemandedEltsRHS;
959+
getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
960+
DemandedElts, DemandedEltsLHS,
961+
DemandedEltsRHS);
962+
963+
const auto ComputeForSingleOpFunc =
964+
[Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
965+
return KnownBitsFunc(
966+
computeKnownBits(Op, DemandedEltsOp, Depth + 1, Q),
967+
computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1, Q));
968+
};
969+
970+
if (!DemandedEltsLHS.isZero() && !DemandedEltsRHS.isZero()) {
971+
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS)
972+
.intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
973+
}
974+
if (!DemandedEltsLHS.isZero()) {
975+
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
976+
}
977+
return ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS);
978+
}
979+
953980
// Public so this can be used in `SimplifyDemandedUseBits`.
954981
KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
955982
const KnownBits &KnownLHS,
@@ -1725,6 +1752,56 @@ static void computeKnownBitsFromOperator(const Operator *I,
17251752
case Intrinsic::x86_sse42_crc32_64_64:
17261753
Known.Zero.setBitsFrom(32);
17271754
break;
1755+
case Intrinsic::x86_ssse3_phadd_d:
1756+
case Intrinsic::x86_ssse3_phadd_w:
1757+
case Intrinsic::x86_ssse3_phadd_d_128:
1758+
case Intrinsic::x86_ssse3_phadd_w_128:
1759+
case Intrinsic::x86_avx2_phadd_d:
1760+
case Intrinsic::x86_avx2_phadd_w: {
1761+
Known = computeKnownBitsForHorizontalOperation(
1762+
I, DemandedElts, Depth, Q,
1763+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1764+
return KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
1765+
/*NUW=*/false, KnownLHS,
1766+
KnownRHS);
1767+
});
1768+
break;
1769+
}
1770+
case Intrinsic::x86_ssse3_phadd_sw:
1771+
case Intrinsic::x86_ssse3_phadd_sw_128:
1772+
case Intrinsic::x86_avx2_phadd_sw: {
1773+
Known = computeKnownBitsForHorizontalOperation(
1774+
I, DemandedElts, Depth, Q,
1775+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1776+
return KnownBits::sadd_sat(KnownLHS, KnownRHS);
1777+
});
1778+
break;
1779+
}
1780+
case Intrinsic::x86_ssse3_phsub_d:
1781+
case Intrinsic::x86_ssse3_phsub_w:
1782+
case Intrinsic::x86_ssse3_phsub_d_128:
1783+
case Intrinsic::x86_ssse3_phsub_w_128:
1784+
case Intrinsic::x86_avx2_phsub_d:
1785+
case Intrinsic::x86_avx2_phsub_w: {
1786+
Known = computeKnownBitsForHorizontalOperation(
1787+
I, DemandedElts, Depth, Q,
1788+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1789+
return KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
1790+
/*NUW=*/false, KnownLHS,
1791+
KnownRHS);
1792+
});
1793+
break;
1794+
}
1795+
case Intrinsic::x86_ssse3_phsub_sw:
1796+
case Intrinsic::x86_ssse3_phsub_sw_128:
1797+
case Intrinsic::x86_avx2_phsub_sw: {
1798+
Known = computeKnownBitsForHorizontalOperation(
1799+
I, DemandedElts, Depth, Q,
1800+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1801+
return KnownBits::ssub_sat(KnownLHS, KnownRHS);
1802+
});
1803+
break;
1804+
}
17281805
case Intrinsic::riscv_vsetvli:
17291806
case Intrinsic::riscv_vsetvlimax: {
17301807
bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,33 @@ void llvm::processShuffleMasks(
541541
}
542542
}
543543

544+
void llvm::getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
545+
const APInt &DemandedElts,
546+
APInt &DemandedLHS,
547+
APInt &DemandedRHS) {
548+
int NumLanes = std::max<int>(1, VectorBitWidth / 128);
549+
int NumElts = DemandedElts.getBitWidth();
550+
int NumEltsPerLane = NumElts / NumLanes;
551+
int HalfEltsPerLane = NumEltsPerLane / 2;
552+
553+
DemandedLHS = APInt::getZero(NumElts);
554+
DemandedRHS = APInt::getZero(NumElts);
555+
556+
// Map DemandedElts to the horizontal operands.
557+
for (int Idx = 0; Idx != NumElts; ++Idx) {
558+
if (!DemandedElts[Idx])
559+
continue;
560+
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
561+
int LocalIdx = Idx % NumEltsPerLane;
562+
if (LocalIdx < HalfEltsPerLane) {
563+
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx);
564+
} else {
565+
LocalIdx -= HalfEltsPerLane;
566+
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx);
567+
}
568+
}
569+
}
570+
544571
MapVector<Instruction *, uint64_t>
545572
llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
546573
const TargetTransformInfo *TTI) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5180,29 +5180,10 @@ static void getPackDemandedElts(EVT VT, const APInt &DemandedElts,
51805180
// Split the demanded elts of a HADD/HSUB node between its operands.
51815181
static void getHorizDemandedElts(EVT VT, const APInt &DemandedElts,
51825182
APInt &DemandedLHS, APInt &DemandedRHS) {
5183-
int NumLanes = VT.getSizeInBits() / 128;
5184-
int NumElts = DemandedElts.getBitWidth();
5185-
int NumEltsPerLane = NumElts / NumLanes;
5186-
int HalfEltsPerLane = NumEltsPerLane / 2;
5187-
5188-
DemandedLHS = APInt::getZero(NumElts);
5189-
DemandedRHS = APInt::getZero(NumElts);
5190-
5191-
// Map DemandedElts to the horizontal operands.
5192-
for (int Idx = 0; Idx != NumElts; ++Idx) {
5193-
if (!DemandedElts[Idx])
5194-
continue;
5195-
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
5196-
int LocalIdx = Idx % NumEltsPerLane;
5197-
if (LocalIdx < HalfEltsPerLane) {
5198-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5199-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5200-
} else {
5201-
LocalIdx -= HalfEltsPerLane;
5202-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5203-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5204-
}
5205-
}
5183+
getHorizDemandedEltsForFirstOperand(VT.getSizeInBits(), DemandedElts,
5184+
DemandedLHS, DemandedRHS);
5185+
DemandedLHS |= DemandedLHS << 1;
5186+
DemandedRHS |= DemandedRHS << 1;
52065187
}
52075188

52085189
/// Calculates the shuffle mask corresponding to the target-specific opcode.
@@ -36953,6 +36934,34 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
3695336934
Known = Known.zext(64);
3695436935
}
3695536936

36937+
static KnownBits computeKnownBitsForHorizontalOperation(
36938+
const SDValue Op, const APInt &DemandedElts, unsigned Depth,
36939+
unsigned OpIndexStart, const SelectionDAG &DAG,
36940+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
36941+
KnownBitsFunc) {
36942+
APInt DemandedEltsLHS, DemandedEltsRHS;
36943+
getHorizDemandedEltsForFirstOperand(Op.getValueType().getSizeInBits(),
36944+
DemandedElts, DemandedEltsLHS,
36945+
DemandedEltsRHS);
36946+
36947+
const auto ComputeForSingleOpFunc =
36948+
[&DAG, Depth, KnownBitsFunc](const SDValue &Op, APInt &DemandedEltsOp) {
36949+
return KnownBitsFunc(
36950+
DAG.computeKnownBits(Op, DemandedEltsOp, Depth + 1),
36951+
DAG.computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1));
36952+
};
36953+
36954+
if (!DemandedEltsLHS.isZero() && !DemandedEltsRHS.isZero()) {
36955+
return ComputeForSingleOpFunc(Op.getOperand(OpIndexStart), DemandedEltsLHS)
36956+
.intersectWith(ComputeForSingleOpFunc(Op.getOperand(OpIndexStart + 1),
36957+
DemandedEltsRHS));
36958+
}
36959+
if (!DemandedEltsLHS.isZero()) {
36960+
return ComputeForSingleOpFunc(Op.getOperand(OpIndexStart), DemandedEltsLHS);
36961+
}
36962+
return ComputeForSingleOpFunc(Op.getOperand(OpIndexStart + 1), DemandedEltsRHS);
36963+
}
36964+
3695636965
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3695736966
KnownBits &Known,
3695836967
const APInt &DemandedElts,
@@ -37262,6 +37271,17 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3726237271
}
3726337272
break;
3726437273
}
37274+
case X86ISD::HADD:
37275+
case X86ISD::HSUB: {
37276+
Known = computeKnownBitsForHorizontalOperation(
37277+
Op, DemandedElts, Depth, /*OpIndexStart=*/0, DAG,
37278+
[Opc](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37279+
return KnownBits::computeForAddSub(
37280+
/*Add=*/Opc == X86ISD::HADD, /*NSW=*/false, /*NUW=*/false,
37281+
KnownLHS, KnownRHS);
37282+
});
37283+
break;
37284+
}
3726537285
case ISD::INTRINSIC_WO_CHAIN: {
3726637286
switch (Op->getConstantOperandVal(0)) {
3726737287
case Intrinsic::x86_sse2_psad_bw:
@@ -37276,6 +37296,55 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3727637296
computeKnownBitsForPSADBW(LHS, RHS, Known, DemandedElts, DAG, Depth);
3727737297
break;
3727837298
}
37299+
case Intrinsic::x86_ssse3_phadd_d:
37300+
case Intrinsic::x86_ssse3_phadd_w:
37301+
case Intrinsic::x86_ssse3_phadd_d_128:
37302+
case Intrinsic::x86_ssse3_phadd_w_128:
37303+
case Intrinsic::x86_avx2_phadd_d:
37304+
case Intrinsic::x86_avx2_phadd_w: {
37305+
Known = computeKnownBitsForHorizontalOperation(
37306+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37307+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37308+
return KnownBits::computeForAddSub(
37309+
/*Add=*/true, /*NSW=*/false, /*NUW=*/false, KnownLHS, KnownRHS);
37310+
});
37311+
break;
37312+
}
37313+
case Intrinsic::x86_ssse3_phadd_sw:
37314+
case Intrinsic::x86_ssse3_phadd_sw_128:
37315+
case Intrinsic::x86_avx2_phadd_sw: {
37316+
Known = computeKnownBitsForHorizontalOperation(
37317+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37318+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37319+
return KnownBits::sadd_sat(KnownLHS, KnownRHS);
37320+
});
37321+
break;
37322+
}
37323+
case Intrinsic::x86_ssse3_phsub_d:
37324+
case Intrinsic::x86_ssse3_phsub_w:
37325+
case Intrinsic::x86_ssse3_phsub_d_128:
37326+
case Intrinsic::x86_ssse3_phsub_w_128:
37327+
case Intrinsic::x86_avx2_phsub_d:
37328+
case Intrinsic::x86_avx2_phsub_w: {
37329+
Known = computeKnownBitsForHorizontalOperation(
37330+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37331+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37332+
return KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
37333+
/*NUW=*/false, KnownLHS,
37334+
KnownRHS);
37335+
});
37336+
break;
37337+
}
37338+
case Intrinsic::x86_ssse3_phsub_sw:
37339+
case Intrinsic::x86_ssse3_phsub_sw_128:
37340+
case Intrinsic::x86_avx2_phsub_sw: {
37341+
Known = computeKnownBitsForHorizontalOperation(
37342+
Op, DemandedElts, Depth, /*OpIndexStart=*/1, DAG,
37343+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37344+
return KnownBits::ssub_sat(KnownLHS, KnownRHS);
37345+
});
37346+
break;
37347+
}
3727937348
}
3728037349
break;
3728137350
}

0 commit comments

Comments
 (0)