Skip to content

Commit 42daf16

Browse files
committed
[VectorCombine] Fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" -> "(bitcast (concat X, Y))" MOVMSK bool mask style patterns
Mask/Bool vectors are often bitcast to/from scalar integers, in particular when concatenating mask results, often this is due to the difficulties of working with vector of bools on C/C++. On x86 this typically involves the MOVMSK/KMOV instructions. To concatenate bool masks, these are typically cast to scalars, which are then zero-extended, shifted and OR'd together. This patch attempts to match these scalar concatenation patterns and convert them to vector shuffles instead. This in turn often assists with further vector combines, depending on the cost model. Fixes #111431
1 parent 673c324 commit 42daf16

File tree

2 files changed

+267
-114
lines changed

2 files changed

+267
-114
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class VectorCombine {
115115
bool foldExtractedCmps(Instruction &I);
116116
bool foldSingleElementStore(Instruction &I);
117117
bool scalarizeLoadExtract(Instruction &I);
118+
bool foldConcatOfBoolMasks(Instruction &I);
118119
bool foldPermuteOfBinops(Instruction &I);
119120
bool foldShuffleOfBinops(Instruction &I);
120121
bool foldShuffleOfCastops(Instruction &I);
@@ -1423,6 +1424,112 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14231424
return true;
14241425
}
14251426

1427+
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1428+
/// to "(bitcast (concat X, Y))"
1429+
/// where X/Y are bitcasted from i1 mask vectors.
1430+
bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1431+
Type *Ty = I.getType();
1432+
if (!Ty->isIntegerTy())
1433+
return false;
1434+
1435+
// TODO: Add big endian test coverage
1436+
if (DL->isBigEndian())
1437+
return false;
1438+
1439+
// Restrict to disjoint cases so the mask vectors aren't overlapping.
1440+
Instruction *X, *Y;
1441+
if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
1442+
return false;
1443+
1444+
// Allow both sources to contain shl, to handle more generic pattern:
1445+
// "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1446+
Value *SrcX;
1447+
uint64_t ShAmtX = 0;
1448+
if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
1449+
!match(X, m_OneUse(
1450+
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
1451+
m_ConstantInt(ShAmtX)))))
1452+
return false;
1453+
1454+
Value *SrcY;
1455+
uint64_t ShAmtY = 0;
1456+
if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
1457+
!match(Y, m_OneUse(
1458+
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
1459+
m_ConstantInt(ShAmtY)))))
1460+
return false;
1461+
1462+
// Canonicalize larger shift to the RHS.
1463+
if (ShAmtX > ShAmtY) {
1464+
std::swap(X, Y);
1465+
std::swap(SrcX, SrcY);
1466+
std::swap(ShAmtX, ShAmtY);
1467+
}
1468+
1469+
// Ensure both sources are matching vXi1 bool mask types, and that the shift
1470+
// difference is the mask width so they can be easily concatenated together.
1471+
uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1472+
unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1473+
unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1474+
auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
1475+
if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1476+
!MaskTy->getElementType()->isIntegerTy(1) ||
1477+
MaskTy->getNumElements() != ShAmtDiff ||
1478+
MaskTy->getNumElements() > (BitWidth / 2))
1479+
return false;
1480+
1481+
auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
1482+
auto *ConcatIntTy = Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
1483+
auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
1484+
1485+
SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1486+
std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1487+
1488+
// TODO: Is it worth supporting multi use cases?
1489+
InstructionCost OldCost = 0;
1490+
OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
1491+
OldCost +=
1492+
NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1493+
OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
1494+
TTI::CastContextHint::None, CostKind);
1495+
OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
1496+
TTI::CastContextHint::None, CostKind);
1497+
1498+
InstructionCost NewCost = 0;
1499+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1500+
ConcatMask, CostKind);
1501+
NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
1502+
TTI::CastContextHint::None, CostKind);
1503+
if (Ty != ConcatIntTy)
1504+
NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
1505+
TTI::CastContextHint::None, CostKind);
1506+
if (ShAmtX > 0)
1507+
NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1508+
1509+
if (NewCost > OldCost)
1510+
return false;
1511+
1512+
// Build bool mask concatenation, bitcast back to scalar integer, and perform
1513+
// any residual zero-extension or shifting.
1514+
Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
1515+
Worklist.pushValue(Concat);
1516+
1517+
Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
1518+
1519+
if (Ty != ConcatIntTy) {
1520+
Worklist.pushValue(Result);
1521+
Result = Builder.CreateZExt(Result, Ty);
1522+
}
1523+
1524+
if (ShAmtX > 0) {
1525+
Worklist.pushValue(Result);
1526+
Result = Builder.CreateShl(Result, ShAmtX);
1527+
}
1528+
1529+
replaceValue(I, *Result);
1530+
return true;
1531+
}
1532+
14261533
/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
14271534
/// --> "binop (shuffle), (shuffle)".
14281535
bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
@@ -2908,6 +3015,9 @@ bool VectorCombine::run() {
29083015
if (TryEarlyFoldsOnly)
29093016
return;
29103017

3018+
if (I.getType()->isIntegerTy())
3019+
MadeChange |= foldConcatOfBoolMasks(I);
3020+
29113021
// Otherwise, try folds that improve codegen but may interfere with
29123022
// early IR canonicalizations.
29133023
// The type checking is for run-time efficiency. We can avoid wasting time

0 commit comments

Comments
 (0)