Skip to content

Commit cc97653

Browse files
authored
AMDGPU: Custom lower 32-bit element shuffles (#123711)
This is so we can try to make use of v_pk_mov_b32 when available. Note this currently has little observable effect. The combiner will undo the common extract of shuffle pattern. The lack of test changes should demonstrate this change is minimally correct. We should probably try to make better use of wider extracts in even aligned cases, but I'm trying to avoid some really ugly regalloc regressions in some MFMA tests. The DAG scheduler ends up doing a worse job if we use vector extracts, resulting in failure to do 3 address conversion of MFMAs.
1 parent 334a1cd commit cc97653

File tree

1 file changed

+80
-5
lines changed

1 file changed

+80
-5
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
419419
}
420420

421421
setOperationAction(ISD::VECTOR_SHUFFLE,
422-
{MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
423-
Expand);
422+
{MVT::v4i32, MVT::v4f32, MVT::v8i32, MVT::v8f32,
423+
MVT::v16i32, MVT::v16f32, MVT::v32i32, MVT::v32f32},
424+
Custom);
424425

425426
if (Subtarget->hasPkMovB32()) {
426427
// TODO: 16-bit element vectors should be legal with even aligned elements.
@@ -7589,15 +7590,38 @@ static bool elementPairIsContiguous(ArrayRef<int> Mask, int Elt) {
75897590
return Mask[Elt + 1] == Mask[Elt] + 1 && (Mask[Elt] % 2 == 0);
75907591
}
75917592

7593+
static bool elementPairIsOddToEven(ArrayRef<int> Mask, int Elt) {
7594+
assert(Elt % 2 == 0);
7595+
return Mask[Elt] >= 0 && Mask[Elt + 1] >= 0 && (Mask[Elt] & 1) &&
7596+
!(Mask[Elt + 1] & 1);
7597+
}
7598+
75927599
SDValue SITargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
75937600
SelectionDAG &DAG) const {
75947601
SDLoc SL(Op);
75957602
EVT ResultVT = Op.getValueType();
75967603
ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op);
75977604
MVT EltVT = ResultVT.getVectorElementType().getSimpleVT();
7598-
MVT PackVT = MVT::getVectorVT(EltVT, 2);
7605+
const int NewSrcNumElts = 2;
7606+
MVT PackVT = MVT::getVectorVT(EltVT, NewSrcNumElts);
75997607
int SrcNumElts = Op.getOperand(0).getValueType().getVectorNumElements();
76007608

7609+
// Break up the shuffle into registers sized pieces.
7610+
//
7611+
// We're trying to form sub-shuffles that the register allocation pipeline
7612+
// won't be able to figure out, like how to use v_pk_mov_b32 to do a register
7613+
// blend or 16-bit op_sel. It should be able to figure out how to reassemble a
7614+
// pair of copies into a consecutive register copy, so use the ordinary
7615+
// extract_vector_elt lowering unless we can use the shuffle.
7616+
//
7617+
// TODO: This is a bit of hack, and we should probably always use
7618+
// extract_subvector for the largest possible subvector we can (or at least
7619+
// use it for PackVT aligned pieces). However we have worse support for
7620+
// combines on them don't directly treat extract_subvector / insert_subvector
7621+
// as legal. The DAG scheduler also ends up doing a worse job with the
7622+
// extract_subvectors.
7623+
const bool ShouldUseConsecutiveExtract = EltVT.getSizeInBits() == 16;
7624+
76017625
// vector_shuffle <0,1,6,7> lhs, rhs
76027626
// -> concat_vectors (extract_subvector lhs, 0), (extract_subvector rhs, 2)
76037627
//
@@ -7608,16 +7632,67 @@ SDValue SITargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
76087632
// -> concat_vectors (extract_subvector rhs, 2), (extract_subvector lhs, 0)
76097633

76107634
// Avoid scalarizing when both halves are reading from consecutive elements.
7611-
SmallVector<SDValue, 4> Pieces;
7635+
7636+
// If we're treating 2 element shuffles as legal, also create odd-to-even
7637+
// shuffles of neighboring pairs.
7638+
//
7639+
// vector_shuffle <3,2,7,6> lhs, rhs
7640+
// -> concat_vectors vector_shuffle <1, 0> (extract_subvector lhs, 0)
7641+
// vector_shuffle <1, 0> (extract_subvector rhs, 2)
7642+
7643+
SmallVector<SDValue, 16> Pieces;
76127644
for (int I = 0, N = ResultVT.getVectorNumElements(); I != N; I += 2) {
7613-
if (elementPairIsContiguous(SVN->getMask(), I)) {
7645+
if (ShouldUseConsecutiveExtract &&
7646+
elementPairIsContiguous(SVN->getMask(), I)) {
76147647
const int Idx = SVN->getMaskElt(I);
76157648
int VecIdx = Idx < SrcNumElts ? 0 : 1;
76167649
int EltIdx = Idx < SrcNumElts ? Idx : Idx - SrcNumElts;
76177650
SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SL, PackVT,
76187651
SVN->getOperand(VecIdx),
76197652
DAG.getConstant(EltIdx, SL, MVT::i32));
76207653
Pieces.push_back(SubVec);
7654+
} else if (elementPairIsOddToEven(SVN->getMask(), I) &&
7655+
isOperationLegal(ISD::VECTOR_SHUFFLE, PackVT)) {
7656+
int Idx0 = SVN->getMaskElt(I);
7657+
int Idx1 = SVN->getMaskElt(I + 1);
7658+
7659+
SDValue SrcOp0 = SVN->getOperand(0);
7660+
SDValue SrcOp1 = SrcOp0;
7661+
if (Idx0 >= SrcNumElts) {
7662+
SrcOp0 = SVN->getOperand(1);
7663+
Idx0 -= SrcNumElts;
7664+
}
7665+
7666+
if (Idx1 >= SrcNumElts) {
7667+
SrcOp1 = SVN->getOperand(1);
7668+
Idx1 -= SrcNumElts;
7669+
}
7670+
7671+
int AlignedIdx0 = Idx0 & ~(NewSrcNumElts - 1);
7672+
int AlignedIdx1 = Idx1 & ~(NewSrcNumElts - 1);
7673+
7674+
// Extract nearest even aligned piece.
7675+
SDValue SubVec0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SL, PackVT, SrcOp0,
7676+
DAG.getConstant(AlignedIdx0, SL, MVT::i32));
7677+
SDValue SubVec1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SL, PackVT, SrcOp1,
7678+
DAG.getConstant(AlignedIdx1, SL, MVT::i32));
7679+
7680+
int NewMaskIdx0 = Idx0 - AlignedIdx0;
7681+
int NewMaskIdx1 = Idx1 - AlignedIdx1;
7682+
7683+
SDValue Result0 = SubVec0;
7684+
SDValue Result1 = SubVec0;
7685+
7686+
if (SubVec0 != SubVec1) {
7687+
NewMaskIdx1 += NewSrcNumElts;
7688+
Result1 = SubVec1;
7689+
} else {
7690+
Result1 = DAG.getUNDEF(PackVT);
7691+
}
7692+
7693+
SDValue Shuf = DAG.getVectorShuffle(PackVT, SL, Result0, Result1,
7694+
{NewMaskIdx0, NewMaskIdx1});
7695+
Pieces.push_back(Shuf);
76217696
} else {
76227697
const int Idx0 = SVN->getMaskElt(I);
76237698
const int Idx1 = SVN->getMaskElt(I + 1);

0 commit comments

Comments
 (0)