Skip to content

Commit c2a211f

Browse files
committed
[WIP][VectorCombine] Fold "shuffle (binop (shuffle, shuffle)), undef" --> "binop (shuffle), (shuffle)"
Add foldPermuteOfBinops - to fold a permute (single source shuffle) through a binary op that is being fed by other shuffles. WIP - still need to add additional test coverage. Fixes #94546
1 parent 4e1b9d3 commit c2a211f

File tree

3 files changed

+94
-8
lines changed

3 files changed

+94
-8
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class VectorCombine {
112112
bool foldExtractedCmps(Instruction &I);
113113
bool foldSingleElementStore(Instruction &I);
114114
bool scalarizeLoadExtract(Instruction &I);
115+
bool foldPermuteOfBinops(Instruction &I);
115116
bool foldShuffleOfBinops(Instruction &I);
116117
bool foldShuffleOfCastops(Instruction &I);
117118
bool foldShuffleOfShuffles(Instruction &I);
@@ -1400,6 +1401,92 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14001401
return true;
14011402
}
14021403

1404+
/// Try to convert "shuffle (binop (shuffle, shuffle)), undef" into "binop (shuffle), (shuffle)".
1405+
bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1406+
BinaryOperator *BinOp;
1407+
ArrayRef<int> OuterMask;
1408+
if (!match(&I,
1409+
m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask))))
1410+
return false;
1411+
1412+
// Don't introduce poison into div/rem.
1413+
if (llvm::is_contained(OuterMask, PoisonMaskElem) && BinOp->isIntDivRem())
1414+
return false;
1415+
1416+
Value *Op00, *Op01;
1417+
ArrayRef<int> Mask0;
1418+
if (!match(BinOp->getOperand(0),
1419+
m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)))))
1420+
return false;
1421+
1422+
Value *Op10, *Op11;
1423+
ArrayRef<int> Mask1;
1424+
if (!match(BinOp->getOperand(1),
1425+
m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)))))
1426+
return false;
1427+
1428+
Instruction::BinaryOps Opcode = BinOp->getOpcode();
1429+
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1430+
auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
1431+
auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
1432+
auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
1433+
if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1434+
return false;
1435+
1436+
unsigned NumSrcElts = BinOpTy->getNumElements();
1437+
1438+
// Don't accept shuffles that reference the second (undef/poison) operand.
1439+
if (any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
1440+
return false;
1441+
1442+
// Merge outer / inner shuffles.
1443+
SmallVector<int> NewMask0, NewMask1;
1444+
for (int M : OuterMask) {
1445+
NewMask0.push_back(M >= 0 ? Mask0[M] : -1);
1446+
NewMask1.push_back(M >= 0 ? Mask1[M] : -1);
1447+
}
1448+
1449+
// Try to merge shuffles across the binop if the new shuffles are not costly.
1450+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1451+
1452+
InstructionCost OldCost =
1453+
TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) +
1454+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1455+
OuterMask, CostKind, 0, nullptr, {BinOp}, &I) +
1456+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1457+
CostKind, 0, nullptr, {Op00, Op01},
1458+
cast<Instruction>(BinOp->getOperand(0))) +
1459+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1460+
CostKind, 0, nullptr, {Op10, Op11},
1461+
cast<Instruction>(BinOp->getOperand(1)));
1462+
1463+
InstructionCost NewCost =
1464+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1465+
NewMask0, CostKind, 0, nullptr, {Op00, Op01}) +
1466+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1467+
NewMask1, CostKind, 0, nullptr, {Op10, Op11}) +
1468+
TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
1469+
1470+
LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
1471+
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1472+
<< "\n");
1473+
if (NewCost >= OldCost)
1474+
return false;
1475+
1476+
Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0);
1477+
Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1);
1478+
Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1479+
1480+
// Intersect flags from the old binops.
1481+
if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1482+
NewInst->copyIRFlags(BinOp);
1483+
1484+
Worklist.pushValue(Shuf0);
1485+
Worklist.pushValue(Shuf1);
1486+
replaceValue(I, *NewBO);
1487+
return true;
1488+
}
1489+
14031490
/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
14041491
bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
14051492
BinaryOperator *B0, *B1;
@@ -2736,6 +2823,7 @@ bool VectorCombine::run() {
27362823
MadeChange |= foldInsExtFNeg(I);
27372824
break;
27382825
case Instruction::ShuffleVector:
2826+
MadeChange |= foldPermuteOfBinops(I);
27392827
MadeChange |= foldShuffleOfBinops(I);
27402828
MadeChange |= foldShuffleOfCastops(I);
27412829
MadeChange |= foldShuffleOfShuffles(I);

llvm/test/Transforms/PhaseOrdering/X86/horiz-math-inseltpoison.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,10 @@ define <8 x float> @hadd_reverse_v8f32(<8 x float> %a, <8 x float> %b) #0 {
108108

109109
define <8 x float> @reverse_hadd_v8f32(<8 x float> %a, <8 x float> %b) #0 {
110110
; CHECK-LABEL: @reverse_hadd_v8f32(
111-
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 0, i32 2, i32 8, i32 10, i32 4, i32 6, i32 12, i32 14>
112-
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 1, i32 3, i32 9, i32 11, i32 5, i32 7, i32 13, i32 15>
111+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 14, i32 12, i32 6, i32 4, i32 10, i32 8, i32 2, i32 0>
112+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 15, i32 13, i32 7, i32 5, i32 11, i32 9, i32 3, i32 1>
113113
; CHECK-NEXT: [[TMP3:%.*]] = fadd <8 x float> [[TMP1]], [[TMP2]]
114-
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <8 x float> [[TMP3]], <8 x float> poison, <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
115-
; CHECK-NEXT: ret <8 x float> [[SHUFFLE]]
114+
; CHECK-NEXT: ret <8 x float> [[TMP3]]
116115
;
117116
%vecext = extractelement <8 x float> %a, i32 0
118117
%vecext1 = extractelement <8 x float> %a, i32 1

llvm/test/Transforms/PhaseOrdering/X86/horiz-math.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,10 @@ define <8 x float> @hadd_reverse_v8f32(<8 x float> %a, <8 x float> %b) #0 {
108108

109109
define <8 x float> @reverse_hadd_v8f32(<8 x float> %a, <8 x float> %b) #0 {
110110
; CHECK-LABEL: @reverse_hadd_v8f32(
111-
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 0, i32 2, i32 8, i32 10, i32 4, i32 6, i32 12, i32 14>
112-
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 1, i32 3, i32 9, i32 11, i32 5, i32 7, i32 13, i32 15>
111+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x i32> <i32 14, i32 12, i32 6, i32 4, i32 10, i32 8, i32 2, i32 0>
112+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[A]], <8 x float> [[B]], <8 x i32> <i32 15, i32 13, i32 7, i32 5, i32 11, i32 9, i32 3, i32 1>
113113
; CHECK-NEXT: [[TMP3:%.*]] = fadd <8 x float> [[TMP1]], [[TMP2]]
114-
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <8 x float> [[TMP3]], <8 x float> poison, <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
115-
; CHECK-NEXT: ret <8 x float> [[SHUFFLE]]
114+
; CHECK-NEXT: ret <8 x float> [[TMP3]]
116115
;
117116
%vecext = extractelement <8 x float> %a, i32 0
118117
%vecext1 = extractelement <8 x float> %a, i32 1

0 commit comments

Comments
 (0)