-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[RISCV][VLOPT] Add support for checkUsers when UserMI is a Single-Width Integer Reduction #120345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
3b0709e
3acf8c4
27de222
dac7acf
3fc5297
fcf4d81
6a27058
f8ce58c
f4dc6b3
68cd004
ab12f92
4656b45
4e463d7
b8c3411
e6f2468
5232e6f
8d50a38
61e681c
595a709
fe55f64
4666a7f
f8c54d4
85f49a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,7 +50,10 @@ class RISCVVLOptimizer : public MachineFunctionPass { | |
StringRef getPassName() const override { return PASS_NAME; } | ||
|
||
private: | ||
bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI); | ||
std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp); | ||
/// Returns the largest common VL MachineOperand that may be used to optimize | ||
/// MI. Returns std::nullopt if it failed to find a suitable VL. | ||
std::optional<MachineOperand> checkUsers(MachineInstr &MI); | ||
bool tryReduceVL(MachineInstr &MI); | ||
bool isCandidate(const MachineInstr &MI) const; | ||
}; | ||
|
@@ -95,6 +98,8 @@ struct OperandInfo { | |
OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW) | ||
: S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {} | ||
|
||
OperandInfo(unsigned Log2EEW) : S(State::Known), Log2EEW(Log2EEW) {} | ||
|
||
OperandInfo() : S(State::Unknown) {} | ||
|
||
bool isUnknown() const { return S == State::Unknown; } | ||
|
@@ -107,6 +112,11 @@ struct OperandInfo { | |
A.EMUL->second == B.EMUL->second; | ||
} | ||
|
||
static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { | ||
assert(A.isKnown() && B.isKnown() && "Both operands must be known"); | ||
return A.Log2EEW == B.Log2EEW; | ||
} | ||
|
||
void print(raw_ostream &OS) const { | ||
if (isUnknown()) { | ||
OS << "Unknown"; | ||
|
@@ -716,6 +726,23 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, | |
return OperandInfo(MIVLMul, MILog2SEW); | ||
} | ||
|
||
// Vector Reduction Operations | ||
// Vector Single-Width Integer Reduction Instructions | ||
// The Dest and VS1 only read element 0 of the vector register. Return just | ||
// the EEW for these. VS2 has EEW=SEW and EMUL=LMUL. | ||
case RISCV::VREDAND_VS: | ||
case RISCV::VREDMAX_VS: | ||
case RISCV::VREDMAXU_VS: | ||
case RISCV::VREDMIN_VS: | ||
case RISCV::VREDMINU_VS: | ||
case RISCV::VREDOR_VS: | ||
case RISCV::VREDSUM_VS: | ||
case RISCV::VREDXOR_VS: { | ||
if (MO.getOperandNo() == 2) | ||
return OperandInfo(MIVLMul, MILog2SEW); | ||
return OperandInfo(MILog2SEW); | ||
} | ||
|
||
default: | ||
return {}; | ||
} | ||
|
@@ -1028,79 +1055,106 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { | |
return true; | ||
} | ||
|
||
bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL, | ||
MachineInstr &MI) { | ||
std::optional<MachineOperand> | ||
RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) { | ||
const MachineInstr &UserMI = *UserOp.getParent(); | ||
const MCInstrDesc &Desc = UserMI.getDesc(); | ||
|
||
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { | ||
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" | ||
" use VLMAX\n"); | ||
return std::nullopt; | ||
} | ||
|
||
// Instructions like reductions may use a vector register as a scalar | ||
// register. In this case, we should treat it like a scalar register which | ||
// does not impact the decision on whether to optimize VL. But if there is | ||
// another user of MI and it may have VL=0, we need to be sure not to reduce | ||
// the VL of MI to zero when the VLOp of UserOp may be non-zero. The most | ||
// we can reduce it to is one. | ||
if (isVectorOpUsedAsScalarOp(UserOp)) { | ||
[[maybe_unused]] Register R = UserOp.getReg(); | ||
[[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); | ||
assert(RISCV::VRRegClass.hasSubClassEq(RC) && | ||
"Expect LMUL 1 register class for vector as scalar operands!"); | ||
LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); | ||
|
||
return MachineOperand::CreateImm(1); | ||
} | ||
|
||
unsigned VLOpNum = RISCVII::getVLOpNum(Desc); | ||
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); | ||
// Looking for an immediate or a register VL that isn't X0. | ||
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && | ||
"Did not expect X0 VL"); | ||
return VLOp; | ||
} | ||
|
||
std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) { | ||
// FIXME: Avoid visiting each user for each time we visit something on the | ||
// worklist, combined with an extra visit from the outer loop. Restructure | ||
// along lines of an instcombine style worklist which integrates the outer | ||
// pass. | ||
bool CanReduceVL = true; | ||
std::optional<MachineOperand> CommonVL; | ||
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { | ||
const MachineInstr &UserMI = *UserOp.getParent(); | ||
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); | ||
|
||
// Instructions like reductions may use a vector register as a scalar | ||
// register. In this case, we should treat it like a scalar register which | ||
// does not impact the decision on whether to optimize VL. | ||
// TODO: Treat it like a scalar register instead of bailing out. | ||
if (isVectorOpUsedAsScalarOp(UserOp)) { | ||
CanReduceVL = false; | ||
break; | ||
} | ||
|
||
if (mayReadPastVL(UserMI)) { | ||
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); | ||
CanReduceVL = false; | ||
break; | ||
return std::nullopt; | ||
} | ||
|
||
// Tied operands might pass through. | ||
if (UserOp.isTied()) { | ||
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); | ||
CanReduceVL = false; | ||
break; | ||
} | ||
|
||
const MCInstrDesc &Desc = UserMI.getDesc(); | ||
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { | ||
LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that" | ||
" use VLMAX\n"); | ||
CanReduceVL = false; | ||
break; | ||
return std::nullopt; | ||
} | ||
|
||
unsigned VLOpNum = RISCVII::getVLOpNum(Desc); | ||
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); | ||
|
||
// Looking for an immediate or a register VL that isn't X0. | ||
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && | ||
"Did not expect X0 VL"); | ||
auto VLOp = getMinimumVLForUser(UserOp); | ||
if (!VLOp) | ||
return std::nullopt; | ||
|
||
// Use the largest VL among all the users. If we cannot determine this | ||
// statically, then we cannot optimize the VL. | ||
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, VLOp)) { | ||
CommonVL = &VLOp; | ||
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { | ||
CommonVL = *VLOp; | ||
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); | ||
} else if (!RISCV::isVLKnownLE(VLOp, *CommonVL)) { | ||
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { | ||
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); | ||
CanReduceVL = false; | ||
break; | ||
return std::nullopt; | ||
} | ||
|
||
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { | ||
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); | ||
return std::nullopt; | ||
} | ||
|
||
// The SEW and LMUL of destination and source registers need to match. | ||
OperandInfo ConsumerInfo = getOperandInfo(UserOp, MRI); | ||
OperandInfo ProducerInfo = getOperandInfo(MI.getOperand(0), MRI); | ||
if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() || | ||
!OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) { | ||
LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown " | ||
"information for EMUL or EEW.\n"); | ||
if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown()) { | ||
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); | ||
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); | ||
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); | ||
CanReduceVL = false; | ||
break; | ||
return std::nullopt; | ||
} | ||
|
||
// If the operand is used as a scalar operand, then the EEW must be | ||
// compatible. Otherwise, the EMUL *and* EEW must be compatible. | ||
bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); | ||
if ((IsVectorOpUsedAsScalarOp && | ||
!OperandInfo::EEWAreEqual(ConsumerInfo, ProducerInfo)) || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a possible follow up - in general, I think a larger source LMUL is fine. You could unify the code here by returning the smallest legal LMUL for the given SEW for the scalar source operand, and then using a greater than comparison for EMUL here. Not sure if that improves readability or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another way of phrasing this could be "is vlmax known greater than or equal to" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am happy to address this in a follow up patch. |
||
(!IsVectorOpUsedAsScalarOp && | ||
!OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo))) { | ||
LLVM_DEBUG( | ||
dbgs() | ||
<< " Abort due to incompatible information for EMUL or EEW.\n"); | ||
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); | ||
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); | ||
return std::nullopt; | ||
} | ||
} | ||
return CanReduceVL; | ||
|
||
return CommonVL; | ||
} | ||
|
||
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { | ||
|
@@ -1112,12 +1166,11 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { | |
MachineInstr &MI = *Worklist.pop_back_val(); | ||
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); | ||
|
||
const MachineOperand *CommonVL = nullptr; | ||
bool CanReduceVL = true; | ||
if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) | ||
CanReduceVL = checkUsers(CommonVL, MI); | ||
if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI)) | ||
continue; | ||
|
||
if (!CanReduceVL || !CommonVL) | ||
auto CommonVL = checkUsers(MI); | ||
if (!CommonVL) | ||
continue; | ||
|
||
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && | ||
|
Uh oh!
There was an error while loading. Please reload this page.