Skip to content

[RISCV][VLOPT] Allow propagation even when VL isn't VLMAX #112228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 16, 2024

Conversation

michaelmaitland
Copy link
Contributor

The original goal of this pass was to focus on vector operations with VLMAX.
However, users often utilize only part of the result, and such usage may come
from the vectorizer.

We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector operation
sequences with different VLs.


Co-authored-by: Kito Cheng [email protected]

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Michael Maitland (michaelmaitland)

Changes

The original goal of this pass was to focus on vector operations with VLMAX.
However, users often utilize only part of the result, and such usage may come
from the vectorizer.

We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector operation
sequences with different VLs.


Co-authored-by: Kito Cheng <[email protected]>


Full diff: https://github.com/llvm/llvm-project/pull/112228.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp (+95-16)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vl-opt.ll (+39-20)
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index eb1f4df4ff7264..2c87826ab7883c 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -31,6 +31,44 @@ using namespace llvm;
 
 namespace {
 
+struct VLInfo {
+  VLInfo(const MachineOperand &VLOp) {
+    IsImm = VLOp.isImm();
+    if (IsImm)
+      Imm = VLOp.getImm();
+    else
+      Reg = VLOp.getReg();
+  }
+
+  Register Reg;
+  int64_t Imm;
+  bool IsImm;
+
+  bool isCompatible(const MachineOperand &VLOp) const {
+    if (IsImm != VLOp.isImm())
+      return false;
+    if (IsImm)
+      return Imm == VLOp.getImm();
+    return Reg == VLOp.getReg();
+  }
+
+  bool isValid() const { return IsImm || Reg.isVirtual(); }
+
+  bool hasBenefit(const MachineOperand &VLOp) const {
+    if (IsImm && Imm == RISCV::VLMaxSentinel)
+      return false;
+
+    if (!IsImm || !VLOp.isImm())
+      return true;
+
+    if (VLOp.getImm() == RISCV::VLMaxSentinel)
+      return true;
+
+    // No benefit if the current VL is already smaller than the new one.
+    return Imm < VLOp.getImm();
+  }
+};
+
 class RISCVVLOptimizer : public MachineFunctionPass {
   const MachineRegisterInfo *MRI;
   const MachineDominatorTree *MDT;
@@ -51,7 +89,7 @@ class RISCVVLOptimizer : public MachineFunctionPass {
   StringRef getPassName() const override { return PASS_NAME; }
 
 private:
-  bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+  bool checkUsers(std::optional<VLInfo> &CommonVL, MachineInstr &MI);
   bool tryReduceVL(MachineInstr &MI);
   bool isCandidate(const MachineInstr &MI) const;
 };
@@ -643,8 +681,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
 
   unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
   const MachineOperand &VLOp = MI.getOperand(VLOpNum);
-  if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
+  if (((VLOp.isImm() && VLOp.getImm() != RISCV::VLMaxSentinel) ||
+       VLOp.isReg())) {
+    bool UseTAPolicy = false;
+    bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
+    if (RISCVII::hasVecPolicyOp(Desc.TSFlags)) {
+      unsigned PolicyOpNum = RISCVII::getVecPolicyOpNum(Desc);
+      const MachineOperand &PolicyOp = MI.getOperand(PolicyOpNum);
+      uint64_t Policy = PolicyOp.getImm();
+      UseTAPolicy = (Policy & RISCVII::TAIL_AGNOSTIC) == RISCVII::TAIL_AGNOSTIC;
+      if (HasPassthru) {
+        unsigned PassthruOpIdx = MI.getNumExplicitDefs();
+        UseTAPolicy = UseTAPolicy || (MI.getOperand(PassthruOpIdx).getReg() ==
+                                      RISCV::NoRegister);
+      }
+    }
+    if (!UseTAPolicy) {
+      LLVM_DEBUG(
+          dbgs() << "  Not a candidate because it uses tail-undisturbed policy"
+                    " with non-VLMAX VL\n");
+      return false;
+    }
+  }
+
+  // If the VL is 1, then there is no need to reduce it.
+  if (VLOp.isImm() && VLOp.getImm() == 1) {
+    LLVM_DEBUG(dbgs() << "  Not a candidate because VL is already 1\n");
     return false;
+  }
 
   // Some instructions that produce vectors have semantics that make it more
   // difficult to determine whether the VL can be reduced. For example, some
@@ -667,7 +731,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
   return true;
 }
 
-bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
+bool RISCVVLOptimizer::checkUsers(std::optional<VLInfo> &CommonVL,
                                   MachineInstr &MI) {
   // FIXME: Avoid visiting each user for each time we visit something on the
   // worklist, combined with an extra visit from the outer loop. Restructure
@@ -721,8 +785,9 @@ bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL,
     }
 
     if (!CommonVL) {
-      CommonVL = VLOp.getReg();
-    } else if (*CommonVL != VLOp.getReg()) {
+      CommonVL = VLInfo(VLOp);
+      LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
+    } else if (!CommonVL->isCompatible(VLOp)) {
       LLVM_DEBUG(dbgs() << "    Abort because users have different VL\n");
       CanReduceVL = false;
       break;
@@ -759,7 +824,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     MachineInstr &MI = *Worklist.pop_back_val();
     LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
 
-    std::optional<Register> CommonVL;
+    std::optional<VLInfo> CommonVL;
     bool CanReduceVL = true;
     if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
       CanReduceVL = checkUsers(CommonVL, MI);
@@ -767,21 +832,35 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
     if (!CanReduceVL || !CommonVL)
       continue;
 
-    if (!CommonVL->isVirtual()) {
-      LLVM_DEBUG(
-          dbgs() << "    Abort due to new VL is not virtual register.\n");
+    if (!CommonVL->isValid()) {
+      LLVM_DEBUG(dbgs() << "    Abort due to common VL is not valid.\n");
       continue;
     }
 
-    const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL);
-    if (!MDT->dominates(VLMI, &MI))
-      continue;
-
-    // All our checks passed. We can reduce VL.
-    LLVM_DEBUG(dbgs() << "    Reducing VL for: " << MI << "\n");
     unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
     MachineOperand &VLOp = MI.getOperand(VLOpNum);
-    VLOp.ChangeToRegister(*CommonVL, false);
+
+    if (!CommonVL->hasBenefit(VLOp)) {
+      LLVM_DEBUG(dbgs() << "    Abort due to no benefit.\n");
+      continue;
+    }
+
+    if (CommonVL->IsImm) {
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << CommonVL->Imm << " for " << MI << "\n");
+      VLOp.ChangeToImmediate(CommonVL->Imm);
+    } else {
+      const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->Reg);
+      if (!MDT->dominates(VLMI, &MI))
+        continue;
+      LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
+                        << printReg(CommonVL->Reg, MRI->getTargetRegisterInfo())
+                        << " for " << MI << "\n");
+
+      // All our checks passed. We can reduce VL.
+      VLOp.ChangeToRegister(CommonVL->Reg, false);
+    }
+
     MadeChange = true;
 
     // Now add all inputs to this instruction to the worklist.
diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
index b03ba076059503..e8ac4efc770484 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
@@ -1,6 +1,12 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
-; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs | FileCheck %s
-; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs | FileCheck %s
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs | \
+; RUN:   FileCheck %s -check-prefixes=CHECK,NOVLOPT
+; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs | \
+; RUN:   FileCheck %s -check-prefixes=CHECK,NOVLOPT
+; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v -riscv-enable-vl-optimizer \
+; RUN:   -verify-machineinstrs | FileCheck %s -check-prefixes=CHECK,VLOPT
+; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v -riscv-enable-vl-optimizer \
+; RUN:   -verify-machineinstrs | FileCheck %s -check-prefixes=CHECK,VLOPT
 
 declare <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, iXLen)
 
@@ -17,7 +23,7 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta(<vscale x 4 x i32> %passthru
   ret <vscale x 4 x i32> %w
 }
 
-; No benificial to propagate VL since VL is larger in the use side.
+; Not beneficial to propagate VL since VL is larger in the use side.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_larger_vl:
 ; CHECK:       # %bb.0:
@@ -32,20 +38,26 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_larger_vl(<vscale x 4 x i32>
 }
 
 define <vscale x 4 x i32> @different_imm_reg_vl_with_ta(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_imm_reg_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v12
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v8, v10
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_imm_reg_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetivli zero, 4, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v12
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v8, v10
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_imm_reg_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v8, v10, v12
+; VLOPT-NEXT:    vadd.vv v8, v8, v10
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen 4)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, iXLen %vl1)
   ret <vscale x 4 x i32> %w
 }
 
-
-; No benificial to propagate VL since VL is already one.
+; Not beneficial to propagate VL since VL is already one.
 define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
 ; CHECK-LABEL: different_imm_vl_with_ta_1:
 ; CHECK:       # %bb.0:
@@ -63,13 +75,20 @@ define <vscale x 4 x i32> @different_imm_vl_with_ta_1(<vscale x 4 x i32> %passth
 ; it's still safe even %vl2 is larger than %vl1, becuase rest of the vector are
 ; undefined value.
 define <vscale x 4 x i32> @different_vl_with_ta(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1, iXLen %vl2) {
-; CHECK-LABEL: different_vl_with_ta:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v10, v8, v10
-; CHECK-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    ret
+; NOVLOPT-LABEL: different_vl_with_ta:
+; NOVLOPT:       # %bb.0:
+; NOVLOPT-NEXT:    vsetvli zero, a0, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v10, v8, v10
+; NOVLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; NOVLOPT-NEXT:    vadd.vv v8, v10, v8
+; NOVLOPT-NEXT:    ret
+;
+; VLOPT-LABEL: different_vl_with_ta:
+; VLOPT:       # %bb.0:
+; VLOPT-NEXT:    vsetvli zero, a1, e32, m2, ta, ma
+; VLOPT-NEXT:    vadd.vv v10, v8, v10
+; VLOPT-NEXT:    vadd.vv v8, v10, v8
+; VLOPT-NEXT:    ret
   %v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, iXLen %vl1)
   %w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a,iXLen %vl2)
   ret <vscale x 4 x i32> %w

Copy link

github-actions bot commented Oct 15, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@michaelmaitland michaelmaitland force-pushed the vlopt-mixed branch 2 times, most recently from 9804943 to 163878e Compare October 15, 2024 15:39
Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 666 to 688
// If the VL is 1, then there is no need to reduce it.
if (VLOp.isImm() && VLOp.getImm() == 1) {
LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n");
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this VL=1 early exit an optimisation or do we need it for correctness?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An optimization, not needed for correctness.

} else if (*CommonVL != VLOp.getReg()) {
CommonVL = &VLOp;
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
} else if (!CommonVL->isIdenticalTo(VLOp)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but this requires all users to have the same VL. One possibility for another PR is to relax this and get the largest VL amongst all users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

@michaelmaitland michaelmaitland force-pushed the vlopt-mixed branch 2 times, most recently from d3c3227 to 018012a Compare October 16, 2024 16:11
MachineOperand &VLOp = MI.getOperand(VLOpNum);

if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort due to no CommonVL not <= VLOp.\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double negative "no" and "not"

;
; VLOPT-LABEL: vdot_lane_s32:
; VLOPT: # %bb.0: # %entry
; VLOPT-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
Copy link
Collaborator

@topperc topperc Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is correct.

The last two vnsrls are treating the vwadd.vv result as 2 64 bit elements but it was originally 4 32 bit elements. So there was a bitcast in there. We can't just change the VL to 2 for the vwadd.vv. That drops 2 of the 32-bit elements.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vdot_lane_s32 is incorrectly optimized

michaelmaitland and others added 9 commits October 16, 2024 11:29
The original goal of this pass was to focus on vector operations with VLMAX.
However, users often utilize only part of the result, and such usage may come
from the vectorizer.

We found that relaxing this constraint can capture more optimization
opportunities, such as non-power-of-2 code generation and vector operation
sequences with different VLs.t show

---------

Co-authored-by: Kito Cheng <[email protected]>
@michaelmaitland
Copy link
Contributor Author

The vdot_lane_s32 is incorrectly optimized

Sorry, I forgot to rebase. I have rebased and regenerated test checks.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@michaelmaitland michaelmaitland merged commit ae68d53 into llvm:main Oct 16, 2024
5 of 7 checks passed
@michaelmaitland michaelmaitland deleted the vlopt-mixed branch October 16, 2024 18:58
VLOp.ChangeToImmediate(CommonVL->getImm());
} else {
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
if (!MDT->dominates(VLMI, &MI))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a later followup, note that this check can be extended to move the defining instruction in some cases. See ensureDominates in RISCVVectorPeephole.cpp. Just noting this so it doesn't get lost.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants