Skip to content

[RISCV] Introduce local peephole to reduce VLs based on demanded VL #104689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 99 additions & 50 deletions llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
}

private:
bool tryToReduceVL(MachineInstr &MI) const;
bool convertToVLMAX(MachineInstr &MI) const;
bool convertToWholeRegister(MachineInstr &MI) const;
bool convertToUnmasked(MachineInstr &MI) const;
Expand All @@ -81,6 +82,96 @@ char RISCVVectorPeephole::ID = 0;
INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
false)

/// Given two VL operands, do we know that LHS <= RHS?
static bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
LHS.getReg() == RHS.getReg())
return true;
if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
return true;
if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
return false;
if (!LHS.isImm() || !RHS.isImm())
return false;
return LHS.getImm() <= RHS.getImm();
}

static unsigned getSEWLMULRatio(const MachineInstr &MI) {
RISCVII::VLMUL LMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
return RISCVVType::getSEWLMULRatio(1 << Log2SEW, LMUL);
}

// Attempt to reduce the VL of an instruction whose sole use is feeding a
// instruction with a narrower VL. This currently works backwards from the
// user instruction (which might have a smaller VL).
bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
// Note that the goal here is a bit multifaceted.
// 1) For store's reducing the VL of the value being stored may help to
// reduce VL toggles. This is somewhat of an artifact of the fact we
// promote arithmetic instructions but VL predicate stores.
// 2) For vmv.v.v reducing VL eagerly on the source instruction allows us
// to share code with the foldVMV_V_V transform below.
//
// Note that to the best of our knowledge, reducing VL is generally not
// a significant win on real hardware unless we can also reduce LMUL which
// this code doesn't try to do.
//
// TODO: We can handle a bunch more instructions here, and probably
// recurse backwards through operands too.
unsigned SrcIdx = 0;
switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
default:
return false;
case RISCV::VSE8_V:
case RISCV::VSE16_V:
case RISCV::VSE32_V:
case RISCV::VSE64_V:
break;
case RISCV::VMV_V_V:
SrcIdx = 2;
break;
}

MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
if (VL.isImm() && VL.getImm() == RISCV::VLMaxSentinel)
return false;

Register SrcReg = MI.getOperand(SrcIdx).getReg();
// Note: one *use*, not one *user*.
if (!MRI->hasOneUse(SrcReg))
return false;

MachineInstr *Src = MRI->getVRegDef(SrcReg);
if (!Src || Src->hasUnmodeledSideEffects() ||
Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 ||
!RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
!RISCVII::hasSEWOp(Src->getDesc().TSFlags))
return false;

// Src needs to have the same VLMAX as MI
if (getSEWLMULRatio(MI) != getSEWLMULRatio(*Src))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is sufficient. You need to know that the EEW and EMUL of Src's destination is identical to the EEW and EMUL of MI's operand. If there was a bitcast between them in IR, it's possible for them to be different. Bitcasts don't codegen to anything.

Copy link
Contributor

@lukel97 lukel97 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, its possible to have the same VLMAX but different EEWs at fractional LMULs. I think this was relying on the register classes to enforce that the EMULs are the same.

define <vscale x 1 x i8> @unfoldable_mismatched_sew_2(<vscale x 1 x i8> %passthru, <vscale x 1 x i16> %x, <vscale x 1 x i16> %y, i64 %avl) {
  %a = call <vscale x 1 x i16> @llvm.riscv.vadd.nxv1i16.nxv1i16(<vscale x 1 x i16> poison, <vscale x 1 x i16> %x, <vscale x 1 x i16> %y, i64 %avl)
  %a.bitcast = bitcast <vscale x 1 x i16> %a to <vscale x 2 x i8>
  %a.insert = call <vscale x 1 x i8> @llvm.vector.extract(<vscale x 2 x i8> %a.bitcast, i64 0)
  %b = call <vscale x 1 x i8> @llvm.riscv.vmv.v.v.nx1i8(<vscale x 1 x i8> %passthru, <vscale x 1 x i8> %a.insert, i64 %avl)
  ret <vscale x 1 x i8> %b
}

However looks like we don't miscompile this by a fluke, there's a trivial copy that gets in the way of the fold:

  %4:vr = PseudoVADD_VV_MF4 $noreg(tied-def 0), %1:vr, %2:vr, %3:gprnox0, 4, 0
  %5:vr = COPY %4:vr
  %6:vr = PseudoVMV_V_V_MF8 %0:vr(tied-def 0), killed %5:vr, %3:gprnox0, 3, 0

Last time I checked we didn't encode the EEW of the operands anywhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually added a destination EEW TSFlag to the pseudos for this, the patch is lying around my git stash somewhere. I thought I needed this information initially but ditched it for the VLMAX approach. I can pull this out again, but if someone has a better idea I'd be open to it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to wrap my head around the discussed problem here and am struggling.

The original purpose of this check was to guard against the case where the VL on MI refers to a different number of bytes than the same VL on Src's dest. (#100367 (comment))

A key point is that we are only reducing the VL here, not increasing it. As a result, by assumption SrcVL <= SrcVLMAX, and VL = MIVLMAX. As a result, I don't think that LMUL is really relevant here. Any resulting VL <= min(SrcVLMAX, and MILMAX) by assumption.

Given that, I don't see where the VLMAX discussion comes into this. I agree that we need to check Src's Dest EEW instead of Src's SEW, but other than that change, I don't get the rest of the discussion here. Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we need to check Src's Dest EEW instead of Src's SEW,

What do you mean by "instead of" here? We aren't comparing SEW directly right now. Are you talking about for the ratio or a direct comparison?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of in the since that our ratio test was in terms of Src's SEW, not Src's Dest's EEW.

We added the ratio check to prevent the case of unequal SEWs raised in the original review. In retrospect, using the ratio was never the right approach since we don't actually care about the LMUL here and the ratio confused concepts.

return false;

bool ActiveElementsAffectResult = RISCVII::activeElementsAffectResult(
TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags);
if (ActiveElementsAffectResult || Src->mayRaiseFPException())
return false;

MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
if (VL.isIdenticalTo(SrcVL) || !isVLKnownLE(VL, SrcVL))
return false;

if (VL.isImm())
SrcVL.ChangeToImmediate(VL.getImm());
else if (VL.isReg())
SrcVL.ChangeToRegister(VL.getReg(), false);

// TODO: For instructions with a passthru, we could clear the passthru
// and tail policy since we've just proven the tail is not demanded.
return true;
}

/// Check if an operand is an immediate or a materialized ADDI $x0, imm.
std::optional<unsigned>
RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
Expand Down Expand Up @@ -325,22 +416,6 @@ bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
return true;
}

/// Given two VL operands, returns the one known to be the smallest or nullptr
/// if unknown.
static const MachineOperand *getKnownMinVL(const MachineOperand *LHS,
const MachineOperand *RHS) {
if (LHS->isReg() && RHS->isReg() && LHS->getReg().isVirtual() &&
LHS->getReg() == RHS->getReg())
return LHS;
if (LHS->isImm() && LHS->getImm() == RISCV::VLMaxSentinel)
return RHS;
if (RHS->isImm() && RHS->getImm() == RISCV::VLMaxSentinel)
return LHS;
if (!LHS->isImm() || !RHS->isImm())
return nullptr;
return LHS->getImm() <= RHS->getImm() ? LHS : RHS;
}

/// Check if it's safe to move From down to To, checking that no physical
/// registers are clobbered.
static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
Expand All @@ -362,21 +437,16 @@ static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
return From.isSafeToMove(SawStore);
}

static unsigned getSEWLMULRatio(const MachineInstr &MI) {
RISCVII::VLMUL LMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
return RISCVVType::getSEWLMULRatio(1 << Log2SEW, LMUL);
}

/// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
/// into it.
///
/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
/// (where %vl1 <= %vl2, see related tryToReduceVL)
///
/// ->
///
/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, min(vl1, vl2), sew, policy
/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, vl1, sew, policy
bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
return false;
Expand Down Expand Up @@ -404,33 +474,16 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
SrcPassthru.getReg() != Passthru.getReg())
return false;

// Because Src and MI have the same passthru, we can use either AVL as long as
// it's the smaller of the two.
//
// (src pt, ..., vl=5) x x x x x|. . .
// (vmv.v.v pt, src, vl=3) x x x|. . . . .
// ->
// (src pt, ..., vl=3) x x x|. . . . .
//
// (src pt, ..., vl=3) x x x|. . . . .
// (vmv.v.v pt, src, vl=6) x x x . . .|. .
// ->
// (src pt, ..., vl=3) x x x|. . . . .
// Src VL will have already been reduced if legal (see tryToReduceVL),
// so we don't need to handle a smaller source VL here. However, the
// user's VL may be larger
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
const MachineOperand *MinVL = getKnownMinVL(&MI.getOperand(3), &SrcVL);
if (!MinVL)
return false;

bool VLChanged = !MinVL->isIdenticalTo(SrcVL);
bool ActiveElementsAffectResult = RISCVII::activeElementsAffectResult(
TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags);

if (VLChanged && (ActiveElementsAffectResult || Src->mayRaiseFPException()))
Comment on lines -421 to -428
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very satisfying to see this generalized. Nit, could you update the comment above foldVMV_V_V to reflect that Src's VL doesn't change anymore and that instead we check for MI.VL >= Src.VL

if (!isVLKnownLE(SrcVL, MI.getOperand(3)))
return false;

// If Src ends up using MI's passthru/VL, move it so it can access it.
// TODO: We don't need to do this if they already dominate Src.
if (!SrcVL.isIdenticalTo(*MinVL) || !SrcPassthru.isIdenticalTo(Passthru)) {
if (!SrcPassthru.isIdenticalTo(Passthru)) {
if (!isSafeToMove(*Src, MI))
return false;
Src->moveBefore(&MI);
Expand All @@ -445,11 +498,6 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
*Src->getParent()->getParent()));
}

if (MinVL->isImm())
SrcVL.ChangeToImmediate(MinVL->getImm());
else if (MinVL->isReg())
SrcVL.ChangeToRegister(MinVL->getReg(), false);

// Use a conservative tu,mu policy, RISCVInsertVSETVLI will relax it if
// passthru is undef.
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()))
Expand Down Expand Up @@ -498,6 +546,7 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : make_early_inc_range(MBB)) {
Changed |= convertToVLMAX(MI);
Changed |= tryToReduceVL(MI);
Changed |= convertToUnmasked(MI);
Changed |= convertToWholeRegister(MI);
Changed |= convertVMergeToVMv(MI);
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/RISCV/rvv/fixed-vectors-abs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ define void @abs_v6i16(ptr %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; CHECK-NEXT: vrsub.vi v9, v8, 0
; CHECK-NEXT: vmax.vv v8, v8, v9
; CHECK-NEXT: vsetivli zero, 6, e16, m1, ta, ma
; CHECK-NEXT: vmax.vv v8, v8, v9
; CHECK-NEXT: vse16.v v8, (a0)
; CHECK-NEXT: ret
%a = load <6 x i16>, ptr %x
Expand Down
Loading
Loading