Skip to content

[RISCV] Use DenseMap to track V0 definition. NFC #84465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,13 @@ class RISCVFoldMasks : public MachineFunctionPass {
StringRef getPassName() const override { return "RISC-V Fold Masks"; }

private:
bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef) const;
bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef) const;
bool convertToUnmasked(MachineInstr &MI) const;
bool convertVMergeToVMv(MachineInstr &MI) const;

bool isAllOnesMask(MachineInstr *MaskDef) const;
bool isAllOnesMask(const MachineInstr *MaskDef) const;

/// Maps uses of V0 to the corresponding def of V0.
DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
};

} // namespace
Expand All @@ -59,10 +62,9 @@ char RISCVFoldMasks::ID = 0;

INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)

bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
if (!MaskDef)
return false;
assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0);
bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
assert(MaskDef && MaskDef->isCopy() &&
MaskDef->getOperand(0).getReg() == RISCV::V0);
Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
if (!SrcReg.isVirtual())
return false;
Expand All @@ -89,8 +91,7 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {

// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
MachineInstr *V0Def) const {
bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
#define CASE_VMERGE_TO_VMV(lmul) \
case RISCV::PseudoVMERGE_VVM_##lmul: \
NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
Expand All @@ -116,7 +117,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
return false;

assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
if (!isAllOnesMask(V0Def))
if (!isAllOnesMask(V0Defs.lookup(&MI)))
return false;

MI.setDesc(TII->get(NewOpc));
Expand All @@ -133,14 +134,13 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
return true;
}

bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI,
MachineInstr *MaskDef) const {
bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
const RISCV::RISCVMaskedPseudoInfo *I =
RISCV::getMaskedPseudoInfo(MI.getOpcode());
if (!I)
return false;

if (!isAllOnesMask(MaskDef))
if (!isAllOnesMask(V0Defs.lookup(&MI)))
return false;

// There are two classes of pseudos in the table - compares and
Expand Down Expand Up @@ -198,20 +198,26 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
// $v0:vr = COPY %mask:vr
// %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
//
// Because $v0 isn't in SSA, keep track of it so we can check the mask operand
// on each pseudo.
MachineInstr *CurrentV0Def;
for (MachineBasicBlock &MBB : MF) {
CurrentV0Def = nullptr;
for (MachineInstr &MI : MBB) {
Changed |= convertToUnmasked(MI, CurrentV0Def);
Changed |= convertVMergeToVMv(MI, CurrentV0Def);
// Because $v0 isn't in SSA, keep track of its definition at each use so we
// can check mask operands.
for (const MachineBasicBlock &MBB : MF) {
const MachineInstr *CurrentV0Def = nullptr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ignores defs of V0 from outside the BB, since we're only interested in the local

$v0:vr = COPY %mask:vr
%x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr

format.

for (const MachineInstr &MI : MBB) {
if (MI.readsRegister(RISCV::V0, TRI))
V0Defs[&MI] = CurrentV0Def;

if (MI.definesRegister(RISCV::V0, TRI))
CurrentV0Def = &MI;
}
}

for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
Changed |= convertToUnmasked(MI);
Changed |= convertVMergeToVMv(MI);
}
}

return Changed;
}

Expand Down