Skip to content

[RISCV] Move performCombineVMergeAndVOps into RISCVFoldMasks #71764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 276 additions & 7 deletions llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,26 @@
// This pass performs various peephole optimisations that fold masks into vector
// pseudo instructions after instruction selection.
//
// Currently it converts
// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
// It performs the following transforms:
//
// %true = PseudoFOO %passthru, ..., %vl, %sew
// %x = PseudoVMERGE_VVM %passthru, %passthru, %true, %mask, %vl, %sew
// ->
// %x = PseudoFOO_MASK %false, ..., %mask, %vl, %sew
//
// %x = PseudoFOO_MASK ..., %allonesmask, %vl, %sew
// ->
// PseudoVMV_V_V %false, %true, %vl, %sew
// %x = PseudoFOO ..., %vl, %sew
//
// %x = PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
// ->
// %x = PseudoVMV_V_V %false, %true, %vl, %sew
//
//===---------------------------------------------------------------------===//

#include "RISCV.h"
#include "RISCVISelDAGToDAG.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"

using namespace llvm;

Expand Down Expand Up @@ -50,6 +56,7 @@ class RISCVFoldMasks : public MachineFunctionPass {

private:
bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef);
bool foldVMergeIntoOps(MachineInstr &MI, MachineInstr *MaskDef);
bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef);

bool isAllOnesMask(MachineInstr *MaskDef);
Expand Down Expand Up @@ -89,6 +96,261 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) {
}
}

static unsigned getVMSetForLMul(RISCVII::VLMUL LMUL) {
switch (LMUL) {
case RISCVII::LMUL_F8:
return RISCV::PseudoVMSET_M_B1;
case RISCVII::LMUL_F4:
return RISCV::PseudoVMSET_M_B2;
case RISCVII::LMUL_F2:
return RISCV::PseudoVMSET_M_B4;
case RISCVII::LMUL_1:
return RISCV::PseudoVMSET_M_B8;
case RISCVII::LMUL_2:
return RISCV::PseudoVMSET_M_B16;
case RISCVII::LMUL_4:
return RISCV::PseudoVMSET_M_B32;
case RISCVII::LMUL_8:
return RISCV::PseudoVMSET_M_B64;
case RISCVII::LMUL_RESERVED:
llvm_unreachable("Unexpected LMUL");
}
llvm_unreachable("Unknown VLMUL enum");
}

// Try to fold away VMERGE_VVM instructions. We handle these cases:
// -Masked TU VMERGE_VVM combined with an unmasked TA instruction instruction
// folds to a masked TU instruction. VMERGE_VVM must have have merge operand
// same as false operand.
// -Masked TA VMERGE_VVM combined with an unmasked TA instruction fold to a
// masked TA instruction.
// -Unmasked TU VMERGE_VVM combined with a masked MU TA instruction folds to
// masked TU instruction. Both instructions must have the same merge operand.
// VMERGE_VVM must have have merge operand same as false operand.
// Note: The VMERGE_VVM forms above (TA, and TU) refer to the policy implied,
// not the pseudo name. That is, a TA VMERGE_VVM can be either the _TU pseudo
// form with an IMPLICIT_DEF passthrough operand or the unsuffixed (TA) pseudo
// form.
bool RISCVFoldMasks::foldVMergeIntoOps(MachineInstr &MI,
MachineInstr *MaskDef) {
MachineOperand *True;
MachineOperand *Merge;
MachineOperand *False;

const unsigned BaseOpc = RISCV::getRVVMCOpcode(MI.getOpcode());
// A vmv.v.v is equivalent to a vmerge with an all-ones mask.
if (BaseOpc == RISCV::VMV_V_V) {
Merge = &MI.getOperand(1);
False = &MI.getOperand(1);
True = &MI.getOperand(2);
} else if (BaseOpc == RISCV::VMERGE_VVM) {
Merge = &MI.getOperand(1);
False = &MI.getOperand(2);
True = &MI.getOperand(3);
} else
return false;

MachineInstr &TrueMI = *MRI->getVRegDef(True->getReg());
if (TrueMI.getParent() != MI.getParent())
return false;

// We require that either merge and false are the same, or that merge
// is undefined.
if (Merge->getReg() != RISCV::NoRegister &&
TRI->lookThruCopyLike(Merge->getReg(), MRI) !=
TRI->lookThruCopyLike(False->getReg(), MRI))
return false;

// N must be the only user of True.
if (!MRI->hasOneUse(True->getReg()))
return false;

unsigned TrueOpc = TrueMI.getOpcode();
const MCInstrDesc &TrueMCID = TrueMI.getDesc();
bool HasTiedDest = RISCVII::isFirstDefTiedToFirstUse(TrueMCID);

const bool MIIsMasked =
BaseOpc == RISCV::VMERGE_VVM && !isAllOnesMask(MaskDef);
bool TrueIsMasked = false;
const RISCV::RISCVMaskedPseudoInfo *Info =
RISCV::lookupMaskedIntrinsicByUnmasked(TrueOpc);
if (!Info && HasTiedDest) {
Info = RISCV::getMaskedPseudoInfo(TrueOpc);
TrueIsMasked = true;
}

if (!Info)
return false;

// When Mask is not a true mask, this transformation is illegal for some
// operations whose results are affected by mask, like viota.m.
if (Info->MaskAffectsResult && MIIsMasked)
return false;

MachineOperand &TrueMergeOp = TrueMI.getOperand(1);
if (HasTiedDest && TrueMergeOp.getReg() != RISCV::NoRegister) {
// The vmerge instruction must be TU.
// FIXME: This could be relaxed, but we need to handle the policy for the
// resulting op correctly.
if (Merge->getReg() == RISCV::NoRegister)
return false;
// Both the vmerge instruction and the True instruction must have the same
// merge operand.
if (TrueMergeOp.getReg() != RISCV::NoRegister &&
TrueMergeOp.getReg() != False->getReg())
return false;
}

if (TrueIsMasked) {
assert(HasTiedDest && "Expected tied dest");
// The vmerge instruction must be TU.
if (Merge->getReg() == RISCV::NoRegister)
return false;
// MI must have an all 1s mask since we're going to keep the mask from the
// True instruction.
// FIXME: Support mask agnostic True instruction which would have an undef
// merge operand.
if (MIIsMasked)
return false;
}

// Skip if True has side effect.
// TODO: Support vleff and vlsegff.
if (TII->get(TrueOpc).hasUnmodeledSideEffects())
return false;

auto GetMinVL =
[](const MachineOperand &LHS,
const MachineOperand &RHS) -> std::optional<MachineOperand> {
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
LHS.getReg() == RHS.getReg())
return LHS;
if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
return RHS;
if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
return LHS;
if (!LHS.isImm() || !RHS.isImm())
return std::nullopt;
return LHS.getImm() <= RHS.getImm() ? LHS : RHS;
};

// Because MI and True must have the same merge operand (or True's operand is
// implicit_def), the "effective" body is the minimum of their VLs.
const MachineOperand &TrueVL =
TrueMI.getOperand(RISCVII::getVLOpNum(TrueMCID));
const MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
auto MinVL = GetMinVL(TrueVL, VL);
if (!MinVL)
return false;
bool VLChanged = !MinVL->isIdenticalTo(VL);

// If we end up changing the VL or mask of True, then we need to make sure it
// doesn't raise any observable fp exceptions, since changing the active
// elements will affect how fflags is set.
if (VLChanged || !TrueIsMasked)
if (TrueMCID.mayRaiseFPException() &&
!TrueMI.getFlag(MachineInstr::MIFlag::NoFPExcept))
return false;

unsigned MaskedOpc = Info->MaskedPseudo;
const MCInstrDesc &MaskedMCID = TII->get(MaskedOpc);
#ifndef NDEBUG
assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) &&
"Expected instructions with mask have policy operand.");
assert(MaskedMCID.getOperandConstraint(MaskedMCID.getNumDefs(),
MCOI::TIED_TO) == 0 &&
"Expected instructions with mask have a tied dest.");
#endif

// Sink True down to MI so that it can access MI's operands.
assert(!TrueMI.hasImplicitDef());
bool SawStore = false;
for (MachineBasicBlock::instr_iterator II = TrueMI.getIterator();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have checked before this that TrueMI and MI are in the same basic block.

II != MI.getIterator(); II++) {
if (II->mayStore()) {
SawStore = true;
break;
}
}
if (!TrueMI.isSafeToMove(nullptr, SawStore))
return false;
TrueMI.moveBefore(&MI);

// Set the merge to the false operand of the merge.
TrueMI.getOperand(1).setReg(False->getReg());

bool NeedToMoveOldMask = TrueIsMasked;
// If we're converting it to a masked pseudo, reuse MI's mask.
if (!TrueIsMasked) {
if (BaseOpc == RISCV::VMV_V_V) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've got this opcode check repeated a bunch, maybe use a well named variable for readability? VMergeIsMasked? Could also fold the allOnesMask check into the definition.

// If MI is a vmv.v.v, it won't have a mask operand. So insert an all-ones
// mask just before True.
unsigned VMSetOpc =
getVMSetForLMul(RISCVII::getLMul(MI.getDesc().TSFlags));
Register Dest = MRI->createVirtualRegister(&RISCV::VRRegClass);
BuildMI(*MI.getParent(), TrueMI, MI.getDebugLoc(), TII->get(VMSetOpc),
Dest)
.add(VL)
.add(TrueMI.getOperand(RISCVII::getSEWOpNum(TrueMCID)));
BuildMI(*MI.getParent(), TrueMI, MI.getDebugLoc(), TII->get(RISCV::COPY),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this is safe. We're inserting a clobber of V0 which didn't previously exist. Do we need to restore the prior V0 value after doing so? We've tracked the definition, so we know what it was.

Copy link
Contributor Author

@lukel97 lukel97 Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we should handle this. Off the top of my head, almost every instruction that uses V0 will have a copy to V0 immediately before it because of the glue in SelectionDAG. But I remember trying to add an assert for this and there were some edge cases.

On a separate note, I have a local branch where I was able to remove the V0 copies in SelectionDAG by adjusting the tablegen patterns to use the VMV0 register class instead, e.g.:

@@ -4235,13 +4235,13 @@ class VPatBinaryMaskTA<string intrinsic_name,
                    (result_type result_reg_class:$merge),
                    (op1_type op1_reg_class:$rs1),
                    (op2_type op2_kind:$rs2),
-                   (mask_type V0),
+                   (mask_type VMV0:$vm),
                    VLOpFrag, (XLenVT timm:$policy))),
                    (!cast<Instruction>(inst#"_MASK")
                    (result_type result_reg_class:$merge),
                    (op1_type op1_reg_class:$rs1),
                    (op2_type op2_kind:$rs2),
-                   (mask_type V0), GPR:$vl, sew, (XLenVT timm:$policy))>;
+                   (mask_type VMV0:$vm), GPR:$vl, sew, (XLenVT timm:$policy))>;

I was then able to simplify all of this, but I'll return to it after this PR. It requires changes to this pass and also has its own test diff since registers get shuffled about.

RISCV::V0)
.addReg(Dest);
NeedToMoveOldMask = true;
}

TrueMI.setDesc(MaskedMCID);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This in place mutation of the operands looks really risky. I think it would be much safer to build up the new operand list, and then replace in one go. In particular, the use of the index accessors on the masked mcid is going to result in clobbering values potentially before their read. I don't want to have to reason about the potential RAWs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave this a try by building up a SmallVector<MachineOperand> and then creating a new MachineInstr with MaskedMCID, but then ran into the issue that it doesn't preserve any of the flags or tied operands. The only way I can think of preserving these is by operating on the operands in place, or by cloning the instruction entirely. But then we still need to insert an operand for the unmasked vmv.v.v case. Any ideas?


// TODO: Increment MaskOpIdx by number of explicit defs in tablegen?
unsigned MaskOpIdx = Info->MaskOpIdx + TrueMI.getNumExplicitDefs();
TrueMI.insert(&TrueMI.getOperand(MaskOpIdx),
MachineOperand::CreateReg(RISCV::V0, false));
}

// Update the AVL.
if (MinVL->isReg())
TrueMI.getOperand(RISCVII::getVLOpNum(MaskedMCID))
.ChangeToRegister(MinVL->getReg(), false);
else
TrueMI.getOperand(RISCVII::getVLOpNum(MaskedMCID))
.ChangeToImmediate(MinVL->getImm());

// Use a tumu policy, relaxing it to tail agnostic provided that the merge
// operand is undefined.
//
// However, if the VL became smaller than what the vmerge had originally, then
// elements past VL that were previously in the vmerge's body will have moved
// to the tail. In that case we always need to use tail undisturbed to
// preserve them.
uint64_t Policy = (Merge->getReg() == RISCV::NoRegister && !VLChanged)
? RISCVII::TAIL_AGNOSTIC
: RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
TrueMI.getOperand(RISCVII::getVecPolicyOpNum(MaskedMCID)).setImm(Policy);

const TargetRegisterClass *V0RC =
TII->getRegClass(MaskedMCID, 0, TRI, *MI.getMF());

// The destination and passthru can no longer be in V0.
MRI->constrainRegClass(TrueMI.getOperand(0).getReg(), V0RC);
Register PassthruReg = TrueMI.getOperand(1).getReg();
if (PassthruReg != RISCV::NoRegister)
MRI->constrainRegClass(PassthruReg, V0RC);

MRI->replaceRegWith(MI.getOperand(0).getReg(), TrueMI.getOperand(0).getReg());

// We need to move the old mask copy to after MI if:
// - TrueMI is masked and we are using its mask instead
// - We created a new all ones mask that clobbers V0
if (NeedToMoveOldMask && MaskDef) {
assert(MaskDef->getParent() == MI.getParent());
MaskDef->removeFromParent();
MI.getParent()->insertAfter(MI.getIterator(), MaskDef);
}

MI.eraseFromParent();

return true;
}

// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) {
Expand Down Expand Up @@ -202,6 +464,13 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
// on each pseudo.
MachineInstr *CurrentV0Def;
for (MachineBasicBlock &MBB : MF) {
CurrentV0Def = nullptr;
for (MachineInstr &MI : make_early_inc_range(MBB)) {
Changed |= foldVMergeIntoOps(MI, CurrentV0Def);
if (MI.definesRegister(RISCV::V0, TRI))
CurrentV0Def = &MI;
}

CurrentV0Def = nullptr;
for (MachineInstr &MI : MBB) {
Changed |= convertToUnmasked(MI, CurrentV0Def);
Expand Down
Loading