Skip to content

Commit 4c989b9

Browse files
committed
[VectorCombine] Handle shuffle of selects
(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) -> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m)) The behaviour of SelectInst on vectors is the same as for `V'select[i] = Condition[i] ? V'True[i] : V'False[i]`. If a ShuffleVector is performed on two selects, it will be like: `V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])` That's why a ShuffleVector with two SelectInst is equivalent to first ShuffleVector Condition/True/False and then SelectInst that result. This patch implements the transforming described above. Proof: https://alive2.llvm.org/ce/z/97wfHp Fixed: #120775
1 parent 2a4dfdf commit 4c989b9

File tree

2 files changed

+361
-89
lines changed

2 files changed

+361
-89
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class VectorCombine {
119119
bool foldConcatOfBoolMasks(Instruction &I);
120120
bool foldPermuteOfBinops(Instruction &I);
121121
bool foldShuffleOfBinops(Instruction &I);
122+
bool foldShuffleOfSelects(Instruction &I);
122123
bool foldShuffleOfCastops(Instruction &I);
123124
bool foldShuffleOfShuffles(Instruction &I);
124125
bool foldShuffleOfIntrinsics(Instruction &I);
@@ -1899,6 +1900,56 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
18991900
return true;
19001901
}
19011902

1903+
/// Try to convert,
1904+
/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
1905+
/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
1906+
bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
1907+
ArrayRef<int> Mask;
1908+
Value *C1, *T1, *F1, *C2, *T2, *F2;
1909+
if (!match(&I, m_Shuffle(
1910+
m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
1911+
m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
1912+
m_Mask(Mask))))
1913+
return false;
1914+
1915+
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
1916+
auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
1917+
auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
1918+
if (!C1VecTy || !C2VecTy)
1919+
return false;
1920+
1921+
auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
1922+
auto SelOp = Instruction::Select;
1923+
InstructionCost OldCost = TTI.getCmpSelInstrCost(
1924+
SelOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
1925+
OldCost += TTI.getCmpSelInstrCost(SelOp, T2->getType(), C2VecTy,
1926+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1927+
OldCost += TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr,
1928+
{I.getOperand(0), I.getOperand(1)}, &I);
1929+
1930+
auto *C1C2VecTy = cast<FixedVectorType>(
1931+
toVectorTy(Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements()));
1932+
InstructionCost NewCost =
1933+
TTI.getShuffleCost(SK, C1C2VecTy, Mask, CostKind, 0, nullptr, {C1, C2});
1934+
NewCost +=
1935+
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {T1, T2});
1936+
NewCost +=
1937+
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {F1, F2});
1938+
NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, DstVecTy,
1939+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1940+
1941+
if (NewCost > OldCost)
1942+
return false;
1943+
1944+
Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
1945+
Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
1946+
Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
1947+
Value *NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
1948+
1949+
replaceValue(I, *NewSel);
1950+
return true;
1951+
}
1952+
19021953
/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
19031954
/// into "castop (shuffle)".
19041955
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
@@ -3352,6 +3403,7 @@ bool VectorCombine::run() {
33523403
case Instruction::ShuffleVector:
33533404
MadeChange |= foldPermuteOfBinops(I);
33543405
MadeChange |= foldShuffleOfBinops(I);
3406+
MadeChange |= foldShuffleOfSelects(I);
33553407
MadeChange |= foldShuffleOfCastops(I);
33563408
MadeChange |= foldShuffleOfShuffles(I);
33573409
MadeChange |= foldShuffleOfIntrinsics(I);

0 commit comments

Comments
 (0)