@@ -112,6 +112,7 @@ class VectorCombine {
112
112
bool foldExtractedCmps (Instruction &I);
113
113
bool foldSingleElementStore (Instruction &I);
114
114
bool scalarizeLoadExtract (Instruction &I);
115
+ bool foldPermuteOfBinops (Instruction &I);
115
116
bool foldShuffleOfBinops (Instruction &I);
116
117
bool foldShuffleOfCastops (Instruction &I);
117
118
bool foldShuffleOfShuffles (Instruction &I);
@@ -1400,6 +1401,100 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1400
1401
return true ;
1401
1402
}
1402
1403
1404
+ // / Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1405
+ // / --> "binop (shuffle), (shuffle)".
1406
+ bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
1407
+ BinaryOperator *BinOp;
1408
+ ArrayRef<int > OuterMask;
1409
+ if (!match (&I,
1410
+ m_Shuffle (m_OneUse (m_BinOp (BinOp)), m_Undef (), m_Mask (OuterMask))))
1411
+ return false ;
1412
+
1413
+ // Don't introduce poison into div/rem.
1414
+ if (BinOp->isIntDivRem () && llvm::is_contained (OuterMask, PoisonMaskElem))
1415
+ return false ;
1416
+
1417
+ Value *Op00, *Op01;
1418
+ ArrayRef<int > Mask0;
1419
+ if (!match (BinOp->getOperand (0 ),
1420
+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1421
+ return false ;
1422
+
1423
+ Value *Op10, *Op11;
1424
+ ArrayRef<int > Mask1;
1425
+ if (!match (BinOp->getOperand (1 ),
1426
+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1427
+ return false ;
1428
+
1429
+ Instruction::BinaryOps Opcode = BinOp->getOpcode ();
1430
+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1431
+ auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType ());
1432
+ auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType ());
1433
+ auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType ());
1434
+ if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1435
+ return false ;
1436
+
1437
+ unsigned NumSrcElts = BinOpTy->getNumElements ();
1438
+
1439
+ // Don't accept shuffles that reference the second (undef/poison) operand in
1440
+ // div/rem..
1441
+ if (BinOp->isIntDivRem () &&
1442
+ any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
1443
+ return false ;
1444
+
1445
+ // Merge outer / inner shuffles.
1446
+ SmallVector<int > NewMask0, NewMask1;
1447
+ for (int M : OuterMask) {
1448
+ if (M < 0 || M >= (int )NumSrcElts) {
1449
+ NewMask0.push_back (PoisonMaskElem);
1450
+ NewMask1.push_back (PoisonMaskElem);
1451
+ } else {
1452
+ NewMask0.push_back (Mask0[M]);
1453
+ NewMask1.push_back (Mask1[M]);
1454
+ }
1455
+ }
1456
+
1457
+ // Try to merge shuffles across the binop if the new shuffles are not costly.
1458
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1459
+
1460
+ InstructionCost OldCost =
1461
+ TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
1462
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1463
+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1464
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1465
+ CostKind, 0 , nullptr , {Op00, Op01},
1466
+ cast<Instruction>(BinOp->getOperand (0 ))) +
1467
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1468
+ CostKind, 0 , nullptr , {Op10, Op11},
1469
+ cast<Instruction>(BinOp->getOperand (1 )));
1470
+
1471
+ InstructionCost NewCost =
1472
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1473
+ CostKind, 0 , nullptr , {Op00, Op01}) +
1474
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1475
+ CostKind, 0 , nullptr , {Op10, Op11}) +
1476
+ TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
1477
+
1478
+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
1479
+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1480
+ << " \n " );
1481
+ if (NewCost >= OldCost)
1482
+ return false ;
1483
+
1484
+ Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1485
+ Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1486
+ Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1487
+
1488
+ // Intersect flags from the old binops.
1489
+ if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1490
+ NewInst->copyIRFlags (BinOp);
1491
+
1492
+ Worklist.pushValue (Shuf0);
1493
+ Worklist.pushValue (Shuf1);
1494
+ replaceValue (I, *NewBO);
1495
+ return true ;
1496
+ }
1497
+
1403
1498
// / Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
1404
1499
bool VectorCombine::foldShuffleOfBinops (Instruction &I) {
1405
1500
BinaryOperator *B0, *B1;
@@ -2736,6 +2831,7 @@ bool VectorCombine::run() {
2736
2831
MadeChange |= foldInsExtFNeg (I);
2737
2832
break ;
2738
2833
case Instruction::ShuffleVector:
2834
+ MadeChange |= foldPermuteOfBinops (I);
2739
2835
MadeChange |= foldShuffleOfBinops (I);
2740
2836
MadeChange |= foldShuffleOfCastops (I);
2741
2837
MadeChange |= foldShuffleOfShuffles (I);
0 commit comments