Skip to content

Commit bf195ad

Browse files
committed
[VectorCombine] Handle shuffle of selects
(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) -> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m)) The behaviour of SelectInst on vectors is the same as for `V'select[i] = Condition[i] ? V'True[i] : V'False[i]`. If a ShuffleVector is performed on two selects, it will be like: `V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])` That's why a ShuffleVector with two SelectInst is equivalent to first ShuffleVector Condition/True/False and then SelectInst that result. This patch implements the transforming described above. Proof: https://alive2.llvm.org/ce/z/97wfHp Fixed: #120775
1 parent 28917e8 commit bf195ad

File tree

3 files changed

+94
-21
lines changed

3 files changed

+94
-21
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class VectorCombine {
119119
bool foldConcatOfBoolMasks(Instruction &I);
120120
bool foldPermuteOfBinops(Instruction &I);
121121
bool foldShuffleOfBinops(Instruction &I);
122+
bool foldShuffleOfSelects(Instruction &I);
122123
bool foldShuffleOfCastops(Instruction &I);
123124
bool foldShuffleOfShuffles(Instruction &I);
124125
bool foldShuffleOfIntrinsics(Instruction &I);
@@ -1899,6 +1900,56 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
18991900
return true;
19001901
}
19011902

1903+
/// Try to convert,
1904+
/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
1905+
/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
1906+
bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
1907+
ArrayRef<int> Mask;
1908+
Value *C1, *T1, *F1, *C2, *T2, *F2;
1909+
if (!match(&I, m_Shuffle(
1910+
m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
1911+
m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
1912+
m_Mask(Mask))))
1913+
return false;
1914+
1915+
auto SelectOp = Instruction::Select;
1916+
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
1917+
auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
1918+
auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
1919+
if (!C1VecTy || !C2VecTy)
1920+
return false;
1921+
1922+
InstructionCost OldCost = TTI.getCmpSelInstrCost(
1923+
SelectOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
1924+
OldCost += TTI.getCmpSelInstrCost(SelectOp, T2->getType(), C2VecTy,
1925+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1926+
OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, DstVecTy,
1927+
Mask, CostKind, 0, nullptr, {I.getOperand(0), I.getOperand(1)}, &I);
1928+
1929+
auto *C1C2VecTy = cast<FixedVectorType>(
1930+
toVectorTy(Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements()));
1931+
InstructionCost NewCost =
1932+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, C1C2VecTy, Mask,
1933+
CostKind, 0, nullptr, {C1, C2});
1934+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, DstVecTy,
1935+
Mask, CostKind, 0, nullptr, {T1, T2});
1936+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, DstVecTy,
1937+
Mask, CostKind, 0, nullptr, {F1, F2});
1938+
NewCost += TTI.getCmpSelInstrCost(SelectOp, DstVecTy, DstVecTy,
1939+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1940+
1941+
if (NewCost > OldCost)
1942+
return false;
1943+
1944+
Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
1945+
Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
1946+
Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
1947+
Value *NewShuf = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
1948+
1949+
replaceValue(I, *NewShuf);
1950+
return true;
1951+
}
1952+
19021953
/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
19031954
/// into "castop (shuffle)".
19041955
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
@@ -3352,6 +3403,7 @@ bool VectorCombine::run() {
33523403
case Instruction::ShuffleVector:
33533404
MadeChange |= foldPermuteOfBinops(I);
33543405
MadeChange |= foldShuffleOfBinops(I);
3406+
MadeChange |= foldShuffleOfSelects(I);
33553407
MadeChange |= foldShuffleOfCastops(I);
33563408
MadeChange |= foldShuffleOfShuffles(I);
33573409
MadeChange |= foldShuffleOfIntrinsics(I);

llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -451,18 +451,18 @@ define <8 x i8> @icmpsel(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
451451

452452
define <8 x i8> @icmpsel_diffentcond(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
453453
; CHECK-LABEL: @icmpsel_diffentcond(
454-
; CHECK-NEXT: [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
455-
; CHECK-NEXT: [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
456-
; CHECK-NEXT: [[BB:%.*]] = shufflevector <8 x i8> [[B:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
457-
; CHECK-NEXT: [[BT:%.*]] = shufflevector <8 x i8> [[B]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
458454
; CHECK-NEXT: [[CB:%.*]] = shufflevector <8 x i8> [[C:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
459455
; CHECK-NEXT: [[CT:%.*]] = shufflevector <8 x i8> [[C]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
460456
; CHECK-NEXT: [[DB:%.*]] = shufflevector <8 x i8> [[D:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
461457
; CHECK-NEXT: [[DT:%.*]] = shufflevector <8 x i8> [[D]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
462-
; CHECK-NEXT: [[ABT1:%.*]] = icmp slt <4 x i8> [[AT]], [[BT]]
463-
; CHECK-NEXT: [[ABB1:%.*]] = icmp ult <4 x i8> [[AB]], [[BB]]
464-
; CHECK-NEXT: [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT]], <4 x i8> [[DT]]
465-
; CHECK-NEXT: [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB]], <4 x i8> [[DB]]
458+
; CHECK-NEXT: [[CB1:%.*]] = shufflevector <8 x i8> [[C1:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
459+
; CHECK-NEXT: [[CT1:%.*]] = shufflevector <8 x i8> [[C1]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
460+
; CHECK-NEXT: [[DB1:%.*]] = shufflevector <8 x i8> [[D1:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
461+
; CHECK-NEXT: [[DT1:%.*]] = shufflevector <8 x i8> [[D1]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
462+
; CHECK-NEXT: [[ABT1:%.*]] = icmp slt <4 x i8> [[CT]], [[DT]]
463+
; CHECK-NEXT: [[ABB1:%.*]] = icmp ult <4 x i8> [[CB]], [[DB]]
464+
; CHECK-NEXT: [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT1]], <4 x i8> [[DT1]]
465+
; CHECK-NEXT: [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB1]], <4 x i8> [[DB1]]
466466
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x i8> [[ABT]], <4 x i8> [[ABB]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
467467
; CHECK-NEXT: ret <8 x i8> [[R]]
468468
;
@@ -992,14 +992,15 @@ define void @maximal_legal_fpmath(ptr %addr1, ptr %addr2, ptr %result, float %va
992992
}
993993

994994
; Peek through (repeated) bitcasts to find a common source value.
995+
; TODO : We can remove the Shufflevector for A, B.
995996
define <4 x i64> @bitcast_smax_v8i32_v4i32(<4 x i64> %a, <4 x i64> %b) {
996997
; CHECK-LABEL: @bitcast_smax_v8i32_v4i32(
997998
; CHECK-NEXT: [[A_BC0:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
998999
; CHECK-NEXT: [[B_BC0:%.*]] = bitcast <4 x i64> [[B:%.*]] to <8 x i32>
999-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
1000-
; CHECK-NEXT: [[A_BC1:%.*]] = bitcast <4 x i64> [[A]] to <8 x i32>
1001-
; CHECK-NEXT: [[B_BC1:%.*]] = bitcast <4 x i64> [[B]] to <8 x i32>
1002-
; CHECK-NEXT: [[CONCAT:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[B_BC1]], <8 x i32> [[A_BC1]]
1000+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
1001+
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <4 x i64> [[A]] to <8 x i32>
1002+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i64> [[B]] to <8 x i32>
1003+
; CHECK-NEXT: [[CONCAT:%.*]] = select <8 x i1> [[TMP1]], <8 x i32> [[TMP3]], <8 x i32> [[TMP5]]
10031004
; CHECK-NEXT: [[RES:%.*]] = bitcast <8 x i32> [[CONCAT]] to <4 x i64>
10041005
; CHECK-NEXT: ret <4 x i64> [[RES]]
10051006
;

llvm/test/Transforms/VectorCombine/X86/shuffle-of-selects.ll

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,34 @@
44
; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64-- -mcpu=x86-64-v4 | FileCheck %s --check-prefixes=CHECK,AVX512
55

66
define <8 x i16> @src_v4tov8_i16(<4 x i16> %x, <4 x i16> %y, <4 x i16> %z) {
7-
; CHECK-LABEL: define <8 x i16> @src_v4tov8_i16(
8-
; CHECK-SAME: <4 x i16> [[X:%.*]], <4 x i16> [[Y:%.*]], <4 x i16> [[Z:%.*]]) #[[ATTR0:[0-9]+]] {
9-
; CHECK-NEXT: [[CMP_XY:%.*]] = icmp slt <4 x i16> [[X]], [[Y]]
10-
; CHECK-NEXT: [[CMP_YZ:%.*]] = icmp slt <4 x i16> [[Y]], [[Z]]
11-
; CHECK-NEXT: [[SELECT_XZ:%.*]] = select <4 x i1> [[CMP_XY]], <4 x i16> [[X]], <4 x i16> [[Z]]
12-
; CHECK-NEXT: [[SELECT_YX:%.*]] = select <4 x i1> [[CMP_YZ]], <4 x i16> [[Y]], <4 x i16> [[X]]
13-
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x i16> [[SELECT_XZ]], <4 x i16> [[SELECT_YX]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
14-
; CHECK-NEXT: ret <8 x i16> [[RES]]
7+
; SSE-LABEL: define <8 x i16> @src_v4tov8_i16(
8+
; SSE-SAME: <4 x i16> [[X:%.*]], <4 x i16> [[Y:%.*]], <4 x i16> [[Z:%.*]]) #[[ATTR0:[0-9]+]] {
9+
; SSE-NEXT: [[CMP_XY:%.*]] = icmp slt <4 x i16> [[X]], [[Y]]
10+
; SSE-NEXT: [[CMP_YZ:%.*]] = icmp slt <4 x i16> [[Y]], [[Z]]
11+
; SSE-NEXT: [[TMP1:%.*]] = shufflevector <4 x i1> [[CMP_XY]], <4 x i1> [[CMP_YZ]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
12+
; SSE-NEXT: [[TMP2:%.*]] = shufflevector <4 x i16> [[X]], <4 x i16> [[Y]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
13+
; SSE-NEXT: [[TMP3:%.*]] = shufflevector <4 x i16> [[Z]], <4 x i16> [[X]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
14+
; SSE-NEXT: [[RES:%.*]] = select <8 x i1> [[TMP1]], <8 x i16> [[TMP2]], <8 x i16> [[TMP3]]
15+
; SSE-NEXT: ret <8 x i16> [[RES]]
16+
;
17+
; AVX2-LABEL: define <8 x i16> @src_v4tov8_i16(
18+
; AVX2-SAME: <4 x i16> [[X:%.*]], <4 x i16> [[Y:%.*]], <4 x i16> [[Z:%.*]]) #[[ATTR0:[0-9]+]] {
19+
; AVX2-NEXT: [[CMP_XY:%.*]] = icmp slt <4 x i16> [[X]], [[Y]]
20+
; AVX2-NEXT: [[CMP_YZ:%.*]] = icmp slt <4 x i16> [[Y]], [[Z]]
21+
; AVX2-NEXT: [[TMP1:%.*]] = shufflevector <4 x i1> [[CMP_XY]], <4 x i1> [[CMP_YZ]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
22+
; AVX2-NEXT: [[TMP2:%.*]] = shufflevector <4 x i16> [[X]], <4 x i16> [[Y]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
23+
; AVX2-NEXT: [[TMP3:%.*]] = shufflevector <4 x i16> [[Z]], <4 x i16> [[X]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
24+
; AVX2-NEXT: [[RES:%.*]] = select <8 x i1> [[TMP1]], <8 x i16> [[TMP2]], <8 x i16> [[TMP3]]
25+
; AVX2-NEXT: ret <8 x i16> [[RES]]
26+
;
27+
; AVX512-LABEL: define <8 x i16> @src_v4tov8_i16(
28+
; AVX512-SAME: <4 x i16> [[X:%.*]], <4 x i16> [[Y:%.*]], <4 x i16> [[Z:%.*]]) #[[ATTR0:[0-9]+]] {
29+
; AVX512-NEXT: [[CMP_XY:%.*]] = icmp slt <4 x i16> [[X]], [[Y]]
30+
; AVX512-NEXT: [[CMP_YZ:%.*]] = icmp slt <4 x i16> [[Y]], [[Z]]
31+
; AVX512-NEXT: [[SELECT_XZ:%.*]] = select <4 x i1> [[CMP_XY]], <4 x i16> [[X]], <4 x i16> [[Z]]
32+
; AVX512-NEXT: [[SELECT_YX:%.*]] = select <4 x i1> [[CMP_YZ]], <4 x i16> [[Y]], <4 x i16> [[X]]
33+
; AVX512-NEXT: [[RES:%.*]] = shufflevector <4 x i16> [[SELECT_XZ]], <4 x i16> [[SELECT_YX]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
34+
; AVX512-NEXT: ret <8 x i16> [[RES]]
1535
;
1636
%cmp.xy = icmp slt <4 x i16> %x, %y
1737
%cmp.yz = icmp slt <4 x i16> %y, %z
@@ -173,7 +193,7 @@ define <16 x i32> @src_v8tov16_i32(<8 x i32> %x, <8 x i32> %y, <8 x i32> %z) {
173193

174194
define <32 x i32> @src_v16tov32_i32(<16 x i32> %x, <16 x i32> %y, <16 x i32> %z) {
175195
; CHECK-LABEL: define <32 x i32> @src_v16tov32_i32(
176-
; CHECK-SAME: <16 x i32> [[X:%.*]], <16 x i32> [[Y:%.*]], <16 x i32> [[Z:%.*]]) #[[ATTR0]] {
196+
; CHECK-SAME: <16 x i32> [[X:%.*]], <16 x i32> [[Y:%.*]], <16 x i32> [[Z:%.*]]) #[[ATTR0:[0-9]+]] {
177197
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i32> [[X]], <16 x i32> [[Y]], <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
178198
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <16 x i32> [[Y]], <16 x i32> [[Z]], <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
179199
; CHECK-NEXT: [[TMP3:%.*]] = icmp slt <32 x i32> [[TMP1]], [[TMP2]]

0 commit comments

Comments
 (0)