Skip to content

Commit 363ec6f

Browse files
committed
[AArch64][GlobalISel] Common some shuffle mask functions.
This removes the GISel versions of isREVMask, isTRNMask, isUZPMask and isZipMask. They are combined with the existing versions from SDAG into AArch64PerfectShuffle.h.
1 parent e4d2427 commit 363ec6f

File tree

7 files changed

+117
-244
lines changed

7 files changed

+117
-244
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 32 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11880,47 +11880,6 @@ static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
1188011880
return true;
1188111881
}
1188211882

11883-
/// isREVMask - Check if a vector shuffle corresponds to a REV
11884-
/// instruction with the specified blocksize. (The order of the elements
11885-
/// within each block of the vector is reversed.)
11886-
static bool isREVMask(ArrayRef<int> M, EVT VT, unsigned BlockSize) {
11887-
assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64 ||
11888-
BlockSize == 128) &&
11889-
"Only possible block sizes for REV are: 16, 32, 64, 128");
11890-
11891-
unsigned EltSz = VT.getScalarSizeInBits();
11892-
unsigned NumElts = VT.getVectorNumElements();
11893-
unsigned BlockElts = M[0] + 1;
11894-
// If the first shuffle index is UNDEF, be optimistic.
11895-
if (M[0] < 0)
11896-
BlockElts = BlockSize / EltSz;
11897-
11898-
if (BlockSize <= EltSz || BlockSize != BlockElts * EltSz)
11899-
return false;
11900-
11901-
for (unsigned i = 0; i < NumElts; ++i) {
11902-
if (M[i] < 0)
11903-
continue; // ignore UNDEF indices
11904-
if ((unsigned)M[i] != (i - i % BlockElts) + (BlockElts - 1 - i % BlockElts))
11905-
return false;
11906-
}
11907-
11908-
return true;
11909-
}
11910-
11911-
static bool isTRNMask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
11912-
unsigned NumElts = VT.getVectorNumElements();
11913-
if (NumElts % 2 != 0)
11914-
return false;
11915-
WhichResult = (M[0] == 0 ? 0 : 1);
11916-
for (unsigned i = 0; i < NumElts; i += 2) {
11917-
if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) ||
11918-
(M[i + 1] >= 0 && (unsigned)M[i + 1] != i + NumElts + WhichResult))
11919-
return false;
11920-
}
11921-
return true;
11922-
}
11923-
1192411883
/// isZIP_v_undef_Mask - Special case of isZIPMask for canonical form of
1192511884
/// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
1192611885
/// Mask is e.g., <0, 0, 1, 1> instead of <0, 4, 1, 5>.
@@ -12585,15 +12544,16 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
1258512544
}
1258612545
}
1258712546

12588-
if (isREVMask(ShuffleMask, VT, 64))
12547+
unsigned NumElts = VT.getVectorNumElements();
12548+
unsigned EltSize = VT.getScalarSizeInBits();
12549+
if (isREVMask(ShuffleMask, EltSize, NumElts, 64))
1258912550
return DAG.getNode(AArch64ISD::REV64, dl, V1.getValueType(), V1, V2);
12590-
if (isREVMask(ShuffleMask, VT, 32))
12551+
if (isREVMask(ShuffleMask, EltSize, NumElts, 32))
1259112552
return DAG.getNode(AArch64ISD::REV32, dl, V1.getValueType(), V1, V2);
12592-
if (isREVMask(ShuffleMask, VT, 16))
12553+
if (isREVMask(ShuffleMask, EltSize, NumElts, 16))
1259312554
return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1, V2);
1259412555

12595-
if (((VT.getVectorNumElements() == 8 && VT.getScalarSizeInBits() == 16) ||
12596-
(VT.getVectorNumElements() == 16 && VT.getScalarSizeInBits() == 8)) &&
12556+
if (((NumElts == 8 && EltSize == 16) || (NumElts == 16 && EltSize == 8)) &&
1259712557
ShuffleVectorInst::isReverseMask(ShuffleMask, ShuffleMask.size())) {
1259812558
SDValue Rev = DAG.getNode(AArch64ISD::REV64, dl, VT, V1);
1259912559
return DAG.getNode(AArch64ISD::EXT, dl, VT, Rev, Rev,
@@ -12615,15 +12575,15 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
1261512575
}
1261612576

1261712577
unsigned WhichResult;
12618-
if (isZIPMask(ShuffleMask, VT, WhichResult)) {
12578+
if (isZIPMask(ShuffleMask, NumElts, WhichResult)) {
1261912579
unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
1262012580
return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
1262112581
}
12622-
if (isUZPMask(ShuffleMask, VT, WhichResult)) {
12582+
if (isUZPMask(ShuffleMask, NumElts, WhichResult)) {
1262312583
unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
1262412584
return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
1262512585
}
12626-
if (isTRNMask(ShuffleMask, VT, WhichResult)) {
12586+
if (isTRNMask(ShuffleMask, NumElts, WhichResult)) {
1262712587
unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
1262812588
return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
1262912589
}
@@ -12655,7 +12615,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
1265512615
int SrcLane = ShuffleMask[Anomaly];
1265612616
if (SrcLane >= NumInputElements) {
1265712617
SrcVec = V2;
12658-
SrcLane -= VT.getVectorNumElements();
12618+
SrcLane -= NumElts;
1265912619
}
1266012620
SDValue SrcLaneV = DAG.getConstant(SrcLane, dl, MVT::i64);
1266112621

@@ -12675,7 +12635,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
1267512635

1267612636
// If the shuffle is not directly supported and it has 4 elements, use
1267712637
// the PerfectShuffle-generated table to synthesize it from other shuffles.
12678-
unsigned NumElts = VT.getVectorNumElements();
1267912638
if (NumElts == 4) {
1268012639
unsigned PFIndexes[4];
1268112640
for (unsigned i = 0; i != 4; ++i) {
@@ -14126,16 +14085,20 @@ bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
1412614085
int DummyInt;
1412714086
unsigned DummyUnsigned;
1412814087

14129-
return (ShuffleVectorSDNode::isSplatMask(&M[0], VT) || isREVMask(M, VT, 64) ||
14130-
isREVMask(M, VT, 32) || isREVMask(M, VT, 16) ||
14088+
unsigned EltSize = VT.getScalarSizeInBits();
14089+
unsigned NumElts = VT.getVectorNumElements();
14090+
return (ShuffleVectorSDNode::isSplatMask(&M[0], VT) ||
14091+
isREVMask(M, EltSize, NumElts, 64) ||
14092+
isREVMask(M, EltSize, NumElts, 32) ||
14093+
isREVMask(M, EltSize, NumElts, 16) ||
1413114094
isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
14132-
// isTBLMask(M, VT) || // FIXME: Port TBL support from ARM.
14133-
isTRNMask(M, VT, DummyUnsigned) || isUZPMask(M, VT, DummyUnsigned) ||
14134-
isZIPMask(M, VT, DummyUnsigned) ||
14095+
isTRNMask(M, NumElts, DummyUnsigned) ||
14096+
isUZPMask(M, NumElts, DummyUnsigned) ||
14097+
isZIPMask(M, NumElts, DummyUnsigned) ||
1413514098
isTRN_v_undef_Mask(M, VT, DummyUnsigned) ||
1413614099
isUZP_v_undef_Mask(M, VT, DummyUnsigned) ||
1413714100
isZIP_v_undef_Mask(M, VT, DummyUnsigned) ||
14138-
isINSMask(M, VT.getVectorNumElements(), DummyBool, DummyInt) ||
14101+
isINSMask(M, NumElts, DummyBool, DummyInt) ||
1413914102
isConcatMask(M, VT, VT.getSizeInBits() == 128));
1414014103
}
1414114104

@@ -27486,15 +27449,15 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
2748627449
return convertFromScalableVector(DAG, VT, Op);
2748727450
}
2748827451

27452+
unsigned EltSize = VT.getScalarSizeInBits();
2748927453
for (unsigned LaneSize : {64U, 32U, 16U}) {
27490-
if (isREVMask(ShuffleMask, VT, LaneSize)) {
27454+
if (isREVMask(ShuffleMask, EltSize, VT.getVectorNumElements(), LaneSize)) {
2749127455
EVT NewVT =
2749227456
getPackedSVEVectorVT(EVT::getIntegerVT(*DAG.getContext(), LaneSize));
2749327457
unsigned RevOp;
27494-
unsigned EltSz = VT.getScalarSizeInBits();
27495-
if (EltSz == 8)
27458+
if (EltSize == 8)
2749627459
RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU;
27497-
else if (EltSz == 16)
27460+
else if (EltSize == 16)
2749827461
RevOp = AArch64ISD::REVH_MERGE_PASSTHRU;
2749927462
else
2750027463
RevOp = AArch64ISD::REVW_MERGE_PASSTHRU;
@@ -27506,8 +27469,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
2750627469
}
2750727470
}
2750827471

27509-
if (Subtarget->hasSVE2p1() && VT.getScalarSizeInBits() == 64 &&
27510-
isREVMask(ShuffleMask, VT, 128)) {
27472+
if (Subtarget->hasSVE2p1() && EltSize == 64 &&
27473+
isREVMask(ShuffleMask, EltSize, VT.getVectorNumElements(), 128)) {
2751127474
if (!VT.isFloatingPoint())
2751227475
return LowerToPredicatedOp(Op, DAG, AArch64ISD::REVD_MERGE_PASSTHRU);
2751327476

@@ -27519,11 +27482,12 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
2751927482
}
2752027483

2752127484
unsigned WhichResult;
27522-
if (isZIPMask(ShuffleMask, VT, WhichResult) && WhichResult == 0)
27485+
if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
27486+
WhichResult == 0)
2752327487
return convertFromScalableVector(
2752427488
DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op2));
2752527489

27526-
if (isTRNMask(ShuffleMask, VT, WhichResult)) {
27490+
if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
2752727491
unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
2752827492
return convertFromScalableVector(
2752927493
DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));
@@ -27566,11 +27530,12 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
2756627530
return convertFromScalableVector(DAG, VT, Op);
2756727531
}
2756827532

27569-
if (isZIPMask(ShuffleMask, VT, WhichResult) && WhichResult != 0)
27533+
if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
27534+
WhichResult != 0)
2757027535
return convertFromScalableVector(
2757127536
DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op2));
2757227537

27573-
if (isUZPMask(ShuffleMask, VT, WhichResult)) {
27538+
if (isUZPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
2757427539
unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
2757527540
return convertFromScalableVector(
2757627541
DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));

llvm/lib/Target/AArch64/AArch64PerfectShuffle.h

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6588,7 +6588,7 @@ static const unsigned PerfectShuffleTable[6561 + 1] = {
65886588
835584U, // <u,u,u,u>: Cost 0 copy LHS
65896589
0};
65906590

6591-
static unsigned getPerfectShuffleCost(llvm::ArrayRef<int> M) {
6591+
inline unsigned getPerfectShuffleCost(llvm::ArrayRef<int> M) {
65926592
assert(M.size() == 4 && "Expected a 4 entry perfect shuffle");
65936593

65946594
// Special case zero-cost nop copies, from either LHS or RHS.
@@ -6623,8 +6623,8 @@ static unsigned getPerfectShuffleCost(llvm::ArrayRef<int> M) {
66236623
/// Return true for zip1 or zip2 masks of the form:
66246624
/// <0, 8, 1, 9, 2, 10, 3, 11> or
66256625
/// <4, 12, 5, 13, 6, 14, 7, 15>
6626-
inline bool isZIPMask(ArrayRef<int> M, EVT VT, unsigned &WhichResultOut) {
6627-
unsigned NumElts = VT.getVectorNumElements();
6626+
inline bool isZIPMask(ArrayRef<int> M, unsigned NumElts,
6627+
unsigned &WhichResultOut) {
66286628
if (NumElts % 2 != 0)
66296629
return false;
66306630
// Check the first non-undef element for which half to use.
@@ -6656,8 +6656,8 @@ inline bool isZIPMask(ArrayRef<int> M, EVT VT, unsigned &WhichResultOut) {
66566656
/// Return true for uzp1 or uzp2 masks of the form:
66576657
/// <0, 2, 4, 6, 8, 10, 12, 14> or
66586658
/// <1, 3, 5, 7, 9, 11, 13, 15>
6659-
inline bool isUZPMask(ArrayRef<int> M, EVT VT, unsigned &WhichResultOut) {
6660-
unsigned NumElts = VT.getVectorNumElements();
6659+
inline bool isUZPMask(ArrayRef<int> M, unsigned NumElts,
6660+
unsigned &WhichResultOut) {
66616661
// Check the first non-undef element for which half to use.
66626662
unsigned WhichResult = 2;
66636663
for (unsigned i = 0; i != NumElts; i++) {
@@ -6680,6 +6680,49 @@ inline bool isUZPMask(ArrayRef<int> M, EVT VT, unsigned &WhichResultOut) {
66806680
return true;
66816681
}
66826682

6683+
/// Return true for trn1 or trn2 masks of the form:
6684+
/// <0, 8, 2, 10, 4, 12, 6, 14> or
6685+
/// <1, 9, 3, 11, 5, 13, 7, 15>
6686+
inline bool isTRNMask(ArrayRef<int> M, unsigned NumElts,
6687+
unsigned &WhichResult) {
6688+
if (NumElts % 2 != 0)
6689+
return false;
6690+
WhichResult = (M[0] == 0 ? 0 : 1);
6691+
for (unsigned i = 0; i < NumElts; i += 2) {
6692+
if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) ||
6693+
(M[i + 1] >= 0 && (unsigned)M[i + 1] != i + NumElts + WhichResult))
6694+
return false;
6695+
}
6696+
return true;
6697+
}
6698+
6699+
/// isREVMask - Check if a vector shuffle corresponds to a REV
6700+
/// instruction with the specified blocksize. (The order of the elements
6701+
/// within each block of the vector is reversed.)
6702+
inline bool isREVMask(ArrayRef<int> M, unsigned EltSize, unsigned NumElts,
6703+
unsigned BlockSize) {
6704+
assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64 ||
6705+
BlockSize == 128) &&
6706+
"Only possible block sizes for REV are: 16, 32, 64, 128");
6707+
6708+
unsigned BlockElts = M[0] + 1;
6709+
// If the first shuffle index is UNDEF, be optimistic.
6710+
if (M[0] < 0)
6711+
BlockElts = BlockSize / EltSize;
6712+
6713+
if (BlockSize <= EltSize || BlockSize != BlockElts * EltSize)
6714+
return false;
6715+
6716+
for (unsigned i = 0; i < NumElts; ++i) {
6717+
if (M[i] < 0)
6718+
continue; // ignore UNDEF indices
6719+
if ((unsigned)M[i] != (i - i % BlockElts) + (BlockElts - 1 - i % BlockElts))
6720+
return false;
6721+
}
6722+
6723+
return true;
6724+
}
6725+
66836726
} // namespace llvm
66846727

66856728
#endif

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3968,8 +3968,8 @@ InstructionCost AArch64TTIImpl::getShuffleCost(
39683968
if (LT.second.isFixedLengthVector() &&
39693969
LT.second.getVectorNumElements() == Mask.size() &&
39703970
(Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) &&
3971-
(isZIPMask(Mask, LT.second, Unused) ||
3972-
isUZPMask(Mask, LT.second, Unused) ||
3971+
(isZIPMask(Mask, LT.second.getVectorNumElements(), Unused) ||
3972+
isUZPMask(Mask, LT.second.getVectorNumElements(), Unused) ||
39733973
// Check for non-zero lane splats
39743974
all_of(drop_begin(Mask),
39753975
[&Mask](int M) { return M < 0 || M == Mask[0]; })))

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

Lines changed: 2 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
//===----------------------------------------------------------------------===//
2121

2222
#include "AArch64GlobalISelUtils.h"
23+
#include "AArch64PerfectShuffle.h"
2324
#include "AArch64Subtarget.h"
2425
#include "AArch64TargetMachine.h"
2526
#include "GISel/AArch64LegalizerInfo.h"
@@ -77,50 +78,6 @@ struct ShuffleVectorPseudo {
7778
ShuffleVectorPseudo() = default;
7879
};
7980

80-
/// Check if a vector shuffle corresponds to a REV instruction with the
81-
/// specified blocksize.
82-
bool isREVMask(ArrayRef<int> M, unsigned EltSize, unsigned NumElts,
83-
unsigned BlockSize) {
84-
assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
85-
"Only possible block sizes for REV are: 16, 32, 64");
86-
assert(EltSize != 64 && "EltSize cannot be 64 for REV mask.");
87-
88-
unsigned BlockElts = M[0] + 1;
89-
90-
// If the first shuffle index is UNDEF, be optimistic.
91-
if (M[0] < 0)
92-
BlockElts = BlockSize / EltSize;
93-
94-
if (BlockSize <= EltSize || BlockSize != BlockElts * EltSize)
95-
return false;
96-
97-
for (unsigned i = 0; i < NumElts; ++i) {
98-
// Ignore undef indices.
99-
if (M[i] < 0)
100-
continue;
101-
if (static_cast<unsigned>(M[i]) !=
102-
(i - i % BlockElts) + (BlockElts - 1 - i % BlockElts))
103-
return false;
104-
}
105-
106-
return true;
107-
}
108-
109-
/// Determines if \p M is a shuffle vector mask for a TRN of \p NumElts.
110-
/// Whether or not G_TRN1 or G_TRN2 should be used is stored in \p WhichResult.
111-
bool isTRNMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
112-
if (NumElts % 2 != 0)
113-
return false;
114-
WhichResult = (M[0] == 0 ? 0 : 1);
115-
for (unsigned i = 0; i < NumElts; i += 2) {
116-
if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != i + WhichResult) ||
117-
(M[i + 1] >= 0 &&
118-
static_cast<unsigned>(M[i + 1]) != i + NumElts + WhichResult))
119-
return false;
120-
}
121-
return true;
122-
}
123-
12481
/// Check if a G_EXT instruction can handle a shuffle mask \p M when the vector
12582
/// sources of the shuffle are different.
12683
std::optional<std::pair<bool, uint64_t>> getExtMask(ArrayRef<int> M,
@@ -163,38 +120,6 @@ std::optional<std::pair<bool, uint64_t>> getExtMask(ArrayRef<int> M,
163120
return std::make_pair(ReverseExt, Imm);
164121
}
165122

166-
/// Determines if \p M is a shuffle vector mask for a UZP of \p NumElts.
167-
/// Whether or not G_UZP1 or G_UZP2 should be used is stored in \p WhichResult.
168-
bool isUZPMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
169-
WhichResult = (M[0] == 0 ? 0 : 1);
170-
for (unsigned i = 0; i != NumElts; ++i) {
171-
// Skip undef indices.
172-
if (M[i] < 0)
173-
continue;
174-
if (static_cast<unsigned>(M[i]) != 2 * i + WhichResult)
175-
return false;
176-
}
177-
return true;
178-
}
179-
180-
/// \return true if \p M is a zip mask for a shuffle vector of \p NumElts.
181-
/// Whether or not G_ZIP1 or G_ZIP2 should be used is stored in \p WhichResult.
182-
bool isZipMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
183-
if (NumElts % 2 != 0)
184-
return false;
185-
186-
// 0 means use ZIP1, 1 means use ZIP2.
187-
WhichResult = (M[0] == 0 ? 0 : 1);
188-
unsigned Idx = WhichResult * NumElts / 2;
189-
for (unsigned i = 0; i != NumElts; i += 2) {
190-
if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != Idx) ||
191-
(M[i + 1] >= 0 && static_cast<unsigned>(M[i + 1]) != Idx + NumElts))
192-
return false;
193-
Idx += 1;
194-
}
195-
return true;
196-
}
197-
198123
/// Helper function for matchINS.
199124
///
200125
/// \returns a value when \p M is an ins mask for \p NumInputElements.
@@ -308,7 +233,7 @@ bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
308233
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
309234
Register Dst = MI.getOperand(0).getReg();
310235
unsigned NumElts = MRI.getType(Dst).getNumElements();
311-
if (!isZipMask(ShuffleMask, NumElts, WhichResult))
236+
if (!isZIPMask(ShuffleMask, NumElts, WhichResult))
312237
return false;
313238
unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
314239
Register V1 = MI.getOperand(1).getReg();

0 commit comments

Comments
 (0)