Skip to content

Commit aba3476

Browse files
authored
[RISCV] Move vmv.v.v peephole from SelectionDAG to RISCVVectorPeephole (#100367)
This is split off from #71764, and moves only the vmv.v.v part of performCombineVMergeAndVOps to work on MachineInstrs. In retrospect trying to handle PseudoVMV_V_V and PseudoVMERGE_VVM in the same function makes the code quite hard to read, so this just does it in a separate peephole. This turns out to be simpler since for PseudoVMV_V_V we don't need to convert the Src instruction to a masked variant, and we don't need to create a fake all ones mask.
1 parent 3f18a0a commit aba3476

File tree

2 files changed

+154
-71
lines changed

2 files changed

+154
-71
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3708,32 +3708,6 @@ static bool IsVMerge(SDNode *N) {
37083708
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMERGE_VVM;
37093709
}
37103710

3711-
static bool IsVMv(SDNode *N) {
3712-
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMV_V_V;
3713-
}
3714-
3715-
static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
3716-
switch (LMUL) {
3717-
case RISCVII::LMUL_F8:
3718-
return RISCV::PseudoVMSET_M_B1;
3719-
case RISCVII::LMUL_F4:
3720-
return RISCV::PseudoVMSET_M_B2;
3721-
case RISCVII::LMUL_F2:
3722-
return RISCV::PseudoVMSET_M_B4;
3723-
case RISCVII::LMUL_1:
3724-
return RISCV::PseudoVMSET_M_B8;
3725-
case RISCVII::LMUL_2:
3726-
return RISCV::PseudoVMSET_M_B16;
3727-
case RISCVII::LMUL_4:
3728-
return RISCV::PseudoVMSET_M_B32;
3729-
case RISCVII::LMUL_8:
3730-
return RISCV::PseudoVMSET_M_B64;
3731-
case RISCVII::LMUL_RESERVED:
3732-
llvm_unreachable("Unexpected LMUL");
3733-
}
3734-
llvm_unreachable("Unknown VLMUL enum");
3735-
}
3736-
37373711
// Try to fold away VMERGE_VVM instructions into their true operands:
37383712
//
37393713
// %true = PseudoVADD_VV ...
@@ -3748,35 +3722,22 @@ static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
37483722
// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
37493723
// mask is all ones.
37503724
//
3751-
// We can also fold a VMV_V_V into its true operand, since it is equivalent to a
3752-
// VMERGE_VVM with an all ones mask.
3753-
//
37543725
// The resulting VL is the minimum of the two VLs.
37553726
//
37563727
// The resulting policy is the effective policy the vmerge would have had,
37573728
// i.e. whether or not it's passthru operand was implicit-def.
37583729
bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
37593730
SDValue Passthru, False, True, VL, Mask, Glue;
3760-
// A vmv.v.v is equivalent to a vmerge with an all-ones mask.
3761-
if (IsVMv(N)) {
3762-
Passthru = N->getOperand(0);
3763-
False = N->getOperand(0);
3764-
True = N->getOperand(1);
3765-
VL = N->getOperand(2);
3766-
// A vmv.v.v won't have a Mask or Glue, instead we'll construct an all-ones
3767-
// mask later below.
3768-
} else {
3769-
assert(IsVMerge(N));
3770-
Passthru = N->getOperand(0);
3771-
False = N->getOperand(1);
3772-
True = N->getOperand(2);
3773-
Mask = N->getOperand(3);
3774-
VL = N->getOperand(4);
3775-
// We always have a glue node for the mask at v0.
3776-
Glue = N->getOperand(N->getNumOperands() - 1);
3777-
}
3778-
assert(!Mask || cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
3779-
assert(!Glue || Glue.getValueType() == MVT::Glue);
3731+
assert(IsVMerge(N));
3732+
Passthru = N->getOperand(0);
3733+
False = N->getOperand(1);
3734+
True = N->getOperand(2);
3735+
Mask = N->getOperand(3);
3736+
VL = N->getOperand(4);
3737+
// We always have a glue node for the mask at v0.
3738+
Glue = N->getOperand(N->getNumOperands() - 1);
3739+
assert(cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
3740+
assert(Glue.getValueType() == MVT::Glue);
37803741

37813742
// If the EEW of True is different from vmerge's SEW, then we can't fold.
37823743
if (True.getSimpleValueType() != N->getSimpleValueType(0))
@@ -3824,7 +3785,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
38243785

38253786
// If True is masked then the vmerge must have either the same mask or an all
38263787
// 1s mask, since we're going to keep the mask from True.
3827-
if (IsMasked && Mask) {
3788+
if (IsMasked) {
38283789
// FIXME: Support mask agnostic True instruction which would have an
38293790
// undef passthru operand.
38303791
SDValue TrueMask =
@@ -3854,11 +3815,9 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
38543815
SmallVector<const SDNode *, 4> LoopWorklist;
38553816
SmallPtrSet<const SDNode *, 16> Visited;
38563817
LoopWorklist.push_back(False.getNode());
3857-
if (Mask)
3858-
LoopWorklist.push_back(Mask.getNode());
3818+
LoopWorklist.push_back(Mask.getNode());
38593819
LoopWorklist.push_back(VL.getNode());
3860-
if (Glue)
3861-
LoopWorklist.push_back(Glue.getNode());
3820+
LoopWorklist.push_back(Glue.getNode());
38623821
if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist))
38633822
return false;
38643823
}
@@ -3919,21 +3878,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
39193878
Glue = True->getOperand(True->getNumOperands() - 1);
39203879
assert(Glue.getValueType() == MVT::Glue);
39213880
}
3922-
// If we end up using the vmerge mask the vmerge is actually a vmv.v.v, create
3923-
// an all-ones mask to use.
3924-
else if (IsVMv(N)) {
3925-
unsigned TSFlags = TII->get(N->getMachineOpcode()).TSFlags;
3926-
unsigned VMSetOpc = GetVMSetForLMul(RISCVII::getLMul(TSFlags));
3927-
ElementCount EC = N->getValueType(0).getVectorElementCount();
3928-
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);
3929-
3930-
SDValue AllOnesMask =
3931-
SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0);
3932-
SDValue MaskCopy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
3933-
RISCV::V0, AllOnesMask, SDValue());
3934-
Mask = CurDAG->getRegister(RISCV::V0, MaskVT);
3935-
Glue = MaskCopy.getValue(1);
3936-
}
39373881

39383882
unsigned MaskedOpc = Info->MaskedPseudo;
39393883
#ifndef NDEBUG
@@ -4012,7 +3956,7 @@ bool RISCVDAGToDAGISel::doPeepholeMergeVVMFold() {
40123956
if (N->use_empty() || !N->isMachineOpcode())
40133957
continue;
40143958

4015-
if (IsVMerge(N) || IsVMv(N))
3959+
if (IsVMerge(N))
40163960
MadeChange |= performCombineVMergeAndVOps(N);
40173961
}
40183962
return MadeChange;

llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
6565
bool convertToWholeRegister(MachineInstr &MI) const;
6666
bool convertToUnmasked(MachineInstr &MI) const;
6767
bool convertVMergeToVMv(MachineInstr &MI) const;
68+
bool foldVMV_V_V(MachineInstr &MI);
6869

6970
bool isAllOnesMask(const MachineInstr *MaskDef) const;
7071
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
@@ -324,6 +325,143 @@ bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
324325
return true;
325326
}
326327

328+
/// Given two VL operands, returns the one known to be the smallest or nullptr
329+
/// if unknown.
330+
static const MachineOperand *getKnownMinVL(const MachineOperand *LHS,
331+
const MachineOperand *RHS) {
332+
if (LHS->isReg() && RHS->isReg() && LHS->getReg().isVirtual() &&
333+
LHS->getReg() == RHS->getReg())
334+
return LHS;
335+
if (LHS->isImm() && LHS->getImm() == RISCV::VLMaxSentinel)
336+
return RHS;
337+
if (RHS->isImm() && RHS->getImm() == RISCV::VLMaxSentinel)
338+
return LHS;
339+
if (!LHS->isImm() || !RHS->isImm())
340+
return nullptr;
341+
return LHS->getImm() <= RHS->getImm() ? LHS : RHS;
342+
}
343+
344+
/// Check if it's safe to move From down to To, checking that no physical
345+
/// registers are clobbered.
346+
static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
347+
assert(From.getParent() == To.getParent() && !From.hasImplicitDef());
348+
SmallVector<Register> PhysUses;
349+
for (const MachineOperand &MO : From.all_uses())
350+
if (MO.getReg().isPhysical())
351+
PhysUses.push_back(MO.getReg());
352+
bool SawStore = false;
353+
for (auto II = From.getIterator(); II != To.getIterator(); II++) {
354+
for (Register PhysReg : PhysUses)
355+
if (II->definesRegister(PhysReg, nullptr))
356+
return false;
357+
if (II->mayStore()) {
358+
SawStore = true;
359+
break;
360+
}
361+
}
362+
return From.isSafeToMove(SawStore);
363+
}
364+
365+
static unsigned getSEWLMULRatio(const MachineInstr &MI) {
366+
RISCVII::VLMUL LMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
367+
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
368+
return RISCVVType::getSEWLMULRatio(1 << Log2SEW, LMUL);
369+
}
370+
371+
/// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
372+
/// into it.
373+
///
374+
/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
375+
/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
376+
///
377+
/// ->
378+
///
379+
/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, min(vl1, vl2), sew, policy
380+
bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
381+
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
382+
return false;
383+
384+
MachineOperand &Passthru = MI.getOperand(1);
385+
386+
if (!MRI->hasOneUse(MI.getOperand(2).getReg()))
387+
return false;
388+
389+
MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
390+
if (!Src || Src->hasUnmodeledSideEffects() ||
391+
Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 ||
392+
!RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) ||
393+
!RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
394+
!RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags))
395+
return false;
396+
397+
// Src needs to have the same VLMAX as MI
398+
if (getSEWLMULRatio(MI) != getSEWLMULRatio(*Src))
399+
return false;
400+
401+
// Src needs to have the same passthru as VMV_V_V
402+
MachineOperand &SrcPassthru = Src->getOperand(1);
403+
if (SrcPassthru.getReg() != RISCV::NoRegister &&
404+
SrcPassthru.getReg() != Passthru.getReg())
405+
return false;
406+
407+
// Because Src and MI have the same passthru, we can use either AVL as long as
408+
// it's the smaller of the two.
409+
//
410+
// (src pt, ..., vl=5) x x x x x|. . .
411+
// (vmv.v.v pt, src, vl=3) x x x|. . . . .
412+
// ->
413+
// (src pt, ..., vl=3) x x x|. . . . .
414+
//
415+
// (src pt, ..., vl=3) x x x|. . . . .
416+
// (vmv.v.v pt, src, vl=6) x x x . . .|. .
417+
// ->
418+
// (src pt, ..., vl=3) x x x|. . . . .
419+
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
420+
const MachineOperand *MinVL = getKnownMinVL(&MI.getOperand(3), &SrcVL);
421+
if (!MinVL)
422+
return false;
423+
424+
bool VLChanged = !MinVL->isIdenticalTo(SrcVL);
425+
bool ActiveElementsAffectResult = RISCVII::activeElementsAffectResult(
426+
TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags);
427+
428+
if (VLChanged && (ActiveElementsAffectResult || Src->mayRaiseFPException()))
429+
return false;
430+
431+
// If Src ends up using MI's passthru/VL, move it so it can access it.
432+
// TODO: We don't need to do this if they already dominate Src.
433+
if (!SrcVL.isIdenticalTo(*MinVL) || !SrcPassthru.isIdenticalTo(Passthru)) {
434+
if (!isSafeToMove(*Src, MI))
435+
return false;
436+
Src->moveBefore(&MI);
437+
}
438+
439+
if (SrcPassthru.getReg() != Passthru.getReg()) {
440+
SrcPassthru.setReg(Passthru.getReg());
441+
// If Src is masked then its passthru needs to be in VRNoV0.
442+
if (Passthru.getReg() != RISCV::NoRegister)
443+
MRI->constrainRegClass(Passthru.getReg(),
444+
TII->getRegClass(Src->getDesc(), 1, TRI,
445+
*Src->getParent()->getParent()));
446+
}
447+
448+
if (MinVL->isImm())
449+
SrcVL.ChangeToImmediate(MinVL->getImm());
450+
else if (MinVL->isReg())
451+
SrcVL.ChangeToRegister(MinVL->getReg(), false);
452+
453+
// Use a conservative tu,mu policy, RISCVInsertVSETVLI will relax it if
454+
// passthru is undef.
455+
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()))
456+
.setImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED);
457+
458+
MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg());
459+
MI.eraseFromParent();
460+
V0Defs.erase(&MI);
461+
462+
return true;
463+
}
464+
327465
bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
328466
if (skipFunction(MF.getFunction()))
329467
return false;
@@ -358,11 +496,12 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
358496
}
359497

360498
for (MachineBasicBlock &MBB : MF) {
361-
for (MachineInstr &MI : MBB) {
499+
for (MachineInstr &MI : make_early_inc_range(MBB)) {
362500
Changed |= convertToVLMAX(MI);
363501
Changed |= convertToUnmasked(MI);
364502
Changed |= convertToWholeRegister(MI);
365503
Changed |= convertVMergeToVMv(MI);
504+
Changed |= foldVMV_V_V(MI);
366505
}
367506
}
368507

0 commit comments

Comments
 (0)