Skip to content

Commit 77f0488

Browse files
authored
[AArch64] Combine zext of deinterleaving shuffle. (#107201)
This is part 1 of a few patches that are intended to take deinterleaving shuffles with masks like `[0,4,8,12]`, where the shuffle is zero-extended to a larger size, and optimize away the deinterleave. In this case it converts them to `and(uzp1, mask)`, where the `uzp1` act upon the elements in the larger type size to get the lanes into the correct possitions, and the `and` performs the zext. It performs the combine fairly late, on the legalized type so that uitofp that are converted to uitofp(zext(..)) will also be handled.
1 parent 0c1500e commit 77f0488

File tree

2 files changed

+163
-179
lines changed

2 files changed

+163
-179
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22187,6 +22187,59 @@ performSignExtendSetCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2218722187
return SDValue();
2218822188
}
2218922189

22190+
// Convert zext(extract(shuffle a, b, [0,4,8,12])) -> and(uzp1(a, b), 255)
22191+
// This comes from interleaved vectorization. It is performed late to capture
22192+
// uitofp converts too.
22193+
static SDValue performZExtDeinterleaveShuffleCombine(SDNode *N,
22194+
SelectionDAG &DAG) {
22195+
EVT VT = N->getValueType(0);
22196+
if ((VT != MVT::v4i32 && VT != MVT::v8i16) ||
22197+
N->getOpcode() != ISD::ZERO_EXTEND ||
22198+
N->getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
22199+
return SDValue();
22200+
22201+
unsigned ExtOffset = N->getOperand(0).getConstantOperandVal(1);
22202+
if (ExtOffset != 0 && ExtOffset != VT.getVectorNumElements())
22203+
return SDValue();
22204+
22205+
EVT InVT = N->getOperand(0).getOperand(0).getValueType();
22206+
auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0).getOperand(0));
22207+
if (!Shuffle ||
22208+
InVT.getVectorNumElements() != VT.getVectorNumElements() * 2 ||
22209+
InVT.getScalarSizeInBits() * 2 != VT.getScalarSizeInBits())
22210+
return SDValue();
22211+
22212+
unsigned Idx;
22213+
bool IsDeInterleave = ShuffleVectorInst::isDeInterleaveMaskOfFactor(
22214+
Shuffle->getMask().slice(ExtOffset, VT.getVectorNumElements()), 4, Idx);
22215+
// An undef interleave shuffle can come up after other canonicalizations,
22216+
// where the shuffle has been converted to
22217+
// zext(extract(shuffle b, undef, [u,u,0,4]))
22218+
bool IsUndefDeInterleave = false;
22219+
if (!IsDeInterleave)
22220+
IsUndefDeInterleave =
22221+
Shuffle->getOperand(1).isUndef() &&
22222+
ShuffleVectorInst::isDeInterleaveMaskOfFactor(
22223+
Shuffle->getMask().slice(ExtOffset + VT.getVectorNumElements() / 2,
22224+
VT.getVectorNumElements() / 2),
22225+
4, Idx);
22226+
if ((!IsDeInterleave && !IsUndefDeInterleave) || Idx >= 4)
22227+
return SDValue();
22228+
SDLoc DL(N);
22229+
SDValue BC1 = DAG.getNode(AArch64ISD::NVCAST, DL, VT,
22230+
Shuffle->getOperand(IsUndefDeInterleave ? 1 : 0));
22231+
SDValue BC2 = DAG.getNode(AArch64ISD::NVCAST, DL, VT,
22232+
Shuffle->getOperand(IsUndefDeInterleave ? 0 : 1));
22233+
SDValue UZP = DAG.getNode(Idx < 2 ? AArch64ISD::UZP1 : AArch64ISD::UZP2, DL,
22234+
VT, BC1, BC2);
22235+
if ((Idx & 1) == 1)
22236+
UZP = DAG.getNode(ISD::SRL, DL, VT, UZP,
22237+
DAG.getConstant(InVT.getScalarSizeInBits(), DL, VT));
22238+
return DAG.getNode(
22239+
ISD::AND, DL, VT, UZP,
22240+
DAG.getConstant((1 << InVT.getScalarSizeInBits()) - 1, DL, VT));
22241+
}
22242+
2219022243
static SDValue performExtendCombine(SDNode *N,
2219122244
TargetLowering::DAGCombinerInfo &DCI,
2219222245
SelectionDAG &DAG) {
@@ -22207,6 +22260,9 @@ static SDValue performExtendCombine(SDNode *N,
2220722260
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD);
2220822261
}
2220922262

22263+
if (SDValue R = performZExtDeinterleaveShuffleCombine(N, DAG))
22264+
return R;
22265+
2221022266
if (N->getValueType(0).isFixedLengthVector() &&
2221122267
N->getOpcode() == ISD::SIGN_EXTEND &&
2221222268
N->getOperand(0)->getOpcode() == ISD::SETCC)

0 commit comments

Comments
 (0)