Skip to content

Commit 7abea04

Browse files
committed
[VectorCombine] Handle shuffle of selects
(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) -> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m)) The behaviour of SelectInst on vectors is the same as for `V'select[i] = Condition[i] ? V'True[i] : V'False[i]`. If a ShuffleVector is performed on two selects, it will be like: `V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])` That's why a ShuffleVector with two SelectInst is equivalent to first ShuffleVector Condition/True/False and then SelectInst that result. This patch implements the transforming described above. Proof: https://alive2.llvm.org/ce/z/97wfHp Fixed: #120775
1 parent 4441c2d commit 7abea04

File tree

4 files changed

+82
-32
lines changed

4 files changed

+82
-32
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class VectorCombine {
119119
bool foldConcatOfBoolMasks(Instruction &I);
120120
bool foldPermuteOfBinops(Instruction &I);
121121
bool foldShuffleOfBinops(Instruction &I);
122+
bool foldShuffleOfSelects(Instruction &I);
122123
bool foldShuffleOfCastops(Instruction &I);
123124
bool foldShuffleOfShuffles(Instruction &I);
124125
bool foldShuffleOfIntrinsics(Instruction &I);
@@ -1899,6 +1900,55 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
18991900
return true;
19001901
}
19011902

1903+
/// Try to convert,
1904+
/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
1905+
/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
1906+
bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
1907+
ArrayRef<int> Mask;
1908+
Value *C1, *T1, *F1, *C2, *T2, *F2;
1909+
if (!match(&I, m_Shuffle(
1910+
m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
1911+
m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
1912+
m_Mask(Mask))))
1913+
return false;
1914+
1915+
auto SelectOp = Instruction::Select;
1916+
1917+
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
1918+
auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
1919+
auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
1920+
auto *T1VecTy = dyn_cast<FixedVectorType>(T1->getType());
1921+
auto *F1VecTy = dyn_cast<FixedVectorType>(F1->getType());
1922+
1923+
if (!C1VecTy || !C2VecTy)
1924+
return false;
1925+
1926+
InstructionCost OldCost = TTI.getCmpSelInstrCost(
1927+
SelectOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
1928+
OldCost += TTI.getCmpSelInstrCost(SelectOp, T2->getType(), C2VecTy,
1929+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1930+
OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, DstVecTy,
1931+
Mask, CostKind);
1932+
1933+
InstructionCost NewCost =
1934+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, C1VecTy, Mask,
1935+
CostKind, 0, nullptr, {C1, C2}, &I);
1936+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, T1VecTy,
1937+
Mask, CostKind, 0, nullptr, {T1, T2}, &I);
1938+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, F1VecTy,
1939+
Mask, CostKind, 0, nullptr, {F1, F2}, &I);
1940+
NewCost += TTI.getCmpSelInstrCost(SelectOp, T1->getType(), DstVecTy,
1941+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1942+
1943+
Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
1944+
Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
1945+
Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
1946+
Value *NewRes = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
1947+
1948+
replaceValue(I, *NewRes);
1949+
return true;
1950+
}
1951+
19021952
/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
19031953
/// into "castop (shuffle)".
19041954
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
@@ -3352,6 +3402,7 @@ bool VectorCombine::run() {
33523402
case Instruction::ShuffleVector:
33533403
MadeChange |= foldPermuteOfBinops(I);
33543404
MadeChange |= foldShuffleOfBinops(I);
3405+
MadeChange |= foldShuffleOfSelects(I);
33553406
MadeChange |= foldShuffleOfCastops(I);
33563407
MadeChange |= foldShuffleOfShuffles(I);
33573408
MadeChange |= foldShuffleOfIntrinsics(I);

llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ define <4 x i64> @x86_pblendvb_v8i32_v4i32(<4 x i64> %a, <4 x i64> %b, <4 x i64>
8787
; CHECK-LABEL: @x86_pblendvb_v8i32_v4i32(
8888
; CHECK-NEXT: [[C_BC:%.*]] = bitcast <4 x i64> [[C:%.*]] to <8 x i32>
8989
; CHECK-NEXT: [[D_BC:%.*]] = bitcast <4 x i64> [[D:%.*]] to <8 x i32>
90-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <8 x i32> [[C_BC]], [[D_BC]]
91-
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
90+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <8 x i32> [[C_BC]], [[D_BC]]
9291
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i64> [[B:%.*]] to <8 x i32>
93-
; CHECK-NEXT: [[TMP3:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[TMP2]], <8 x i32> [[TMP1]]
94-
; CHECK-NEXT: [[RES:%.*]] = bitcast <8 x i32> [[TMP3]] to <4 x i64>
92+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
93+
; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> [[TMP1]], <8 x i32> [[TMP2]], <8 x i32> [[TMP3]]
94+
; CHECK-NEXT: [[RES:%.*]] = bitcast <8 x i32> [[TMP4]] to <4 x i64>
9595
; CHECK-NEXT: ret <4 x i64> [[RES]]
9696
;
9797
%a.bc = bitcast <4 x i64> %a to <32 x i8>
@@ -118,11 +118,11 @@ define <4 x i64> @x86_pblendvb_v16i16_v8i16(<4 x i64> %a, <4 x i64> %b, <4 x i64
118118
; CHECK-LABEL: @x86_pblendvb_v16i16_v8i16(
119119
; CHECK-NEXT: [[C_BC:%.*]] = bitcast <4 x i64> [[C:%.*]] to <16 x i16>
120120
; CHECK-NEXT: [[D_BC:%.*]] = bitcast <4 x i64> [[D:%.*]] to <16 x i16>
121-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <16 x i16> [[C_BC]], [[D_BC]]
122-
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i64> [[A:%.*]] to <16 x i16>
121+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <16 x i16> [[C_BC]], [[D_BC]]
123122
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i64> [[B:%.*]] to <16 x i16>
124-
; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[CMP]], <16 x i16> [[TMP2]], <16 x i16> [[TMP1]]
125-
; CHECK-NEXT: [[RES:%.*]] = bitcast <16 x i16> [[TMP3]] to <4 x i64>
123+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i64> [[A:%.*]] to <16 x i16>
124+
; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> [[TMP1]], <16 x i16> [[TMP2]], <16 x i16> [[TMP3]]
125+
; CHECK-NEXT: [[RES:%.*]] = bitcast <16 x i16> [[TMP4]] to <4 x i64>
126126
; CHECK-NEXT: ret <4 x i64> [[RES]]
127127
;
128128
%a.bc = bitcast <4 x i64> %a to <32 x i8>
@@ -255,11 +255,11 @@ define <8 x i64> @x86_pblendvb_v16i32_v8i32(<8 x i64> %a, <8 x i64> %b, <8 x i64
255255
; CHECK-LABEL: @x86_pblendvb_v16i32_v8i32(
256256
; CHECK-NEXT: [[C_BC:%.*]] = bitcast <8 x i64> [[C:%.*]] to <16 x i32>
257257
; CHECK-NEXT: [[D_BC:%.*]] = bitcast <8 x i64> [[D:%.*]] to <16 x i32>
258-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <16 x i32> [[C_BC]], [[D_BC]]
259-
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i64> [[A:%.*]] to <16 x i32>
258+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <16 x i32> [[C_BC]], [[D_BC]]
260259
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i64> [[B:%.*]] to <16 x i32>
261-
; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[CMP]], <16 x i32> [[TMP2]], <16 x i32> [[TMP1]]
262-
; CHECK-NEXT: [[RES:%.*]] = bitcast <16 x i32> [[TMP3]] to <8 x i64>
260+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i64> [[A:%.*]] to <16 x i32>
261+
; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> [[TMP1]], <16 x i32> [[TMP2]], <16 x i32> [[TMP3]]
262+
; CHECK-NEXT: [[RES:%.*]] = bitcast <16 x i32> [[TMP4]] to <8 x i64>
263263
; CHECK-NEXT: ret <8 x i64> [[RES]]
264264
;
265265
%a.bc = bitcast <8 x i64> %a to <64 x i8>
@@ -286,11 +286,11 @@ define <8 x i64> @x86_pblendvb_v32i16_v16i16(<8 x i64> %a, <8 x i64> %b, <8 x i6
286286
; CHECK-LABEL: @x86_pblendvb_v32i16_v16i16(
287287
; CHECK-NEXT: [[C_BC:%.*]] = bitcast <8 x i64> [[C:%.*]] to <32 x i16>
288288
; CHECK-NEXT: [[D_BC:%.*]] = bitcast <8 x i64> [[D:%.*]] to <32 x i16>
289-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <32 x i16> [[C_BC]], [[D_BC]]
290-
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i64> [[A:%.*]] to <32 x i16>
289+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <32 x i16> [[C_BC]], [[D_BC]]
291290
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i64> [[B:%.*]] to <32 x i16>
292-
; CHECK-NEXT: [[TMP3:%.*]] = select <32 x i1> [[CMP]], <32 x i16> [[TMP2]], <32 x i16> [[TMP1]]
293-
; CHECK-NEXT: [[RES:%.*]] = bitcast <32 x i16> [[TMP3]] to <8 x i64>
291+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i64> [[A:%.*]] to <32 x i16>
292+
; CHECK-NEXT: [[TMP4:%.*]] = select <32 x i1> [[TMP1]], <32 x i16> [[TMP2]], <32 x i16> [[TMP3]]
293+
; CHECK-NEXT: [[RES:%.*]] = bitcast <32 x i16> [[TMP4]] to <8 x i64>
294294
; CHECK-NEXT: ret <8 x i64> [[RES]]
295295
;
296296
%a.bc = bitcast <8 x i64> %a to <64 x i8>

llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -451,19 +451,14 @@ define <8 x i8> @icmpsel(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
451451

452452
define <8 x i8> @icmpsel_diffentcond(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
453453
; CHECK-LABEL: @icmpsel_diffentcond(
454-
; CHECK-NEXT: [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
455-
; CHECK-NEXT: [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
456-
; CHECK-NEXT: [[BB:%.*]] = shufflevector <8 x i8> [[B:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
457-
; CHECK-NEXT: [[BT:%.*]] = shufflevector <8 x i8> [[B]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
458454
; CHECK-NEXT: [[CB:%.*]] = shufflevector <8 x i8> [[C:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
459455
; CHECK-NEXT: [[CT:%.*]] = shufflevector <8 x i8> [[C]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
460456
; CHECK-NEXT: [[DB:%.*]] = shufflevector <8 x i8> [[D:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
461457
; CHECK-NEXT: [[DT:%.*]] = shufflevector <8 x i8> [[D]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
462-
; CHECK-NEXT: [[ABT1:%.*]] = icmp slt <4 x i8> [[AT]], [[BT]]
463-
; CHECK-NEXT: [[ABB1:%.*]] = icmp ult <4 x i8> [[AB]], [[BB]]
464-
; CHECK-NEXT: [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT]], <4 x i8> [[DT]]
465-
; CHECK-NEXT: [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB]], <4 x i8> [[DB]]
466-
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x i8> [[ABT]], <4 x i8> [[ABB]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
458+
; CHECK-NEXT: [[ABT1:%.*]] = icmp slt <4 x i8> [[CT]], [[DT]]
459+
; CHECK-NEXT: [[ABB1:%.*]] = icmp ult <4 x i8> [[CB]], [[DB]]
460+
; CHECK-NEXT: [[SHUFFLE_CMP:%.*]] = shufflevector <4 x i1> [[ABT1]], <4 x i1> [[ABB1]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
461+
; CHECK-NEXT: [[R:%.*]] = select <8 x i1> [[SHUFFLE_CMP]], <8 x i8> [[C1:%.*]], <8 x i8> [[D1:%.*]]
467462
; CHECK-NEXT: ret <8 x i8> [[R]]
468463
;
469464
%ab = shufflevector <8 x i8> %a, <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -992,14 +987,17 @@ define void @maximal_legal_fpmath(ptr %addr1, ptr %addr2, ptr %result, float %va
992987
}
993988

994989
; Peek through (repeated) bitcasts to find a common source value.
990+
; TODO : We can remove the Shufflevector for A, B.
995991
define <4 x i64> @bitcast_smax_v8i32_v4i32(<4 x i64> %a, <4 x i64> %b) {
996992
; CHECK-LABEL: @bitcast_smax_v8i32_v4i32(
997993
; CHECK-NEXT: [[A_BC0:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
998994
; CHECK-NEXT: [[B_BC0:%.*]] = bitcast <4 x i64> [[B:%.*]] to <8 x i32>
999-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
1000-
; CHECK-NEXT: [[A_BC1:%.*]] = bitcast <4 x i64> [[A]] to <8 x i32>
1001-
; CHECK-NEXT: [[B_BC1:%.*]] = bitcast <4 x i64> [[B]] to <8 x i32>
1002-
; CHECK-NEXT: [[CONCAT:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[B_BC1]], <8 x i32> [[A_BC1]]
995+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
996+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[B]], <4 x i64> [[B]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
997+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i64> [[TMP2]] to <8 x i32>
998+
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x i64> [[A]], <4 x i64> [[A]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
999+
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <4 x i64> [[TMP4]] to <8 x i32>
1000+
; CHECK-NEXT: [[CONCAT:%.*]] = select <8 x i1> [[TMP1]], <8 x i32> [[TMP3]], <8 x i32> [[TMP5]]
10031001
; CHECK-NEXT: [[RES:%.*]] = bitcast <8 x i32> [[CONCAT]] to <4 x i64>
10041002
; CHECK-NEXT: ret <4 x i64> [[RES]]
10051003
;

llvm/test/Transforms/VectorCombine/X86/select-shuffle.ll

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ define <8 x i16> @shuffle_select_select(<4 x i16> %x, <4 x i16> %y, <4 x i16> %z
4141
; CHECK-LABEL: @shuffle_select_select(
4242
; CHECK-NEXT: [[CMP_XY:%.*]] = icmp slt <4 x i16> [[X:%.*]], [[Y:%.*]]
4343
; CHECK-NEXT: [[CMP_YZ:%.*]] = icmp slt <4 x i16> [[Y]], [[Z:%.*]]
44-
; CHECK-NEXT: [[SELECT_XZ:%.*]] = select <4 x i1> [[CMP_XY]], <4 x i16> [[X]], <4 x i16> [[Z]]
45-
; CHECK-NEXT: [[SELECT_YX:%.*]] = select <4 x i1> [[CMP_YZ]], <4 x i16> [[Y]], <4 x i16> [[X]]
46-
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x i16> [[SELECT_XZ]], <4 x i16> [[SELECT_YX]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
44+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i1> [[CMP_XY]], <4 x i1> [[CMP_YZ]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
45+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i16> [[X]], <4 x i16> [[Y]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
46+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x i16> [[Z]], <4 x i16> [[X]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
47+
; CHECK-NEXT: [[RES:%.*]] = select <8 x i1> [[TMP1]], <8 x i16> [[TMP2]], <8 x i16> [[TMP3]]
4748
; CHECK-NEXT: ret <8 x i16> [[RES]]
4849
;
4950
%cmp.xy = icmp slt <4 x i16> %x, %y

0 commit comments

Comments
 (0)