Skip to content

[RISCV] Don't use V0 directly in patterns #88496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 18 additions & 31 deletions llvm/lib/Target/RISCV/RISCVFoldMasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ class RISCVFoldMasks : public MachineFunctionPass {
bool convertToUnmasked(MachineInstr &MI) const;
bool convertVMergeToVMv(MachineInstr &MI) const;

bool isAllOnesMask(const MachineInstr *MaskDef) const;

/// Maps uses of V0 to the corresponding def of V0.
DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
bool isAllOnesMask(const MachineOperand &MaskOp) const;
};

} // namespace
Expand All @@ -62,12 +59,22 @@ char RISCVFoldMasks::ID = 0;

INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)

bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
assert(MaskDef && MaskDef->isCopy() &&
MaskDef->getOperand(0).getReg() == RISCV::V0);
bool RISCVFoldMasks::isAllOnesMask(const MachineOperand &MaskOp) const {
if (!MaskOp.isReg())
return false;

Register MaskReg = MaskOp.getReg();
if (!MaskReg.isVirtual())
return false;

MachineInstr *MaskDef = MRI->getVRegDef(MaskReg);
if (!MaskDef || !MaskDef->isCopy())
return false;

Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
if (!SrcReg.isVirtual())
return false;

MaskDef = MRI->getVRegDef(SrcReg);
if (!MaskDef)
return false;
Expand Down Expand Up @@ -116,8 +123,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
TRI->lookThruCopyLike(FalseReg, MRI))
return false;

assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
if (!isAllOnesMask(V0Defs.lookup(&MI)))
if (!isAllOnesMask(MI.getOperand(4)))
return false;

MI.setDesc(TII->get(NewOpc));
Expand All @@ -140,7 +146,9 @@ bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
if (!I)
return false;

if (!isAllOnesMask(V0Defs.lookup(&MI)))
// TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
if (!isAllOnesMask(MI.getOperand(MaskOpIdx)))
return false;

// There are two classes of pseudos in the table - compares and
Expand All @@ -160,9 +168,6 @@ bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
(void)HasPolicyOp;

MI.setDesc(MCID);

// TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
MI.removeOperand(MaskOpIdx);

// The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
Expand Down Expand Up @@ -193,24 +198,6 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {

bool Changed = false;

// Masked pseudos coming out of isel will have their mask operand in the form:
//
// $v0:vr = COPY %mask:vr
// %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
//
// Because $v0 isn't in SSA, keep track of its definition at each use so we
// can check mask operands.
for (const MachineBasicBlock &MBB : MF) {
const MachineInstr *CurrentV0Def = nullptr;
for (const MachineInstr &MI : MBB) {
if (MI.readsRegister(RISCV::V0, TRI))
V0Defs[&MI] = CurrentV0Def;

if (MI.definesRegister(RISCV::V0, TRI))
CurrentV0Def = &MI;
}
}

for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
Changed |= convertToUnmasked(MI);
Expand Down
82 changes: 10 additions & 72 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ void RISCVDAGToDAGISel::addVectorLoadStoreOperands(
bool IsMasked, bool IsStridedOrIndexed, SmallVectorImpl<SDValue> &Operands,
bool IsLoad, MVT *IndexVT) {
SDValue Chain = Node->getOperand(0);
SDValue Glue;

Operands.push_back(Node->getOperand(CurOp++)); // Base pointer.

Expand All @@ -307,11 +306,8 @@ void RISCVDAGToDAGISel::addVectorLoadStoreOperands(
}

if (IsMasked) {
// Mask needs to be copied to V0.
SDValue Mask = Node->getOperand(CurOp++);
Chain = CurDAG->getCopyToReg(Chain, DL, RISCV::V0, Mask, SDValue());
Glue = Chain.getValue(1);
Operands.push_back(CurDAG->getRegister(RISCV::V0, Mask.getValueType()));
Operands.push_back(Mask);
}
SDValue VL;
selectVLOp(Node->getOperand(CurOp++), VL);
Expand All @@ -333,8 +329,6 @@ void RISCVDAGToDAGISel::addVectorLoadStoreOperands(
}

Operands.push_back(Chain); // Chain.
if (Glue)
Operands.push_back(Glue);
}

void RISCVDAGToDAGISel::selectVLSEG(SDNode *Node, bool IsMasked,
Expand Down Expand Up @@ -1670,20 +1664,14 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
return;
}

// Mask needs to be copied to V0.
SDValue Chain = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
RISCV::V0, Mask, SDValue());
SDValue Glue = Chain.getValue(1);
SDValue V0 = CurDAG->getRegister(RISCV::V0, VT);

// Otherwise use
// vmslt{u}.vx vd, va, x, v0.t; vmxor.mm vd, vd, v0
// The result is mask undisturbed.
// We use the same instructions to emulate mask agnostic behavior, because
// the agnostic result can be either undisturbed or all 1.
SDValue Cmp = SDValue(
CurDAG->getMachineNode(VMSLTMaskOpcode, DL, VT,
{MaskedOff, Src1, Src2, V0, VL, SEW, Glue}),
{MaskedOff, Src1, Src2, Mask, VL, SEW}),
0);
// vmxor.mm vd, vd, v0 is used to update active value.
ReplaceNode(Node, CurDAG->getMachineNode(VMXOROpcode, DL, VT,
Expand Down Expand Up @@ -3426,32 +3414,7 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
return false;
}

static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
// Check that we're using V0 as a mask register.
if (!isa<RegisterSDNode>(MaskOp) ||
cast<RegisterSDNode>(MaskOp)->getReg() != RISCV::V0)
return false;

// The glued user defines V0.
const auto *Glued = GlueOp.getNode();

if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
return false;

// Check that we're defining V0 as a mask register.
if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
return false;

// Check the instruction defining V0; it needs to be a VMSET pseudo.
SDValue MaskSetter = Glued->getOperand(2);

// Sometimes the VMSET is wrapped in a COPY_TO_REGCLASS, e.g. if the mask came
// from an extract_subvector or insert_subvector.
if (MaskSetter->isMachineOpcode() &&
MaskSetter->getMachineOpcode() == RISCV::COPY_TO_REGCLASS)
MaskSetter = MaskSetter->getOperand(0);

static bool usesAllOnesMask(SDValue MaskOp) {
const auto IsVMSet = [](unsigned Opc) {
return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
Expand All @@ -3462,14 +3425,12 @@ static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
// TODO: Check that the VMSET is the expected bitwidth? The pseudo has
// undefined behaviour if it's the wrong bitwidth, so we could choose to
// assume that it's all-ones? Same applies to its VL.
return MaskSetter->isMachineOpcode() &&
IsVMSet(MaskSetter.getMachineOpcode());
return MaskOp->isMachineOpcode() && IsVMSet(MaskOp.getMachineOpcode());
}

// Return true if we can make sure mask of N is all-ones mask.
static bool usesAllOnesMask(SDNode *N, unsigned MaskOpIdx) {
return usesAllOnesMask(N->getOperand(MaskOpIdx),
N->getOperand(N->getNumOperands() - 1));
return usesAllOnesMask(N->getOperand(MaskOpIdx));
}

static bool isImplicitDef(SDValue V) {
Expand Down Expand Up @@ -3515,11 +3476,6 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(MachineSDNode *N) {
Ops.push_back(Op);
}

// Transitively apply any node glued to our new node.
const auto *Glued = N->getGluedNode();
if (auto *TGlued = Glued->getGluedNode())
Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1));

MachineSDNode *Result =
CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops);

Expand Down Expand Up @@ -3584,7 +3540,7 @@ static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
// The resulting policy is the effective policy the vmerge would have had,
// i.e. whether or not it's merge operand was implicit-def.
bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
SDValue Merge, False, True, VL, Mask, Glue;
SDValue Merge, False, True, VL, Mask;
// A vmv.v.v is equivalent to a vmerge with an all-ones mask.
if (IsVMv(N)) {
Merge = N->getOperand(0);
Expand All @@ -3600,11 +3556,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
True = N->getOperand(2);
Mask = N->getOperand(3);
VL = N->getOperand(4);
// We always have a glue node for the mask at v0.
Glue = N->getOperand(N->getNumOperands() - 1);
}
assert(!Mask || cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
assert(!Glue || Glue.getValueType() == MVT::Glue);

// We require that either merge and false are the same, or that merge
// is undefined.
Expand Down Expand Up @@ -3639,7 +3591,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {

// When Mask is not a true mask, this transformation is illegal for some
// operations whose results are affected by mask, like viota.m.
if (Info->MaskAffectsResult && Mask && !usesAllOnesMask(Mask, Glue))
if (Info->MaskAffectsResult && Mask && !usesAllOnesMask(Mask))
return false;

// If True has a merge operand then it needs to be the same as vmerge's False,
Expand All @@ -3664,7 +3616,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
return false;
// FIXME: Support mask agnostic True instruction which would have an
// undef merge operand.
if (Mask && !usesAllOnesMask(Mask, Glue))
if (Mask && !usesAllOnesMask(Mask))
return false;
}

Expand All @@ -3691,8 +3643,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (Mask)
LoopWorklist.push_back(Mask.getNode());
LoopWorklist.push_back(VL.getNode());
if (Glue)
LoopWorklist.push_back(Glue.getNode());
if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist))
return false;
}
Expand Down Expand Up @@ -3737,25 +3687,16 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {

// From the preconditions we checked above, we know the mask and thus glue
// for the result node will be taken from True.
if (IsMasked) {
if (IsMasked)
Mask = True->getOperand(Info->MaskOpIdx);
Glue = True->getOperand(True->getNumOperands() - 1);
assert(Glue.getValueType() == MVT::Glue);
}
// If we end up using the vmerge mask the vmerge is actually a vmv.v.v, create
// an all-ones mask to use.
else if (IsVMv(N)) {
unsigned TSFlags = TII->get(N->getMachineOpcode()).TSFlags;
unsigned VMSetOpc = GetVMSetForLMul(RISCVII::getLMul(TSFlags));
ElementCount EC = N->getValueType(0).getVectorElementCount();
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);

SDValue AllOnesMask =
SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0);
SDValue MaskCopy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
RISCV::V0, AllOnesMask, SDValue());
Mask = CurDAG->getRegister(RISCV::V0, MaskVT);
Glue = MaskCopy.getValue(1);
Mask = SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0);
}

unsigned MaskedOpc = Info->MaskedPseudo;
Expand Down Expand Up @@ -3806,9 +3747,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
if (HasChainOp)
Ops.push_back(True.getOperand(TrueChainOpIdx));

// Add the glue for the CopyToReg of mask->v0.
Ops.push_back(Glue);

MachineSDNode *Result =
CurDAG->getMachineNode(MaskedOpc, DL, True->getVTList(), Ops);
Result->setFlags(True->getFlags());
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {

unsigned getUndefInitOpcode(unsigned RegClassID) const override {
switch (RegClassID) {
case RISCV::VMV0RegClassID:
case RISCV::VRRegClassID:
return RISCV::PseudoRVVInitUndefM1;
case RISCV::VRM2RegClassID:
Expand Down
Loading