Skip to content

Commit d0e913f

Browse files
committed
[RISCV] Move vmv.v.v peephole from SelectionDAG to RISCVVectorPeephole
This is split off from llvm#71764, and moves only the vmv.v.v part of performCombineVMergeAndVOps to work on MachineInstrs. In retrospect trying to handle PseudoVMV_V_V and PseudoVMERGE_VVM in the same function makes the code quite hard to read, so this just does it in a separate peephole. This turns out to be simpler since for PseudoVMV_V_V we don't need to convert the Src instruction to a masked variant, and we don't need to create a fake all ones mask.
1 parent 9b14831 commit d0e913f

File tree

2 files changed

+154
-71
lines changed

2 files changed

+154
-71
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3663,32 +3663,6 @@ static bool IsVMerge(SDNode *N) {
36633663
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMERGE_VVM;
36643664
}
36653665

3666-
static bool IsVMv(SDNode *N) {
3667-
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMV_V_V;
3668-
}
3669-
3670-
static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
3671-
switch (LMUL) {
3672-
case RISCVII::LMUL_F8:
3673-
return RISCV::PseudoVMSET_M_B1;
3674-
case RISCVII::LMUL_F4:
3675-
return RISCV::PseudoVMSET_M_B2;
3676-
case RISCVII::LMUL_F2:
3677-
return RISCV::PseudoVMSET_M_B4;
3678-
case RISCVII::LMUL_1:
3679-
return RISCV::PseudoVMSET_M_B8;
3680-
case RISCVII::LMUL_2:
3681-
return RISCV::PseudoVMSET_M_B16;
3682-
case RISCVII::LMUL_4:
3683-
return RISCV::PseudoVMSET_M_B32;
3684-
case RISCVII::LMUL_8:
3685-
return RISCV::PseudoVMSET_M_B64;
3686-
case RISCVII::LMUL_RESERVED:
3687-
llvm_unreachable("Unexpected LMUL");
3688-
}
3689-
llvm_unreachable("Unknown VLMUL enum");
3690-
}
3691-
36923666
// Try to fold away VMERGE_VVM instructions into their true operands:
36933667
//
36943668
// %true = PseudoVADD_VV ...
@@ -3703,35 +3677,22 @@ static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
37033677
// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
37043678
// mask is all ones.
37053679
//
3706-
// We can also fold a VMV_V_V into its true operand, since it is equivalent to a
3707-
// VMERGE_VVM with an all ones mask.
3708-
//
37093680
// The resulting VL is the minimum of the two VLs.
37103681
//
37113682
// The resulting policy is the effective policy the vmerge would have had,
37123683
// i.e. whether or not it's merge operand was implicit-def.
37133684
bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
37143685
SDValue Merge, False, True, VL, Mask, Glue;
3715-
// A vmv.v.v is equivalent to a vmerge with an all-ones mask.
3716-
if (IsVMv(N)) {
3717-
Merge = N->getOperand(0);
3718-
False = N->getOperand(0);
3719-
True = N->getOperand(1);
3720-
VL = N->getOperand(2);
3721-
// A vmv.v.v won't have a Mask or Glue, instead we'll construct an all-ones
3722-
// mask later below.
3723-
} else {
3724-
assert(IsVMerge(N));
3725-
Merge = N->getOperand(0);
3726-
False = N->getOperand(1);
3727-
True = N->getOperand(2);
3728-
Mask = N->getOperand(3);
3729-
VL = N->getOperand(4);
3730-
// We always have a glue node for the mask at v0.
3731-
Glue = N->getOperand(N->getNumOperands() - 1);
3732-
}
3733-
assert(!Mask || cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
3734-
assert(!Glue || Glue.getValueType() == MVT::Glue);
3686+
assert(IsVMerge(N));
3687+
Merge = N->getOperand(0);
3688+
False = N->getOperand(1);
3689+
True = N->getOperand(2);
3690+
Mask = N->getOperand(3);
3691+
VL = N->getOperand(4);
3692+
// We always have a glue node for the mask at v0.
3693+
Glue = N->getOperand(N->getNumOperands() - 1);
3694+
assert(cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
3695+
assert(Glue.getValueType() == MVT::Glue);
37353696

37363697
// We require that either merge and false are the same, or that merge
37373698
// is undefined.
@@ -3775,7 +3736,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
37753736

37763737
// If True is masked then the vmerge must have either the same mask or an all
37773738
// 1s mask, since we're going to keep the mask from True.
3778-
if (IsMasked && Mask) {
3739+
if (IsMasked) {
37793740
// FIXME: Support mask agnostic True instruction which would have an
37803741
// undef merge operand.
37813742
SDValue TrueMask =
@@ -3805,11 +3766,9 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
38053766
SmallVector<const SDNode *, 4> LoopWorklist;
38063767
SmallPtrSet<const SDNode *, 16> Visited;
38073768
LoopWorklist.push_back(False.getNode());
3808-
if (Mask)
3809-
LoopWorklist.push_back(Mask.getNode());
3769+
LoopWorklist.push_back(Mask.getNode());
38103770
LoopWorklist.push_back(VL.getNode());
3811-
if (Glue)
3812-
LoopWorklist.push_back(Glue.getNode());
3771+
LoopWorklist.push_back(Glue.getNode());
38133772
if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist))
38143773
return false;
38153774
}
@@ -3869,21 +3828,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
38693828
Glue = True->getOperand(True->getNumOperands() - 1);
38703829
assert(Glue.getValueType() == MVT::Glue);
38713830
}
3872-
// If we end up using the vmerge mask the vmerge is actually a vmv.v.v, create
3873-
// an all-ones mask to use.
3874-
else if (IsVMv(N)) {
3875-
unsigned TSFlags = TII->get(N->getMachineOpcode()).TSFlags;
3876-
unsigned VMSetOpc = GetVMSetForLMul(RISCVII::getLMul(TSFlags));
3877-
ElementCount EC = N->getValueType(0).getVectorElementCount();
3878-
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);
3879-
3880-
SDValue AllOnesMask =
3881-
SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0);
3882-
SDValue MaskCopy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
3883-
RISCV::V0, AllOnesMask, SDValue());
3884-
Mask = CurDAG->getRegister(RISCV::V0, MaskVT);
3885-
Glue = MaskCopy.getValue(1);
3886-
}
38873831

38883832
unsigned MaskedOpc = Info->MaskedPseudo;
38893833
#ifndef NDEBUG
@@ -3962,7 +3906,7 @@ bool RISCVDAGToDAGISel::doPeepholeMergeVVMFold() {
39623906
if (N->use_empty() || !N->isMachineOpcode())
39633907
continue;
39643908

3965-
if (IsVMerge(N) || IsVMv(N))
3909+
if (IsVMerge(N))
39663910
MadeChange |= performCombineVMergeAndVOps(N);
39673911
}
39683912
return MadeChange;

llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
6565
bool convertToWholeRegister(MachineInstr &MI) const;
6666
bool convertToUnmasked(MachineInstr &MI) const;
6767
bool convertVMergeToVMv(MachineInstr &MI) const;
68+
bool foldVMV_V_V(MachineInstr &MI);
6869

6970
bool isAllOnesMask(const MachineInstr *MaskDef) const;
7071
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
@@ -323,6 +324,143 @@ bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
323324
return true;
324325
}
325326

327+
/// Given two VL operands, returns the one known to be the smallest or nullptr
328+
/// if unknown.
329+
static const MachineOperand *getKnownMinVL(const MachineOperand *LHS,
330+
const MachineOperand *RHS) {
331+
if (LHS->isReg() && RHS->isReg() && LHS->getReg().isVirtual() &&
332+
LHS->getReg() == RHS->getReg())
333+
return LHS;
334+
if (LHS->isImm() && LHS->getImm() == RISCV::VLMaxSentinel)
335+
return RHS;
336+
if (RHS->isImm() && RHS->getImm() == RISCV::VLMaxSentinel)
337+
return LHS;
338+
if (!LHS->isImm() || !RHS->isImm())
339+
return nullptr;
340+
return LHS->getImm() <= RHS->getImm() ? LHS : RHS;
341+
}
342+
343+
/// Check if it's safe to move From down to To, checking that no physical
344+
/// registers are clobbered.
345+
static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
346+
assert(From.getParent() == To.getParent() && !From.hasImplicitDef());
347+
SmallVector<Register> PhysUses;
348+
for (const MachineOperand &MO : From.all_uses())
349+
if (MO.getReg().isPhysical())
350+
PhysUses.push_back(MO.getReg());
351+
bool SawStore = false;
352+
for (auto II = From.getIterator(); II != To.getIterator(); II++) {
353+
for (Register PhysReg : PhysUses)
354+
if (II->definesRegister(PhysReg, nullptr))
355+
return false;
356+
if (II->mayStore()) {
357+
SawStore = true;
358+
break;
359+
}
360+
}
361+
return From.isSafeToMove(nullptr, SawStore);
362+
}
363+
364+
static const RISCV::RISCVMaskedPseudoInfo *
365+
lookupMaskedPseudoInfo(const MachineInstr &MI) {
366+
const RISCV::RISCVMaskedPseudoInfo *Info =
367+
RISCV::lookupMaskedIntrinsicByUnmasked(MI.getOpcode());
368+
if (!Info)
369+
Info = RISCV::getMaskedPseudoInfo(MI.getOpcode());
370+
return Info;
371+
}
372+
373+
/// If a PseudoVMV_V_V is the only user of it's input, fold its passthru and VL
374+
/// into it.
375+
///
376+
/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl, sew, policy
377+
/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl, sew, policy
378+
///
379+
/// ->
380+
///
381+
/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl, sew, policy
382+
bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
383+
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
384+
return false;
385+
386+
MachineOperand &Passthru = MI.getOperand(1);
387+
MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
388+
389+
if (!MRI->hasOneUse(MI.getOperand(2).getReg()))
390+
return false;
391+
392+
if (!Src || Src->hasUnmodeledSideEffects() ||
393+
Src->getParent() != MI.getParent())
394+
return false;
395+
396+
// Src needs to be a pseudo that's opted into this transform.
397+
const RISCV::RISCVMaskedPseudoInfo *Info = lookupMaskedPseudoInfo(*Src);
398+
if (!Info)
399+
return false;
400+
401+
assert(Src->getNumDefs() == 1 &&
402+
RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) &&
403+
RISCVII::hasVLOp(Src->getDesc().TSFlags) &&
404+
RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags));
405+
406+
// Src needs to have the same passthru as VMV_V_V
407+
if (Src->getOperand(1).getReg() != RISCV::NoRegister &&
408+
Src->getOperand(1).getReg() != Passthru.getReg())
409+
return false;
410+
411+
// Because Src and MI have the same passthru, we can use either AVL as long as
412+
// it's the smaller of the two.
413+
//
414+
// (src pt, ..., vl=5) x x x x x|. . .
415+
// (vmv.v.v pt, src, vl=3) x x x|. . . . .
416+
// ->
417+
// (src pt, ..., vl=3) x x x|. . . . .
418+
//
419+
// (src pt, ..., vl=3) x x x|. . . . .
420+
// (vmv.v.v pt, src, vl=6) x x x . . .|. .
421+
// ->
422+
// (src pt, ..., vl=3) x x x|. . . . .
423+
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
424+
const MachineOperand *MinVL = getKnownMinVL(&MI.getOperand(3), &SrcVL);
425+
if (!MinVL)
426+
return false;
427+
428+
bool VLChanged = !MinVL->isIdenticalTo(SrcVL);
429+
bool RaisesFPExceptions = MI.getDesc().mayRaiseFPException() &&
430+
!MI.getFlag(MachineInstr::MIFlag::NoFPExcept);
431+
if (VLChanged && (Info->ActiveElementsAffectResult || RaisesFPExceptions))
432+
return false;
433+
434+
if (!isSafeToMove(*Src, MI))
435+
return false;
436+
437+
// Move Src down to MI, then replace all uses of MI with it.
438+
Src->moveBefore(&MI);
439+
440+
Src->getOperand(1).setReg(Passthru.getReg());
441+
// If Src is masked then its passthru needs to be in VRNoV0.
442+
if (Passthru.getReg() != RISCV::NoRegister)
443+
MRI->constrainRegClass(Passthru.getReg(),
444+
TII->getRegClass(Src->getDesc(), 1, TRI,
445+
*Src->getParent()->getParent()));
446+
447+
if (MinVL->isImm())
448+
SrcVL.ChangeToImmediate(MinVL->getImm());
449+
else if (MinVL->isReg())
450+
SrcVL.ChangeToRegister(MinVL->getReg(), false);
451+
452+
// Use a conservative tu,mu policy, RISCVInsertVSETVLI will relax it if
453+
// passthru is undef.
454+
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()))
455+
.setImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED);
456+
457+
MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg());
458+
MI.eraseFromParent();
459+
V0Defs.erase(&MI);
460+
461+
return true;
462+
}
463+
326464
bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
327465
if (skipFunction(MF.getFunction()))
328466
return false;
@@ -357,11 +495,12 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
357495
}
358496

359497
for (MachineBasicBlock &MBB : MF) {
360-
for (MachineInstr &MI : MBB) {
498+
for (MachineInstr &MI : make_early_inc_range(MBB)) {
361499
Changed |= convertToVLMAX(MI);
362500
Changed |= convertToUnmasked(MI);
363501
Changed |= convertToWholeRegister(MI);
364502
Changed |= convertVMergeToVMv(MI);
503+
Changed |= foldVMV_V_V(MI);
365504
}
366505
}
367506

0 commit comments

Comments
 (0)