Skip to content

[RISCV][VLOpt] Consolidate EMUL=SEW/EEW*LMUL logic [NFC] #122021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 8, 2025

Conversation

preames
Copy link
Collaborator

@preames preames commented Jan 8, 2025

All but one of the cases in tree today have EMUL=SEW/EEW*LMUL. Repeating this each time is verbose and introduces oppurtunity for error. (For instance, the comment associated with vwmul.vv was out of sync with the code for same.)

Introduce getOperandLog2EEW and move most complexity to it. Then introduce getOperandInfo as a wrapper around previous, and special case the one case which requires it.

All but one of the cases in tree today have EMUL=SEW/EEW*LMUL.  Repeating
this each time is verbose and introduces oppurtunity for error.  (For instance,
the comment associated with vwmul.vv was out of sync with the code for same.)

Introduce getOperandLog2EEW and move most complexity to it.  Then introduce
getOperandInfo as a wrapper around previous, and special case the one case
which requires it.
@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

All but one of the cases in tree today have EMUL=SEW/EEW*LMUL. Repeating this each time is verbose and introduces oppurtunity for error. (For instance, the comment associated with vwmul.vv was out of sync with the code for same.)

Introduce getOperandLog2EEW and move most complexity to it. Then introduce getOperandInfo as a wrapper around previous, and special case the one case which requires it.


Full diff: https://github.com/llvm/llvm-project/pull/122021.diff

1 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp (+103-84)
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index a1b078fa678d65..4f25272ad0307b 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -168,24 +168,22 @@ getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
 } // end namespace RISCVVType
 } // end namespace llvm
 
-/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
-/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI.
-static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor,
-                                                  const MachineInstr &MI,
-                                                  const MachineOperand &MO) {
-  RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
+/// SEW comes from TSFlags of MI.
+static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
+                                              const MachineInstr &MI,
+                                              const MachineOperand &MO) {
   unsigned MILog2SEW =
       MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
 
   if (MO.getOperandNo() == 0)
-    return OperandInfo(MIVLMul, MILog2SEW);
+    return MILog2SEW;
 
   unsigned MISEW = 1 << MILog2SEW;
   unsigned EEW = MISEW / Factor;
   unsigned Log2EEW = Log2_32(EEW);
 
-  return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
-                     Log2EEW);
+  return Log2EEW;
 }
 
 /// Check whether MO is a mask operand of MI.
@@ -199,18 +197,15 @@ static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
   return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
 }
 
-/// Return the OperandInfo for MO.
-static OperandInfo getOperandInfo(const MachineOperand &MO,
-                                  const MachineRegisterInfo *MRI) {
+static std::optional<unsigned>
+getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
   const MachineInstr &MI = *MO.getParent();
   const RISCVVPseudosTable::PseudoInfo *RVV =
       RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
   assert(RVV && "Could not find MI in PseudoTable");
 
-  // MI has a VLMUL and SEW associated with it. The RVV specification defines
-  // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and
-  // MI.SEW.
-  RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+  // MI has a SEW associated with it. The RVV specification defines
+  // the EEW of each operand and definition in relation to MI.SEW.
   unsigned MILog2SEW =
       MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
 
@@ -221,13 +216,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   // since they must preserve the entire register content.
   if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() &&
       (MO.getReg() != RISCV::NoRegister))
-    return {};
+    return std::nullopt;
 
   bool IsMODef = MO.getOperandNo() == 0;
 
-  // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL
+  // All mask operands have EEW=1
   if (isMaskOperand(MI, MO, MRI))
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+    return 0;
 
   // switch against BaseInstr to reduce number of cases that need to be
   // considered.
@@ -244,63 +239,62 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   // Vector Loads and Stores
   // Vector Unit-Stride Instructions
   // Vector Strided Instructions
-  /// Dest EEW encoded in the instruction and EMUL=(EEW/SEW)*LMUL
+  /// Dest EEW encoded in the instruction
   case RISCV::VLE8_V:
   case RISCV::VSE8_V:
   case RISCV::VLSE8_V:
   case RISCV::VSSE8_V:
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3);
+    return 3;
   case RISCV::VLE16_V:
   case RISCV::VSE16_V:
   case RISCV::VLSE16_V:
   case RISCV::VSSE16_V:
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4);
+    return 4;
   case RISCV::VLE32_V:
   case RISCV::VSE32_V:
   case RISCV::VLSE32_V:
   case RISCV::VSSE32_V:
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5);
+    return 5;
   case RISCV::VLE64_V:
   case RISCV::VSE64_V:
   case RISCV::VLSE64_V:
   case RISCV::VSSE64_V:
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6);
+    return 6;
 
   // Vector Indexed Instructions
   // vs(o|u)xei<eew>.v
-  // Dest/Data (operand 0) EEW=SEW, EMUL=LMUL. Source EEW=<eew> and
-  // EMUL=(EEW/SEW)*LMUL.
+  // Dest/Data (operand 0) EEW=SEW.  Source EEW=<eew>.
   case RISCV::VLUXEI8_V:
   case RISCV::VLOXEI8_V:
   case RISCV::VSUXEI8_V:
   case RISCV::VSOXEI8_V: {
     if (MO.getOperandNo() == 0)
-      return OperandInfo(MIVLMul, MILog2SEW);
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3);
+      return MILog2SEW;
+    return 3;
   }
   case RISCV::VLUXEI16_V:
   case RISCV::VLOXEI16_V:
   case RISCV::VSUXEI16_V:
   case RISCV::VSOXEI16_V: {
     if (MO.getOperandNo() == 0)
-      return OperandInfo(MIVLMul, MILog2SEW);
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4);
+      return MILog2SEW;
+    return 4;
   }
   case RISCV::VLUXEI32_V:
   case RISCV::VLOXEI32_V:
   case RISCV::VSUXEI32_V:
   case RISCV::VSOXEI32_V: {
     if (MO.getOperandNo() == 0)
-      return OperandInfo(MIVLMul, MILog2SEW);
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5);
+      return MILog2SEW;
+    return 5;
   }
   case RISCV::VLUXEI64_V:
   case RISCV::VLOXEI64_V:
   case RISCV::VSUXEI64_V:
   case RISCV::VSOXEI64_V: {
     if (MO.getOperandNo() == 0)
-      return OperandInfo(MIVLMul, MILog2SEW);
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6);
+      return MILog2SEW;
+    return 6;
   }
 
   // Vector Integer Arithmetic Instructions
@@ -314,7 +308,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VRSUB_VX:
   // Vector Bitwise Logical Instructions
   // Vector Single-Width Shift Instructions
-  // EEW=SEW. EMUL=LMUL.
+  // EEW=SEW.
   case RISCV::VAND_VI:
   case RISCV::VAND_VV:
   case RISCV::VAND_VX:
@@ -334,7 +328,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VSRA_VV:
   case RISCV::VSRA_VX:
   // Vector Integer Min/Max Instructions
-  // EEW=SEW. EMUL=LMUL.
+  // EEW=SEW.
   case RISCV::VMINU_VV:
   case RISCV::VMINU_VX:
   case RISCV::VMIN_VV:
@@ -344,7 +338,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VMAX_VV:
   case RISCV::VMAX_VX:
   // Vector Single-Width Integer Multiply Instructions
-  // Source and Dest EEW=SEW and EMUL=LMUL.
+  // Source and Dest EEW=SEW.
   case RISCV::VMUL_VV:
   case RISCV::VMUL_VX:
   case RISCV::VMULH_VV:
@@ -354,7 +348,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VMULHSU_VV:
   case RISCV::VMULHSU_VX:
   // Vector Integer Divide Instructions
-  // EEW=SEW. EMUL=LMUL.
+  // EEW=SEW.
   case RISCV::VDIVU_VV:
   case RISCV::VDIVU_VX:
   case RISCV::VDIV_VV:
@@ -364,7 +358,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VREM_VV:
   case RISCV::VREM_VX:
   // Vector Single-Width Integer Multiply-Add Instructions
-  // EEW=SEW. EMUL=LMUL.
+  // EEW=SEW.
   case RISCV::VMACC_VV:
   case RISCV::VMACC_VX:
   case RISCV::VNMSAC_VV:
@@ -375,8 +369,8 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VNMSUB_VX:
   // Vector Integer Merge Instructions
   // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
-  // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL=
-  // (EEW/SEW)*LMUL. Mask operand is handled before this switch.
+  // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
+  // before this switch.
   case RISCV::VMERGE_VIM:
   case RISCV::VMERGE_VVM:
   case RISCV::VMERGE_VXM:
@@ -389,7 +383,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   // Vector Fixed-Point Arithmetic Instructions
   // Vector Single-Width Saturating Add and Subtract
   // Vector Single-Width Averaging Add and Subtract
-  // EEW=SEW. EMUL=LMUL.
+  // EEW=SEW.
   case RISCV::VMV_V_I:
   case RISCV::VMV_V_V:
   case RISCV::VMV_V_X:
@@ -412,12 +406,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VASUB_VV:
   case RISCV::VASUB_VX:
   // Vector Single-Width Fractional Multiply with Rounding and Saturation
-  // EEW=SEW. EMUL=LMUL. The instruction produces 2*SEW product internally but
+  // EEW=SEW. The instruction produces 2*SEW product internally but
   // saturates to fit into SEW bits.
   case RISCV::VSMUL_VV:
   case RISCV::VSMUL_VX:
   // Vector Single-Width Scaling Shift Instructions
-  // EEW=SEW. EMUL=LMUL.
+  // EEW=SEW.
   case RISCV::VSSRL_VI:
   case RISCV::VSSRL_VV:
   case RISCV::VSSRL_VX:
@@ -427,13 +421,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   // Vector Permutation Instructions
   // Integer Scalar Move Instructions
   // Floating-Point Scalar Move Instructions
-  // EMUL=LMUL. EEW=SEW.
+  // EEW=SEW.
   case RISCV::VMV_X_S:
   case RISCV::VMV_S_X:
   case RISCV::VFMV_F_S:
   case RISCV::VFMV_S_F:
   // Vector Slide Instructions
-  // EMUL=LMUL. EEW=SEW.
+  // EEW=SEW.
   case RISCV::VSLIDEUP_VI:
   case RISCV::VSLIDEUP_VX:
   case RISCV::VSLIDEDOWN_VI:
@@ -443,12 +437,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VSLIDE1DOWN_VX:
   case RISCV::VFSLIDE1DOWN_VF:
   // Vector Register Gather Instructions
-  // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1.
+  // EEW=SEW. For mask operand, EEW=1.
   case RISCV::VRGATHER_VI:
   case RISCV::VRGATHER_VV:
   case RISCV::VRGATHER_VX:
   // Vector Compress Instruction
-  // EMUL=LMUL. EEW=SEW.
+  // EEW=SEW.
   case RISCV::VCOMPRESS_VM:
   // Vector Element Index Instruction
   case RISCV::VID_V:
@@ -495,10 +489,10 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VFCVT_F_X_V:
   // Vector Floating-Point Merge Instruction
   case RISCV::VFMERGE_VFM:
-    return OperandInfo(MIVLMul, MILog2SEW);
+    return MILog2SEW;
 
   // Vector Widening Integer Add/Subtract
-  // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL.
+  // Def uses EEW=2*SEW . Operands use EEW=SEW.
   case RISCV::VWADDU_VV:
   case RISCV::VWADDU_VX:
   case RISCV::VWSUBU_VV:
@@ -509,7 +503,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VWSUB_VX:
   case RISCV::VWSLL_VI:
   // Vector Widening Integer Multiply Instructions
-  // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW.
+  // Destination EEW=2*SEW. Source EEW=SEW.
   case RISCV::VWMUL_VV:
   case RISCV::VWMUL_VX:
   case RISCV::VWMULSU_VV:
@@ -517,7 +511,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VWMULU_VV:
   case RISCV::VWMULU_VX:
   // Vector Widening Integer Multiply-Add Instructions
-  // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL.
+  // Destination EEW=2*SEW. Source EEW=SEW.
   // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
   // is then added to the 2*SEW-bit Dest. These instructions never have a
   // passthru operand.
@@ -538,7 +532,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VFWNMSAC_VF:
   case RISCV::VFWNMSAC_VV:
   // Vector Widening Floating-Point Add/Subtract Instructions
-  // Dest EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL.
+  // Dest EEW=2*SEW. Source EEW=SEW.
   case RISCV::VFWADD_VV:
   case RISCV::VFWADD_VF:
   case RISCV::VFWSUB_VV:
@@ -555,11 +549,10 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VFWCVT_F_X_V:
   case RISCV::VFWCVT_F_F_V: {
     unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW;
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
-                       Log2EEW);
+    return Log2EEW;
   }
 
-  // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL
+  // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
   case RISCV::VWADDU_WV:
   case RISCV::VWADDU_WX:
   case RISCV::VWSUBU_WV:
@@ -576,24 +569,22 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
     bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
     bool TwoTimes = IsMODef || IsOp1;
     unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW;
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
-                       Log2EEW);
+    return Log2EEW;
   }
 
   // Vector Integer Extension
   case RISCV::VZEXT_VF2:
   case RISCV::VSEXT_VF2:
-    return getIntegerExtensionOperandInfo(2, MI, MO);
+    return getIntegerExtensionOperandEEW(2, MI, MO);
   case RISCV::VZEXT_VF4:
   case RISCV::VSEXT_VF4:
-    return getIntegerExtensionOperandInfo(4, MI, MO);
+    return getIntegerExtensionOperandEEW(4, MI, MO);
   case RISCV::VZEXT_VF8:
   case RISCV::VSEXT_VF8:
-    return getIntegerExtensionOperandInfo(8, MI, MO);
+    return getIntegerExtensionOperandEEW(8, MI, MO);
 
   // Vector Narrowing Integer Right Shift Instructions
-  // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has
-  // EEW=SEW EMUL=LMUL.
+  // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 hasEEW=SEW
   case RISCV::VNSRL_WX:
   case RISCV::VNSRL_WI:
   case RISCV::VNSRL_WV:
@@ -601,7 +592,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VNSRA_WV:
   case RISCV::VNSRA_WX:
   // Vector Narrowing Fixed-Point Clip Instructions
-  // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL
+  // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
   case RISCV::VNCLIPU_WI:
   case RISCV::VNCLIPU_WV:
   case RISCV::VNCLIPU_WX:
@@ -620,8 +611,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
     bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
     bool TwoTimes = IsOp1;
     unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW;
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
-                       Log2EEW);
+    return Log2EEW;
   }
 
   // Vector Mask Instructions
@@ -629,7 +619,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   // vmsbf.m set-before-first mask bit
   // vmsif.m set-including-first mask bit
   // vmsof.m set-only-first mask bit
-  // EEW=1 and EMUL=(EEW/SEW)*LMUL
+  // EEW=1
   // We handle the cases when operand is a v0 mask operand above the switch,
   // but these instructions may use non-v0 mask operands and need to be handled
   // specifically.
@@ -644,20 +634,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VMSBF_M:
   case RISCV::VMSIF_M:
   case RISCV::VMSOF_M: {
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+    return 0;
   }
 
   // Vector Iota Instruction
-  // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL=
-  // (EEW/SEW)*LMUL. Mask operand is not handled before this switch.
+  // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
+  // before this switch.
   case RISCV::VIOTA_M: {
     if (IsMODef || MO.getOperandNo() == 1)
-      return OperandInfo(MIVLMul, MILog2SEW);
-    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+      return MILog2SEW;
+    return 0;
   }
 
   // Vector Integer Compare Instructions
-  // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL.
+  // Dest EEW=1. Source EEW=SEW.
   case RISCV::VMSEQ_VI:
   case RISCV::VMSEQ_VV:
   case RISCV::VMSEQ_VX:
@@ -679,21 +669,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VMSGT_VI:
   case RISCV::VMSGT_VX:
   // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
-  // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. Mask
-  // source operand handled above this switch.
+  // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
   case RISCV::VMADC_VIM:
   case RISCV::VMADC_VVM:
   case RISCV::VMADC_VXM:
   case RISCV::VMSBC_VVM:
   case RISCV::VMSBC_VXM:
-  // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL.
+  // Dest EEW=1. Source EEW=SEW.
   case RISCV::VMADC_VV:
   case RISCV::VMADC_VI:
   case RISCV::VMADC_VX:
   case RISCV::VMSBC_VV:
   case RISCV::VMSBC_VX:
   // 13.13. Vector Floating-Point Compare Instructions
-  // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW EMUL=LMUL.
+  // Dest EEW=1. Source EEW=SEW
   case RISCV::VMFEQ_VF:
   case RISCV::VMFEQ_VV:
   case RISCV::VMFNE_VF:
@@ -705,14 +694,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VMFGT_VF:
   case RISCV::VMFGE_VF: {
     if (IsMODef)
-      return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
-    return OperandInfo(MIVLMul, MILog2SEW);
+      return 0;
+    return MILog2SEW;
   }
 
   // Vector Reduction Operations
   // Vector Single-Width Integer Reduction Instructions
-  // The Dest and VS1 only read element 0 of the vector register. Return just
-  // the EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
   case RISCV::VREDAND_VS:
   case RISCV::VREDMAX_VS:
   case RISCV::VREDMAXU_VS:
@@ -721,9 +708,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   case RISCV::VREDOR_VS:
   case RISCV::VREDSUM_VS:
   case RISCV::VREDXOR_VS: {
-    if (MO.getOperandNo() == 2)
-      return OperandInfo(MIVLMul, MILog2SEW);
-    return OperandInfo(MILog2SEW);
+    return MILog2SEW;
   }
 
   default:
@@ -731,6 +716,40 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
   }
 }
 
+static OperandInfo getOperandInfo(const MachineOperand &MO,
+                                  const MachineRegisterInfo *MRI) {
+  const MachineInstr &MI = *MO.getParent();
+  const RISCVVPseudosTable::PseudoInfo *RVV =
+      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
+  assert(RVV && "Could not find MI in PseudoTable");
+
+  std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
+  if (!Log2EEW)
+    return {};
+
+  switch (RVV->BaseInstr) {
+  // Vector Reduction Operations
+  // Vector Single-Width Integer Reduction Instructions
+  // The Dest and VS1 only read element 0 of the vector register. Return just
+  // the EEW for these.
+  case RISCV::VREDAND_VS:
+  case RISCV::VREDMAX_VS:
+  case RISCV::VREDMAXU_VS:
+  case RISCV::VREDMIN_VS:
+  case RISCV::VREDMINU_VS:
+  case RISCV::VREDOR_VS:
+  case RISCV::VREDSUM_VS:
+  case RISCV::VREDXOR_VS:
+    if (MO.getOperandNo() != 2)
+      return OperandInfo(*Log2EEW);
+    break;
+  };
+
+  // All others have EMUL=EEW/SEW*LMUL
+  return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI),
+                     *Log2EEW);
+}
+
 /// Return true if this optimization should consider MI for VL reduction. This
 /// white-list approach simplifies this optimization for instructions that may
 /// have more complex semantics with relation to how it uses VL.

@preames
Copy link
Collaborator Author

preames commented Jan 8, 2025

@michaelmaitland - Not sure if you have planned changes which are incompatible with this API. If so, feel free to speak up. This isn't strongly motivated except by making the code slightly easier to read.

@michaelmaitland
Copy link
Contributor

I think this is a good direction, considering what additional work I have planned. I will give a more through review tomorrow.

case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is non-NFC so not for this PR. Do VMV_S_X and VMV_X_S need to be handled here?

And eventually floating point reductions too? I think we can consolidate this with isVectorOpUsedAsScalarOp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Recording this just for future reference, no action needed.)

Hm, the more I think about this, the more I'm wondering if the existing code is correct. This codepath causes us to ignore the EMUL entirely when checking whether a user is compatible with a def. I'm not sure that's right.

The bit I'm worried about is not the result of the reduction (i.e. the single element), it's the passthru operand. If the instruction is TU, we have to preserve the tail elements through the reduction. If the reduction vector source is also equal to the passthru, we can end up with a situation where reducing VL changes the tail elements of the definition and that's visible through the reduction.

Ah, but we're okay. We'd reject the use which was the reduction pass thru entirely. So as a result, we only reduce the VL of def if it can't reach the passthru operand.

I think you are right that all "scalar in vector" arguments probably should have consistent handling, but I think the current code is merely conservative, not incorrect for these.

RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");

std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an observation, it's a shame we have to do another pseudo table lookup. We could plumb RVV->BaseInstr through to getOperandLog2EEW but that introduces an invariant.

Copy link
Contributor

@michaelmaitland michaelmaitland left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 983a957 into llvm:main Jan 8, 2025
5 of 7 checks passed
@preames preames deleted the pr-riscv-vlopt-cleanup-2 branch January 8, 2025 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants