Skip to content

[ValueTracking][X86] Compute KnownBits for phadd/phsub #92429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions llvm/include/llvm/Analysis/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,23 @@ void processShuffleMasks(
function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);

/// Compute the demanded elements mask of horizontal binary operations. A
/// horizontal operation combines two adjacent elements in a vector operand.
/// This function returns a mask for the elements that correspond to the first
/// operand of this horizontal combination. For example, for two vectors
/// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
/// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
/// result of this function to the left by 1.
///
/// \param VectorBitWidth the total bit width of the vector
/// \param DemandedElts the demanded elements mask for the operation
/// \param DemandedLHS the demanded elements mask for the left operand
/// \param DemandedRHS the demanded elements mask for the right operand
void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
const APInt &DemandedElts,
APInt &DemandedLHS,
APInt &DemandedRHS);

/// Compute a map of integer instructions to their minimum legal type
/// size.
///
Expand Down
64 changes: 64 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,32 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
return KnownOut;
}

static KnownBits computeKnownBitsForHorizontalOperation(
const Operator *I, const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q,
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
KnownBitsFunc) {
APInt DemandedEltsLHS, DemandedEltsRHS;
getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
DemandedElts, DemandedEltsLHS,
DemandedEltsRHS);

const auto ComputeForSingleOpFunc =
[Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
return KnownBitsFunc(
computeKnownBits(Op, DemandedEltsOp, Depth + 1, Q),
computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1, Q));
};

if (DemandedEltsRHS.isZero())
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
if (DemandedEltsLHS.isZero())
return ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS);

return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS)
.intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
}

// Public so this can be used in `SimplifyDemandedUseBits`.
KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
const KnownBits &KnownLHS,
Expand Down Expand Up @@ -1756,6 +1782,44 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Intrinsic::x86_sse42_crc32_64_64:
Known.Zero.setBitsFrom(32);
break;
case Intrinsic::x86_ssse3_phadd_d_128:
case Intrinsic::x86_ssse3_phadd_w_128:
case Intrinsic::x86_avx2_phadd_d:
case Intrinsic::x86_avx2_phadd_w: {
Known = computeKnownBitsForHorizontalOperation(
I, DemandedElts, Depth, Q,
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
return KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
/*NUW=*/false, KnownLHS,
KnownRHS);
});
break;
}
case Intrinsic::x86_ssse3_phadd_sw_128:
case Intrinsic::x86_avx2_phadd_sw: {
Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
Q, KnownBits::sadd_sat);
break;
}
case Intrinsic::x86_ssse3_phsub_d_128:
case Intrinsic::x86_ssse3_phsub_w_128:
case Intrinsic::x86_avx2_phsub_d:
case Intrinsic::x86_avx2_phsub_w: {
Known = computeKnownBitsForHorizontalOperation(
I, DemandedElts, Depth, Q,
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
return KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
/*NUW=*/false, KnownLHS,
KnownRHS);
});
break;
}
case Intrinsic::x86_ssse3_phsub_sw_128:
case Intrinsic::x86_avx2_phsub_sw: {
Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Depth,
Q, KnownBits::ssub_sat);
break;
}
case Intrinsic::riscv_vsetvli:
case Intrinsic::riscv_vsetvlimax: {
bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;
Expand Down
28 changes: 28 additions & 0 deletions llvm/lib/Analysis/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,34 @@ void llvm::processShuffleMasks(
}
}

void llvm::getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
const APInt &DemandedElts,
APInt &DemandedLHS,
APInt &DemandedRHS) {
assert(VectorBitWidth >= 128 && "Vectors smaller than 128 bit not supported");
int NumLanes = VectorBitWidth / 128;
int NumElts = DemandedElts.getBitWidth();
int NumEltsPerLane = NumElts / NumLanes;
int HalfEltsPerLane = NumEltsPerLane / 2;

DemandedLHS = APInt::getZero(NumElts);
DemandedRHS = APInt::getZero(NumElts);

// Map DemandedElts to the horizontal operands.
for (int Idx = 0; Idx != NumElts; ++Idx) {
if (!DemandedElts[Idx])
continue;
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
int LocalIdx = Idx % NumEltsPerLane;
if (LocalIdx < HalfEltsPerLane) {
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx);
} else {
LocalIdx -= HalfEltsPerLane;
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx);
}
}
}

MapVector<Instruction *, uint64_t>
llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
const TargetTransformInfo *TTI) {
Expand Down
64 changes: 41 additions & 23 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5204,29 +5204,10 @@ static void getPackDemandedElts(EVT VT, const APInt &DemandedElts,
// Split the demanded elts of a HADD/HSUB node between its operands.
static void getHorizDemandedElts(EVT VT, const APInt &DemandedElts,
APInt &DemandedLHS, APInt &DemandedRHS) {
int NumLanes = VT.getSizeInBits() / 128;
int NumElts = DemandedElts.getBitWidth();
int NumEltsPerLane = NumElts / NumLanes;
int HalfEltsPerLane = NumEltsPerLane / 2;

DemandedLHS = APInt::getZero(NumElts);
DemandedRHS = APInt::getZero(NumElts);

// Map DemandedElts to the horizontal operands.
for (int Idx = 0; Idx != NumElts; ++Idx) {
if (!DemandedElts[Idx])
continue;
int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
int LocalIdx = Idx % NumEltsPerLane;
if (LocalIdx < HalfEltsPerLane) {
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 0);
DemandedLHS.setBit(LaneIdx + 2 * LocalIdx + 1);
} else {
LocalIdx -= HalfEltsPerLane;
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 0);
DemandedRHS.setBit(LaneIdx + 2 * LocalIdx + 1);
}
}
getHorizDemandedEltsForFirstOperand(VT.getSizeInBits(), DemandedElts,
DemandedLHS, DemandedRHS);
DemandedLHS |= DemandedLHS << 1;
DemandedRHS |= DemandedRHS << 1;
}

/// Calculates the shuffle mask corresponding to the target-specific opcode.
Expand Down Expand Up @@ -37174,6 +37155,32 @@ static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
Known = KnownBits::sadd_sat(Lo, Hi);
}

static KnownBits computeKnownBitsForHorizontalOperation(
const SDValue Op, const APInt &DemandedElts, unsigned Depth,
const SelectionDAG &DAG,
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
KnownBitsFunc) {
APInt DemandedEltsLHS, DemandedEltsRHS;
getHorizDemandedEltsForFirstOperand(Op.getValueType().getSizeInBits(),
DemandedElts, DemandedEltsLHS,
DemandedEltsRHS);

const auto ComputeForSingleOpFunc =
[&DAG, Depth, KnownBitsFunc](SDValue Op, APInt &DemandedEltsOp) {
return KnownBitsFunc(
DAG.computeKnownBits(Op, DemandedEltsOp, Depth + 1),
DAG.computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1));
};

if (DemandedEltsRHS.isZero())
return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS);
if (DemandedEltsLHS.isZero())
return ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS);

return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS)
.intersectWith(ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS));
}

void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
KnownBits &Known,
const APInt &DemandedElts,
Expand Down Expand Up @@ -37503,6 +37510,17 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
}
break;
}
case X86ISD::HADD:
case X86ISD::HSUB: {
Known = computeKnownBitsForHorizontalOperation(
Op, DemandedElts, Depth, DAG,
[Opc](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
return KnownBits::computeForAddSub(
/*Add=*/Opc == X86ISD::HADD, /*NSW=*/false, /*NUW=*/false,
KnownLHS, KnownRHS);
});
break;
}
case ISD::INTRINSIC_WO_CHAIN: {
switch (Op->getConstantOperandVal(0)) {
case Intrinsic::x86_sse2_pmadd_wd:
Expand Down
Loading
Loading