Skip to content

Commit a398aae

Browse files
committed
initial commit - for vxi8 shifts, try permute vector to widen shift
1 parent 9f69da3 commit a398aae

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29766,6 +29766,102 @@ static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl,
2976629766
return SDValue();
2976729767
}
2976829768

29769+
// Given a vector of values, find a permutation such that every adjacent even-
29770+
// odd pair has the same value. ~0 is reserved as a special value for wildcard,
29771+
// which can be paired with any value. Returns true if a permutation is found.
29772+
template <typename InputTy,
29773+
typename PermutationTy,
29774+
typename MapTy = std::unordered_map<typename InputTy::value_type,
29775+
std::pair<typename InputTy::value_type, typename PermutationTy::value_type>>>
29776+
static bool PermuteAndPairVector(const InputTy& Inputs,
29777+
PermutationTy &Permutation) {
29778+
const auto Wildcard = ~typename InputTy::value_type();
29779+
29780+
// List of values to be paired, mapping an unpaired value to its current
29781+
// neighbor's value and index.
29782+
MapTy UnpairedInputs;
29783+
SmallVector<typename PermutationTy::value_type, 16> WildcardPairs;
29784+
29785+
Permutation.clear();
29786+
typename PermutationTy::value_type I = 0;
29787+
for (auto InputIt = Inputs.begin(), InputEnd = Inputs.end(); InputIt != InputEnd;) {
29788+
Permutation.push_back(I);
29789+
Permutation.push_back(I + 1);
29790+
29791+
auto Even = *InputIt++;
29792+
assert(InputIt != InputEnd && "Expected even number of elements");
29793+
auto Odd = *InputIt++;
29794+
29795+
// If both are wildcards, note it for later use by unpairable values.
29796+
if (Even == Wildcard && Odd == Wildcard) {
29797+
WildcardPairs.push_back(I);
29798+
}
29799+
29800+
// If both are equal, they are in good position.
29801+
if (Even != Odd) {
29802+
auto DoWork = [&] (auto &This, auto ThisIndex, auto Other, auto OtherIndex) {
29803+
if (This != Wildcard) {
29804+
// For non-wildcard value, check if it can pair with an exisiting
29805+
// unpaired value from UnpairedInputs, if so, swap with the unpaired
29806+
// value's neighbor, otherwise the current value is added to the map.
29807+
if (auto [MapIt, Inserted] = UnpairedInputs.try_emplace(This, std::make_pair(Other, OtherIndex)); !Inserted) {
29808+
auto [SwapValue, SwapIndex] = MapIt->second;
29809+
std::swap(Permutation[SwapIndex], Permutation[ThisIndex]);
29810+
This = SwapValue;
29811+
UnpairedInputs.erase(MapIt);
29812+
29813+
if (This == Other) {
29814+
if (This == Wildcard) {
29815+
// We freed up a wildcard pair by pairing two non-adjacent
29816+
// values, note it for later use by unpairable values.
29817+
WildcardPairs.push_back(I);
29818+
} else {
29819+
// The swapped element also forms a pair with Other, so it can
29820+
// be removed from the map.
29821+
assert(UnpairedInputs.count(This));
29822+
UnpairedInputs.erase(This);
29823+
}
29824+
} else {
29825+
// Swapped in an unpaired value, update its info.
29826+
if (This != Wildcard) {
29827+
assert(UnpairedInputs.count(This));
29828+
UnpairedInputs[This] = std::make_pair(Other, OtherIndex);
29829+
}
29830+
// If its neighbor is also in UnpairedInputs, update its info too.
29831+
if (auto OtherMapIt = UnpairedInputs.find(Other); OtherMapIt != UnpairedInputs.end() && OtherMapIt->second.second == ThisIndex) {
29832+
OtherMapIt->second.first = This;
29833+
}
29834+
}
29835+
}
29836+
}
29837+
};
29838+
DoWork(Even, I, Odd, I + 1);
29839+
if (Even != Odd) {
29840+
DoWork(Odd, I + 1, Even, I);
29841+
}
29842+
}
29843+
I += 2;
29844+
}
29845+
29846+
// Now check if each remaining unpaired neighboring values can be swapped with
29847+
// a wildcard pair to form two paired values.
29848+
for (auto &[Unpaired, V] : UnpairedInputs) {
29849+
auto [Neighbor, NeighborIndex] = V;
29850+
if (Neighbor != Wildcard) {
29851+
assert(UnpairedInputs.count(Neighbor));
29852+
if (WildcardPairs.size()) {
29853+
std::swap(Permutation[WildcardPairs.back()], Permutation[NeighborIndex]);
29854+
WildcardPairs.pop_back();
29855+
// Mark the neighbor as processed.
29856+
UnpairedInputs[Neighbor].first = Wildcard;
29857+
} else {
29858+
return false;
29859+
}
29860+
}
29861+
}
29862+
return true;
29863+
}
29864+
2976929865
static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2977029866
SelectionDAG &DAG) {
2977129867
MVT VT = Op.getSimpleValueType();
@@ -30044,6 +30140,110 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
3004430140
}
3004530141
}
3004630142

30143+
// ISD::SRA/SRL/SHL on vXi8 can be widened to vYi16 (Y = X/2) if the constant
30144+
// amounts can be shuffled such that every pair of adjacent elements has the
30145+
// same value. This introduces an extra shuffle before and after the shift,
30146+
// and it is profitable if the operand is aready a shuffle so that both can
30147+
// be merged, or if the extra shuffle is fast (can use VPSHUFB).
30148+
// (shift (shuffle X P1) S1) ->
30149+
// (shuffle (shift (shuffle X (shuffle P2 P1)) S2) P2^-1) where S2 can be
30150+
// widened, and P2^-1 is the inverse shuffle of P2.
30151+
if (ConstantAmt && (VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) && R.hasOneUse() && Subtarget.hasSSE3()) {
30152+
bool Profitable = true;
30153+
// VPAND ymm only available on AVX2.
30154+
if (VT == MVT::v32i8 || VT == MVT::v64i8) {
30155+
Profitable = Subtarget.hasAVX2();
30156+
}
30157+
30158+
SmallVector<int, 64> Permutation;
30159+
SmallVector<uint16_t, 64> ShiftAmt;
30160+
for (size_t I = 0; I < Amt.getNumOperands(); ++I) {
30161+
if (Amt.getOperand(I).isUndef())
30162+
ShiftAmt.push_back(~0);
30163+
else
30164+
ShiftAmt.push_back(Amt.getConstantOperandVal(I));
30165+
}
30166+
30167+
if (Profitable && (VT == MVT::v32i8 || VT == MVT::v64i8)) {
30168+
Profitable = false;
30169+
constexpr size_t LaneBytes = 16;
30170+
const size_t NumLanes = VT.getVectorNumElements() / LaneBytes;
30171+
30172+
// For v32i8 or v64i8, we should check if we can generate a shuffle that
30173+
// may be lowered to VPSHUFB, because it is faster than VPERMB. This is
30174+
// possible if we can apply the same shuffle mask to each v16i8 lane.
30175+
// For example (assuming a lane has 4 elements for simplicity),
30176+
// <1, 2, 2, 1, 4, 3, 3, 4> is handled as <14, 23, 23, 14>, which can
30177+
// be shuffled to adjacent pairs <14, 14, 23, 23> with the VPSHUFB mask
30178+
// <0, 3, 2, 1> (or high level mask <0, 3, 2, 1, 4, 7, 6, 5>).
30179+
// Limitation: if there are some undef in shift amounts, this algorithm
30180+
// may not find a solution even if one exists, as here we only treat a
30181+
// VPSHUFB index as undef if all shuffle amounts of the same index modulo
30182+
// lane size are all undef.
30183+
// Since a byte can only be shifted by 7 bits without being UB, 4 bits are
30184+
// enough to represent the shift amount or undef (0xF).
30185+
std::array<uint16_t, LaneBytes> VPSHUFBShiftAmt = {};
30186+
for (size_t I = 0; I < LaneBytes; ++I)
30187+
for (size_t J = 0; J < NumLanes; ++J)
30188+
VPSHUFBShiftAmt[I] |= (ShiftAmt[I + J * LaneBytes] & 0xF) << (J * 4);
30189+
if (VT == MVT::v32i8) {
30190+
for (size_t I = 0; I < LaneBytes; ++I)
30191+
VPSHUFBShiftAmt[I] |= 0xFF00;
30192+
}
30193+
if (PermuteAndPairVector(VPSHUFBShiftAmt, Permutation)) {
30194+
// Found a VPSHUFB solution, offset the shuffle amount to other lanes.
30195+
Permutation.resize(VT.getVectorNumElements());
30196+
for (size_t I = 0; I < LaneBytes; ++I)
30197+
for (size_t J = 1; J < NumLanes; ++J)
30198+
Permutation[I + J * LaneBytes] = Permutation[I] + J * LaneBytes;
30199+
Profitable = true;
30200+
} else if (R.getOpcode() == ISD::VECTOR_SHUFFLE) {
30201+
// A slower shuffle is profitable if the operand is also a slow shuffle,
30202+
// such that they can be merged.
30203+
// TODO: Use TargetTransformInfo to systematically determine whether
30204+
// inner shuffle is slow. Currently we only check if it contains
30205+
// cross-lane shuffle.
30206+
if (ShuffleVectorSDNode *InnerShuffle = dyn_cast<ShuffleVectorSDNode>(R.getNode())) {
30207+
if (InnerShuffle->getMask().size() == VT.getVectorNumElements() &&
30208+
is128BitLaneCrossingShuffleMask(VT, InnerShuffle->getMask()))
30209+
Profitable = true;
30210+
}
30211+
}
30212+
}
30213+
30214+
// If it is still profitable at this point, and has not found a permutation
30215+
// yet, try again with any shuffle index.
30216+
if (Profitable && Permutation.empty()) {
30217+
PermuteAndPairVector<decltype(ShiftAmt), decltype(Permutation),
30218+
SmallMapVector<uint16_t, std::pair<uint16_t, int>, 8>>(ShiftAmt, Permutation);
30219+
}
30220+
30221+
// Found a permutation P that can rearrange the shift amouts into adjacent
30222+
// pair of same values. Rewrite the shift S1(x) into P^-1(S2(P(x))).
30223+
if (!Permutation.empty()) {
30224+
SDValue InnerShuffle = DAG.getVectorShuffle(VT, dl, R, DAG.getUNDEF(VT), Permutation);
30225+
SmallVector<SDValue, 64> NewShiftAmt;
30226+
for (int Index : Permutation) {
30227+
NewShiftAmt.push_back(Amt.getOperand(Index));
30228+
}
30229+
#ifndef NDEBUG
30230+
for (size_t I = 0; I < NewShiftAmt.size(); I += 2) {
30231+
SDValue Even = NewShiftAmt[I];
30232+
SDValue Odd = NewShiftAmt[I + 1];
30233+
assert(Even.isUndef() || Odd.isUndef() || Even->getAsZExtVal() == Odd->getAsZExtVal());
30234+
}
30235+
#endif
30236+
SDValue NewShiftVector = DAG.getBuildVector(VT, dl, NewShiftAmt);
30237+
SDValue NewShift = DAG.getNode(Opc, dl, VT, InnerShuffle, NewShiftVector);
30238+
SmallVector<int, 64> InversePermutation(Permutation.size());
30239+
for (size_t I = 0; I < Permutation.size(); ++I) {
30240+
InversePermutation[Permutation[I]] = I;
30241+
}
30242+
SDValue OuterShuffle = DAG.getVectorShuffle(VT, dl, NewShift, DAG.getUNDEF(VT), InversePermutation);
30243+
return OuterShuffle;
30244+
}
30245+
}
30246+
3004730247
// If possible, lower this packed shift into a vector multiply instead of
3004830248
// expanding it into a sequence of scalar shifts.
3004930249
// For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.

0 commit comments

Comments
 (0)