@@ -112,6 +112,7 @@ class VectorCombine {
112
112
bool foldExtractedCmps (Instruction &I);
113
113
bool foldSingleElementStore (Instruction &I);
114
114
bool scalarizeLoadExtract (Instruction &I);
115
+ bool foldPermuteOfBinops (Instruction &I);
115
116
bool foldShuffleOfBinops (Instruction &I);
116
117
bool foldShuffleOfCastops (Instruction &I);
117
118
bool foldShuffleOfShuffles (Instruction &I);
@@ -1400,6 +1401,92 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1400
1401
return true ;
1401
1402
}
1402
1403
1404
+ // / Try to convert "shuffle (binop (shuffle, shuffle)), undef" into "binop (shuffle), (shuffle)".
1405
+ bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
1406
+ BinaryOperator *BinOp;
1407
+ ArrayRef<int > OuterMask;
1408
+ if (!match (&I,
1409
+ m_Shuffle (m_OneUse (m_BinOp (BinOp)), m_Undef (), m_Mask (OuterMask))))
1410
+ return false ;
1411
+
1412
+ // Don't introduce poison into div/rem.
1413
+ if (llvm::is_contained (OuterMask, PoisonMaskElem) && BinOp->isIntDivRem ())
1414
+ return false ;
1415
+
1416
+ Value *Op00, *Op01;
1417
+ ArrayRef<int > Mask0;
1418
+ if (!match (BinOp->getOperand (0 ),
1419
+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1420
+ return false ;
1421
+
1422
+ Value *Op10, *Op11;
1423
+ ArrayRef<int > Mask1;
1424
+ if (!match (BinOp->getOperand (1 ),
1425
+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1426
+ return false ;
1427
+
1428
+ Instruction::BinaryOps Opcode = BinOp->getOpcode ();
1429
+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1430
+ auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType ());
1431
+ auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType ());
1432
+ auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType ());
1433
+ if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1434
+ return false ;
1435
+
1436
+ unsigned NumSrcElts = BinOpTy->getNumElements ();
1437
+
1438
+ // Don't accept shuffles that reference the second (undef/poison) operand.
1439
+ if (any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
1440
+ return false ;
1441
+
1442
+ // Merge outer / inner shuffles.
1443
+ SmallVector<int > NewMask0, NewMask1;
1444
+ for (int M : OuterMask) {
1445
+ NewMask0.push_back (M >= 0 ? Mask0[M] : -1 );
1446
+ NewMask1.push_back (M >= 0 ? Mask1[M] : -1 );
1447
+ }
1448
+
1449
+ // Try to merge shuffles across the binop if the new shuffles are not costly.
1450
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1451
+
1452
+ InstructionCost OldCost =
1453
+ TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
1454
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1455
+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1456
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1457
+ CostKind, 0 , nullptr , {Op00, Op01},
1458
+ cast<Instruction>(BinOp->getOperand (0 ))) +
1459
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1460
+ CostKind, 0 , nullptr , {Op10, Op11},
1461
+ cast<Instruction>(BinOp->getOperand (1 )));
1462
+
1463
+ InstructionCost NewCost =
1464
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1465
+ NewMask0, CostKind, 0 , nullptr , {Op00, Op01}) +
1466
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1467
+ NewMask1, CostKind, 0 , nullptr , {Op10, Op11}) +
1468
+ TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
1469
+
1470
+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
1471
+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1472
+ << " \n " );
1473
+ if (NewCost >= OldCost)
1474
+ return false ;
1475
+
1476
+ Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1477
+ Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1478
+ Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1479
+
1480
+ // Intersect flags from the old binops.
1481
+ if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1482
+ NewInst->copyIRFlags (BinOp);
1483
+
1484
+ Worklist.pushValue (Shuf0);
1485
+ Worklist.pushValue (Shuf1);
1486
+ replaceValue (I, *NewBO);
1487
+ return true ;
1488
+ }
1489
+
1403
1490
// / Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
1404
1491
bool VectorCombine::foldShuffleOfBinops (Instruction &I) {
1405
1492
BinaryOperator *B0, *B1;
@@ -2736,6 +2823,7 @@ bool VectorCombine::run() {
2736
2823
MadeChange |= foldInsExtFNeg (I);
2737
2824
break ;
2738
2825
case Instruction::ShuffleVector:
2826
+ MadeChange |= foldPermuteOfBinops (I);
2739
2827
MadeChange |= foldShuffleOfBinops (I);
2740
2828
MadeChange |= foldShuffleOfCastops (I);
2741
2829
MadeChange |= foldShuffleOfShuffles (I);
0 commit comments