Skip to content

[InstCombine] Refactor matchFunnelShift to allow more pattern (NFC) #68474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 93 additions & 79 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2732,100 +2732,114 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
// rotate matching code under visitSelect and visitTrunc?
unsigned Width = Or.getType()->getScalarSizeInBits();

// First, find an or'd pair of opposite shifts:
// or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
BinaryOperator *Or0, *Or1;
if (!match(Or.getOperand(0), m_BinOp(Or0)) ||
!match(Or.getOperand(1), m_BinOp(Or1)))
return nullptr;

Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
!match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
Or0->getOpcode() == Or1->getOpcode())
Instruction *Or0, *Or1;
if (!match(Or.getOperand(0), m_Instruction(Or0)) ||
!match(Or.getOperand(1), m_Instruction(Or1)))
return nullptr;

// Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
if (Or0->getOpcode() == BinaryOperator::LShr) {
std::swap(Or0, Or1);
std::swap(ShVal0, ShVal1);
std::swap(ShAmt0, ShAmt1);
}
assert(Or0->getOpcode() == BinaryOperator::Shl &&
Or1->getOpcode() == BinaryOperator::LShr &&
"Illegal or(shift,shift) pair");

// Match the shift amount operands for a funnel shift pattern. This always
// matches a subtraction on the R operand.
auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
// Check for constant shift amounts that sum to the bitwidth.
const APInt *LI, *RI;
if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
return ConstantInt::get(L->getType(), *LI);

Constant *LC, *RC;
if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
return ConstantExpr::mergeUndefsWith(LC, RC);

// (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
// We limit this to X < Width in case the backend re-expands the intrinsic,
// and has to reintroduce a shift modulo operation (InstCombine might remove
// it after this fold). This still doesn't guarantee that the final codegen
// will match this original pattern.
if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
return KnownL.getMaxValue().ult(Width) ? L : nullptr;
}
bool IsFshl = true; // Sub on LSHR.
SmallVector<Value *, 3> FShiftArgs;

// For non-constant cases, the following patterns currently only work for
// rotation patterns.
// TODO: Add general funnel-shift compatible patterns.
if (ShVal0 != ShVal1)
// First, find an or'd pair of opposite shifts:
// or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the rationale for this check? Isn't it covered by the matching of logical shifts?

But either way, if it is indeed useful, can you invert it and early return to keep nested scopes down.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to match either
or (shl %x, 16) (shr %y, 16)
or (shl %x, 16) (zext %y)

if (isa(Or0) && isa(Or1)) is used to filter in first case and that is the original code matched.
68ab662 filters in second case.
If we want to further reduce nested scopes. We may need to split this function to several static functions to match each case, but this requires to pass some args and may not be concise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code seems new? Not that its wrong or anything.

But either way don't understand why it can't be an early return. If we don't enter the if FShiftArgs will be null and we will return anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code seems new? Not that its wrong or anything.

I added this if statement in this patch.

It can be early returned in this patch, but I'd like to add another else if on 5b3b1bb. If I removed this if statement, then I need to add the if back and indent the code in the if statement on #68502. This make the diff not clear.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay fair enough.

Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
if (!match(Or0,
m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
!match(Or1,
m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
Or0->getOpcode() == Or1->getOpcode())
return nullptr;

// For non-constant cases we don't support non-pow2 shift masks.
// TODO: Is it worth matching urem as well?
if (!isPowerOf2_32(Width))
return nullptr;
// Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
if (Or0->getOpcode() == BinaryOperator::LShr) {
std::swap(Or0, Or1);
std::swap(ShVal0, ShVal1);
std::swap(ShAmt0, ShAmt1);
}
assert(Or0->getOpcode() == BinaryOperator::Shl &&
Or1->getOpcode() == BinaryOperator::LShr &&
"Illegal or(shift,shift) pair");

// Match the shift amount operands for a funnel shift pattern. This always
// matches a subtraction on the R operand.
auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
// Check for constant shift amounts that sum to the bitwidth.
const APInt *LI, *RI;
if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
return ConstantInt::get(L->getType(), *LI);

Constant *LC, *RC;
if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
match(L,
m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
match(R,
m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
return ConstantExpr::mergeUndefsWith(LC, RC);

// (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
// We limit this to X < Width in case the backend re-expands the
// intrinsic, and has to reintroduce a shift modulo operation (InstCombine
// might remove it after this fold). This still doesn't guarantee that the
// final codegen will match this original pattern.
if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
return KnownL.getMaxValue().ult(Width) ? L : nullptr;
}

// The shift amount may be masked with negation:
// (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
Value *X;
unsigned Mask = Width - 1;
if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
return X;
// For non-constant cases, the following patterns currently only work for
// rotation patterns.
// TODO: Add general funnel-shift compatible patterns.
if (ShVal0 != ShVal1)
return nullptr;

// Similar to above, but the shift amount may be extended after masking,
// so return the extended value as the parameter for the intrinsic.
if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
m_SpecificInt(Mask))))
return L;
// For non-constant cases we don't support non-pow2 shift masks.
// TODO: Is it worth matching urem as well?
if (!isPowerOf2_32(Width))
return nullptr;

if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
return L;
// The shift amount may be masked with negation:
// (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
Value *X;
unsigned Mask = Width - 1;
if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
return X;

// Similar to above, but the shift amount may be extended after masking,
// so return the extended value as the parameter for the intrinsic.
if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
match(R,
m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
m_SpecificInt(Mask))))
return L;

if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
return L;

return nullptr;
};
return nullptr;
};

Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
bool IsFshl = true; // Sub on LSHR.
if (!ShAmt) {
ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
IsFshl = false; // Sub on SHL.
Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
if (!ShAmt) {
ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
IsFshl = false; // Sub on SHL.
}
if (!ShAmt)
return nullptr;

FShiftArgs = {ShVal0, ShVal1, ShAmt};
}
if (!ShAmt)

if (FShiftArgs.empty())
return nullptr;

Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
return CallInst::Create(F, {ShVal0, ShVal1, ShAmt});
return CallInst::Create(F, FShiftArgs);
}

/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
Expand Down