Skip to content

[RISCV][VLOPT] Allow propagation even when VL isn't VLMAX #112228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 16, 2024
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4102,3 +4102,17 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
assert(Scaled >= 3 && Scaled <= 6);
return Scaled;
}

/// Given two VL operands, do we know that LHS <= RHS?
bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
LHS.getReg() == RHS.getReg())
return true;
if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
return true;
if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
return false;
if (!LHS.isImm() || !RHS.isImm())
return false;
return LHS.getImm() <= RHS.getImm();
}
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ unsigned getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW);
// Special immediate for AVL operand of V pseudo instructions to indicate VLMax.
static constexpr int64_t VLMaxSentinel = -1LL;

/// Given two VL operands, do we know that LHS <= RHS?
bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS);

// Mask assignments for floating-point
static constexpr unsigned FPMASK_Negative_Infinity = 0x001;
static constexpr unsigned FPMASK_Negative_Normal = 0x002;
Expand Down
84 changes: 61 additions & 23 deletions llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
StringRef getPassName() const override { return PASS_NAME; }

private:
bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI);
bool tryReduceVL(MachineInstr &MI);
bool isCandidate(const MachineInstr &MI) const;
};
Expand Down Expand Up @@ -658,10 +658,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
if (MI.getNumDefs() != 1)
return false;

// If we're not using VLMAX, then we need to be careful whether we are using
// TA/TU when there is a non-undef Passthru. But when we are using VLMAX, it
// does not matter whether we are using TA/TU with a non-undef Passthru, since
// there are no tail elements to be perserved.
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = MI.getOperand(VLOpNum);
if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
if (VLOp.isReg() || VLOp.getImm() != RISCV::VLMaxSentinel) {
// If MI has a non-undef passthru, we will not try to optimize it since
// that requires us to preserve tail elements according to TA/TU.
// Otherwise, The MI has an undef Passthru, so it doesn't matter whether we
// are using TA/TU.
bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
unsigned PassthruOpIdx = MI.getNumExplicitDefs();
if (HasPassthru &&
MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) {
LLVM_DEBUG(
dbgs() << " Not a candidate because it uses non-undef passthru"
" with non-VLMAX VL\n");
return false;
}
}

// If the VL is 1, then there is no need to reduce it. This is an
// optimization, not needed to preserve correctness.
if (VLOp.isImm() && VLOp.getImm() == 1) {
LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n");
return false;
}

// Some instructions that produce vectors have semantics that make it more
// difficult to determine whether the VL can be reduced. For example, some
Expand All @@ -684,7 +708,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
return true;
}

bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
MachineInstr &MI) {
// FIXME: Avoid visiting each user for each time we visit something on the
// worklist, combined with an extra visit from the outer loop. Restructure
Expand Down Expand Up @@ -730,16 +754,17 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,

unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
// Looking for a register VL that isn't X0.
if (!VLOp.isReg() || VLOp.getReg() == RISCV::X0) {
LLVM_DEBUG(dbgs() << " Abort due to user uses X0 as VL.\n");
CanReduceVL = false;
break;
}

// Looking for an immediate or a register VL that isn't X0.
assert(!VLOp.isReg() ||
VLOp.getReg() != RISCV::X0 && "Did not expect X0 VL");

if (!CommonVL) {
CommonVL = VLOp.getReg();
} else if (*CommonVL != VLOp.getReg()) {
CommonVL = &VLOp;
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
} else if (!CommonVL->isIdenticalTo(VLOp)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but this requires all users to have the same VL. One possibility for another PR is to relax this and get the largest VL amongst all users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

// FIXME: This check requires all users to have the same VL. We can relax
// this and get the largest VL amongst all users.
LLVM_DEBUG(dbgs() << " Abort because users have different VL\n");
CanReduceVL = false;
break;
Expand Down Expand Up @@ -776,29 +801,42 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
MachineInstr &MI = *Worklist.pop_back_val();
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");

std::optional<Register> CommonVL;
const MachineOperand *CommonVL = nullptr;
bool CanReduceVL = true;
if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
CanReduceVL = checkUsers(CommonVL, MI);

if (!CanReduceVL || !CommonVL)
continue;

if (!CommonVL->isVirtual()) {
LLVM_DEBUG(
dbgs() << " Abort due to new VL is not virtual register.\n");
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
"Expected VL to be an Imm or virtual Reg");

unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
MachineOperand &VLOp = MI.getOperand(VLOpNum);

if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
continue;
}

const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
if (!MDT->dominates(VLMI, &MI))
continue;
if (CommonVL->isImm()) {
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
<< CommonVL->getImm() << " for " << MI << "\n");
VLOp.ChangeToImmediate(CommonVL->getImm());
} else {
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
if (!MDT->dominates(VLMI, &MI))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a later followup, note that this check can be extended to move the defining instruction in some cases. See ensureDominates in RISCVVectorPeephole.cpp. Just noting this so it doesn't get lost.

continue;
LLVM_DEBUG(
dbgs() << " Reduce VL from " << VLOp << " to "
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
<< " for " << MI << "\n");

// All our checks passed. We can reduce VL.
VLOp.ChangeToRegister(CommonVL->getReg(), false);
}

// All our checks passed. We can reduce VL.
LLVM_DEBUG(dbgs() << " Reducing VL for: " << MI << "\n");
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
MachineOperand &VLOp = MI.getOperand(VLOpNum);
VLOp.ChangeToRegister(*CommonVL, false);
MadeChange = true;

// Now add all inputs to this instruction to the worklist.
Expand Down
22 changes: 4 additions & 18 deletions llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,6 @@ char RISCVVectorPeephole::ID = 0;
INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
false)

/// Given two VL operands, do we know that LHS <= RHS?
static bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
LHS.getReg() == RHS.getReg())
return true;
if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
return true;
if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
return false;
if (!LHS.isImm() || !RHS.isImm())
return false;
return LHS.getImm() <= RHS.getImm();
}

/// Given \p User that has an input operand with EEW=SEW, which uses the dest
/// operand of \p Src with an unknown EEW, return true if their EEWs match.
bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
Expand Down Expand Up @@ -191,7 +177,7 @@ bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
return false;

MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
if (VL.isIdenticalTo(SrcVL) || !isVLKnownLE(VL, SrcVL))
if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL))
return false;

if (!ensureDominates(VL, *Src))
Expand Down Expand Up @@ -580,7 +566,7 @@ bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
MachineOperand &SrcPolicy =
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()));

if (isVLKnownLE(MIVL, SrcVL))
if (RISCV::isVLKnownLE(MIVL, SrcVL))
SrcPolicy.setImm(SrcPolicy.getImm() | RISCVII::TAIL_AGNOSTIC);
}

Expand Down Expand Up @@ -631,7 +617,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
// so we don't need to handle a smaller source VL here. However, the
// user's VL may be larger
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
if (!isVLKnownLE(SrcVL, MI.getOperand(3)))
if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3)))
return false;

// If the new passthru doesn't dominate Src, try to move Src so it does.
Expand All @@ -650,7 +636,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
// If MI was tail agnostic and the VL didn't increase, preserve it.
int64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
if ((MI.getOperand(5).getImm() & RISCVII::TAIL_AGNOSTIC) &&
isVLKnownLE(MI.getOperand(3), SrcVL))
RISCV::isVLKnownLE(MI.getOperand(3), SrcVL))
Policy |= RISCVII::TAIL_AGNOSTIC;
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy);

Expand Down
50 changes: 36 additions & 14 deletions llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,46 @@
declare <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, iXLen)

define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: different_imm_vl_with_ta:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v12
; CHECK-NEXT: vsetivli zero, 4, e32, m2, ta, ma
; CHECK-NEXT: vadd.vv v8, v8, v10
; CHECK-NEXT: ret
; NOVLOPT-LABEL: different_imm_vl_with_ta:
; NOVLOPT: # %bb.0:
; NOVLOPT-NEXT: vsetivli zero, 5, e32, m2, ta, ma
; NOVLOPT-NEXT: vadd.vv v8, v10, v12
; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
; NOVLOPT-NEXT: vadd.vv v8, v8, v10
; NOVLOPT-NEXT: ret
;
; VLOPT-LABEL: different_imm_vl_with_ta:
; VLOPT: # %bb.0:
; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
; VLOPT-NEXT: vadd.vv v8, v10, v12
; VLOPT-NEXT: vadd.vv v8, v8, v10
; VLOPT-NEXT: ret
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 5)
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
ret <vscale x 4 x i32> %w
}

; No benificial to propagate VL since VL is larger in the use side.
define <vscale x 4 x i32> @vlmax_and_imm_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; NOVLOPT-LABEL: vlmax_and_imm_vl_with_ta:
; NOVLOPT: # %bb.0:
; NOVLOPT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
; NOVLOPT-NEXT: vadd.vv v8, v10, v12
; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
; NOVLOPT-NEXT: vadd.vv v8, v8, v10
; NOVLOPT-NEXT: ret
;
; VLOPT-LABEL: vlmax_and_imm_vl_with_ta:
; VLOPT: # %bb.0:
; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
; VLOPT-NEXT: vadd.vv v8, v10, v12
; VLOPT-NEXT: vadd.vv v8, v8, v10
; VLOPT-NEXT: ret
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen -1)
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen 4)
ret <vscale x 4 x i32> %w
}

; Not beneficial to propagate VL since VL is larger in the use side.
define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:
; CHECK: # %bb.0:
Expand All @@ -50,8 +77,7 @@ define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %pass
ret <vscale x 4 x i32> %w
}


; No benificial to propagate VL since VL is already one.
; Not beneficial to propagate VL since VL is already one.
define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
; CHECK-LABEL: different_imm_vl_with_ta_1:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -110,7 +136,3 @@ define <vscale x 4 x i32> @different_imm_vl_with_tu(<vscale x 4 x i32> %passthru
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen 4)
ret <vscale x 4 x i32> %w
}

;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; NOVLOPT: {{.*}}
; VLOPT: {{.*}}