@@ -112,6 +112,7 @@ class VectorCombine {
112
112
bool foldExtractedCmps (Instruction &I);
113
113
bool foldSingleElementStore (Instruction &I);
114
114
bool scalarizeLoadExtract (Instruction &I);
115
+ bool foldPermuteOfBinops (Instruction &I);
115
116
bool foldShuffleOfBinops (Instruction &I);
116
117
bool foldShuffleOfCastops (Instruction &I);
117
118
bool foldShuffleOfShuffles (Instruction &I);
@@ -1400,6 +1401,93 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1400
1401
return true ;
1401
1402
}
1402
1403
1404
+ // / Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1405
+ // / --> "binop (shuffle), (shuffle)".
1406
+ bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
1407
+ BinaryOperator *BinOp;
1408
+ ArrayRef<int > OuterMask;
1409
+ if (!match (&I,
1410
+ m_Shuffle (m_OneUse (m_BinOp (BinOp)), m_Undef (), m_Mask (OuterMask))))
1411
+ return false ;
1412
+
1413
+ // Don't introduce poison into div/rem.
1414
+ if (llvm::is_contained (OuterMask, PoisonMaskElem) && BinOp->isIntDivRem ())
1415
+ return false ;
1416
+
1417
+ Value *Op00, *Op01;
1418
+ ArrayRef<int > Mask0;
1419
+ if (!match (BinOp->getOperand (0 ),
1420
+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1421
+ return false ;
1422
+
1423
+ Value *Op10, *Op11;
1424
+ ArrayRef<int > Mask1;
1425
+ if (!match (BinOp->getOperand (1 ),
1426
+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1427
+ return false ;
1428
+
1429
+ Instruction::BinaryOps Opcode = BinOp->getOpcode ();
1430
+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1431
+ auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType ());
1432
+ auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType ());
1433
+ auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType ());
1434
+ if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1435
+ return false ;
1436
+
1437
+ unsigned NumSrcElts = BinOpTy->getNumElements ();
1438
+
1439
+ // Don't accept shuffles that reference the second (undef/poison) operand.
1440
+ if (any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
1441
+ return false ;
1442
+
1443
+ // Merge outer / inner shuffles.
1444
+ SmallVector<int > NewMask0, NewMask1;
1445
+ for (int M : OuterMask) {
1446
+ NewMask0.push_back (M >= 0 ? Mask0[M] : -1 );
1447
+ NewMask1.push_back (M >= 0 ? Mask1[M] : -1 );
1448
+ }
1449
+
1450
+ // Try to merge shuffles across the binop if the new shuffles are not costly.
1451
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1452
+
1453
+ InstructionCost OldCost =
1454
+ TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
1455
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1456
+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1457
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1458
+ CostKind, 0 , nullptr , {Op00, Op01},
1459
+ cast<Instruction>(BinOp->getOperand (0 ))) +
1460
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1461
+ CostKind, 0 , nullptr , {Op10, Op11},
1462
+ cast<Instruction>(BinOp->getOperand (1 )));
1463
+
1464
+ InstructionCost NewCost =
1465
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1466
+ CostKind, 0 , nullptr , {Op00, Op01}) +
1467
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1468
+ CostKind, 0 , nullptr , {Op10, Op11}) +
1469
+ TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
1470
+
1471
+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
1472
+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1473
+ << " \n " );
1474
+ if (NewCost >= OldCost)
1475
+ return false ;
1476
+
1477
+ Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1478
+ Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1479
+ Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1480
+
1481
+ // Intersect flags from the old binops.
1482
+ if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1483
+ NewInst->copyIRFlags (BinOp);
1484
+
1485
+ Worklist.pushValue (Shuf0);
1486
+ Worklist.pushValue (Shuf1);
1487
+ replaceValue (I, *NewBO);
1488
+ return true ;
1489
+ }
1490
+
1403
1491
// / Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
1404
1492
bool VectorCombine::foldShuffleOfBinops (Instruction &I) {
1405
1493
BinaryOperator *B0, *B1;
@@ -2736,6 +2824,7 @@ bool VectorCombine::run() {
2736
2824
MadeChange |= foldInsExtFNeg (I);
2737
2825
break ;
2738
2826
case Instruction::ShuffleVector:
2827
+ MadeChange |= foldPermuteOfBinops (I);
2739
2828
MadeChange |= foldShuffleOfBinops (I);
2740
2829
MadeChange |= foldShuffleOfCastops (I);
2741
2830
MadeChange |= foldShuffleOfShuffles (I);
0 commit comments