Skip to content

Commit b64f2ea

Browse files
committed
[ValueTracking][X86] Compute KnownBits for phadd/phsub
Add KnownBits computations to ValueTracking and X86 DAG lowering. These instructions add/subtract adjacent vector elements in their operands. Example: phadd [X1, X2] [Y1, Y2] = [X1 + X2, Y1 + Y2]. This means that, in this example, we can compute the KnownBits of the operation by computing the KnownBits of [X1, X2] + [X1, X2] and [Y1, Y2] + [Y1, Y2] and intersecting the results. This approach also generalizes to all x86 vector types. There are also the operations phadd.sw and phsub.sw, which perform saturating addition/subtraction. Use sadd_sat and ssub_sat to compute the KnownBits of these operations in ValueTracking. Also adjust the existing test case pr53247.ll because it can be transformed to a constant using the new KnownBits computation. Fixes #82516.
1 parent 95724e9 commit b64f2ea

File tree

8 files changed

+198
-158
lines changed

8 files changed

+198
-158
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,23 @@ void processShuffleMasks(
246246
function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
247247
function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
248248

249+
/// Compute the demanded elements mask of horizontal binary operations. A
250+
/// horizontal operation combines two adjacent elements in a vector operand.
251+
/// This function returns a mask for the elements that correspond to the first
252+
/// operand of this horizontal combination. For example, for two vectors
253+
/// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
254+
/// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
255+
/// result of this function to the left by 1.
256+
///
257+
/// \param VectorBitWidth the total bit width of the vector
258+
/// \param DemandedElts the demanded elements mask for the operation
259+
/// \param DemandedLHS the demanded elements mask for the left operand
260+
/// \param DemandedRHS the demanded elements mask for the right operand
261+
void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
262+
const APInt &DemandedElts,
263+
APInt &DemandedLHS,
264+
APInt &DemandedRHS);
265+
249266
/// Compute a map of integer instructions to their minimum legal type
250267
/// size.
251268
///

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,32 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
950950
return KnownOut;
951951
}
952952

953+
static KnownBits computeKnownBitsForHorizontalOperation(
954+
const Operator *I, const APInt &DemandedElts, unsigned Depth,
955+
const SimplifyQuery &Q,
956+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
957+
KnownBitsFunc) {
958+
APInt DemandedEltsLHS, DemandedEltsRHS;
959+
getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
960+
DemandedElts, DemandedEltsLHS,
961+
DemandedEltsRHS);
962+
963+
const auto ComputeForSingleOpFunc =
964+
[Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
965+
return KnownBitsFunc(
966+
computeKnownBits(Op, DemandedEltsOp, Depth + 1, Q),
967+
computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1, Q));
968+
};
969+
970+
if (DemandedEltsRHS.isZero())
971+
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
972+
if (DemandedEltsLHS.isZero())
973+
return ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS);
974+
975+
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS)
976+
.intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
977+
}
978+
953979
// Public so this can be used in `SimplifyDemandedUseBits`.
954980
KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
955981
const KnownBits &KnownLHS,
@@ -1725,6 +1751,44 @@ static void computeKnownBitsFromOperator(const Operator *I,
17251751
case Intrinsic::x86_sse42_crc32_64_64:
17261752
Known.Zero.setBitsFrom(32);
17271753
break;
1754+
case Intrinsic::x86_ssse3_phadd_d_128:
1755+
case Intrinsic::x86_ssse3_phadd_w_128:
1756+
case Intrinsic::x86_avx2_phadd_d:
1757+
case Intrinsic::x86_avx2_phadd_w: {
1758+
Known = computeKnownBitsForHorizontalOperation(
1759+
I, DemandedElts, Depth, Q,
1760+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1761+
return KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
1762+
/*NUW=*/false, KnownLHS,
1763+
KnownRHS);
1764+
});
1765+
break;
1766+
}
1767+
case Intrinsic::x86_ssse3_phadd_sw_128:
1768+
case Intrinsic::x86_avx2_phadd_sw: {
1769+
Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
1770+
Q, KnownBits::sadd_sat);
1771+
break;
1772+
}
1773+
case Intrinsic::x86_ssse3_phsub_d_128:
1774+
case Intrinsic::x86_ssse3_phsub_w_128:
1775+
case Intrinsic::x86_avx2_phsub_d:
1776+
case Intrinsic::x86_avx2_phsub_w: {
1777+
Known = computeKnownBitsForHorizontalOperation(
1778+
I, DemandedElts, Depth, Q,
1779+
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1780+
return KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
1781+
/*NUW=*/false, KnownLHS,
1782+
KnownRHS);
1783+
});
1784+
break;
1785+
}
1786+
case Intrinsic::x86_ssse3_phsub_sw_128:
1787+
case Intrinsic::x86_avx2_phsub_sw: {
1788+
Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
1789+
Q, KnownBits::ssub_sat);
1790+
break;
1791+
}
17281792
case Intrinsic::riscv_vsetvli:
17291793
case Intrinsic::riscv_vsetvlimax: {
17301794
bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,34 @@ void llvm::processShuffleMasks(
541541
}
542542
}
543543

544+
void llvm::getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
545+
const APInt &DemandedElts,
546+
APInt &DemandedLHS,
547+
APInt &DemandedRHS) {
548+
assert(VectorBitWidth >= 128 && "Vectors smaller than 128 bit not supported");
549+
int NumLanes = VectorBitWidth / 128;
550+
int NumElts = DemandedElts.getBitWidth();
551+
int NumEltsPerLane = NumElts / NumLanes;
552+
int HalfEltsPerLane = NumEltsPerLane / 2;
553+
554+
DemandedLHS = APInt::getZero(NumElts);
555+
DemandedRHS = APInt::getZero(NumElts);
556+
557+
// Map DemandedElts to the horizontal operands.
558+
for (int Idx = 0; Idx != NumElts; ++Idx) {
559+
if (!DemandedElts[Idx])
560+
continue;
561+
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
562+
int LocalIdx = Idx % NumEltsPerLane;
563+
if (LocalIdx < HalfEltsPerLane) {
564+
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx);
565+
} else {
566+
LocalIdx -= HalfEltsPerLane;
567+
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx);
568+
}
569+
}
570+
}
571+
544572
MapVector<Instruction *, uint64_t>
545573
llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
546574
const TargetTransformInfo *TTI) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5180,29 +5180,10 @@ static void getPackDemandedElts(EVT VT, const APInt &DemandedElts,
51805180
// Split the demanded elts of a HADD/HSUB node between its operands.
51815181
static void getHorizDemandedElts(EVT VT, const APInt &DemandedElts,
51825182
APInt &DemandedLHS, APInt &DemandedRHS) {
5183-
int NumLanes = VT.getSizeInBits() / 128;
5184-
int NumElts = DemandedElts.getBitWidth();
5185-
int NumEltsPerLane = NumElts / NumLanes;
5186-
int HalfEltsPerLane = NumEltsPerLane / 2;
5187-
5188-
DemandedLHS = APInt::getZero(NumElts);
5189-
DemandedRHS = APInt::getZero(NumElts);
5190-
5191-
// Map DemandedElts to the horizontal operands.
5192-
for (int Idx = 0; Idx != NumElts; ++Idx) {
5193-
if (!DemandedElts[Idx])
5194-
continue;
5195-
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
5196-
int LocalIdx = Idx % NumEltsPerLane;
5197-
if (LocalIdx < HalfEltsPerLane) {
5198-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5199-
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5200-
} else {
5201-
LocalIdx -= HalfEltsPerLane;
5202-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 0);
5203-
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 1);
5204-
}
5205-
}
5183+
getHorizDemandedEltsForFirstOperand(VT.getSizeInBits(), DemandedElts,
5184+
DemandedLHS, DemandedRHS);
5185+
DemandedLHS |= DemandedLHS << 1;
5186+
DemandedRHS |= DemandedRHS << 1;
52065187
}
52075188

52085189
/// Calculates the shuffle mask corresponding to the target-specific opcode.
@@ -36953,6 +36934,32 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
3695336934
Known = Known.zext(64);
3695436935
}
3695536936

36937+
static KnownBits computeKnownBitsForHorizontalOperation(
36938+
const SDValue Op, const APInt &DemandedElts, unsigned Depth,
36939+
const SelectionDAG &DAG,
36940+
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
36941+
KnownBitsFunc) {
36942+
APInt DemandedEltsLHS, DemandedEltsRHS;
36943+
getHorizDemandedEltsForFirstOperand(Op.getValueType().getSizeInBits(),
36944+
DemandedElts, DemandedEltsLHS,
36945+
DemandedEltsRHS);
36946+
36947+
const auto ComputeForSingleOpFunc =
36948+
[&DAG, Depth, KnownBitsFunc](SDValue Op, APInt &DemandedEltsOp) {
36949+
return KnownBitsFunc(
36950+
DAG.computeKnownBits(Op, DemandedEltsOp, Depth + 1),
36951+
DAG.computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1));
36952+
};
36953+
36954+
if (DemandedEltsRHS.isZero())
36955+
return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS);
36956+
if (DemandedEltsLHS.isZero())
36957+
return ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS);
36958+
36959+
return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS)
36960+
.intersectWith(ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS));
36961+
}
36962+
3695636963
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3695736964
KnownBits &Known,
3695836965
const APInt &DemandedElts,
@@ -37262,6 +37269,17 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3726237269
}
3726337270
break;
3726437271
}
37272+
case X86ISD::HADD:
37273+
case X86ISD::HSUB: {
37274+
Known = computeKnownBitsForHorizontalOperation(
37275+
Op, DemandedElts, Depth, DAG,
37276+
[Opc](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37277+
return KnownBits::computeForAddSub(
37278+
/*Add=*/Opc == X86ISD::HADD, /*NSW=*/false, /*NUW=*/false,
37279+
KnownLHS, KnownRHS);
37280+
});
37281+
break;
37282+
}
3726537283
case ISD::INTRINSIC_WO_CHAIN: {
3726637284
switch (Op->getConstantOperandVal(0)) {
3726737285
case Intrinsic::x86_sse2_psad_bw:

0 commit comments

Comments
 (0)