Skip to content

[RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions #88307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 25, 2024

Conversation

mshockwave
Copy link
Member

@mshockwave mshockwave commented Apr 10, 2024

This patch covers VADD_VV and VMUL_VV.


This PR is stacked on top of #88306 (specifically, 8e67ee2)

I also put pre-commit test in a separate commit so that it's easier to grasp the differences. I will either squash it or push it before merge.

OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
OPCODE_LMUL_CASE(PseudoVMUL_VV):
OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
OPCODE_LMUL_CASE(PseudoVMULH_VV):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think VMULH(U) is associative

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. It's fixed now.

@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Min-Yih Hsu (mshockwave)

Changes

This patch covers VADD_VV, VMUL_VV, VMULU_VV, and VMULUH_VV.


This PR is stacked on top of #88306 (specifically, 8e67ee2)

I also put pre-commit test in a separate commit so that it's easier to grasp the differences. I will either squash it or push it before merge.


Patch is 29.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88307.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetInstrInfo.h (+11)
  • (modified) llvm/lib/CodeGen/TargetInstrInfo.cpp (+100-45)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.cpp (+224-4)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.h (+14)
  • (added) llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll (+254)
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index 9fd0ebe6956fbe..82c952b227557d 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -30,6 +30,7 @@
 #include "llvm/MC/MCInstrInfo.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <array>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -1268,12 +1269,22 @@ class TargetInstrInfo : public MCInstrInfo {
     return true;
   }
 
+  /// The returned array encodes the operand index for each parameter because
+  /// the operands may be commuted; the operand indices for associative
+  /// operations might also be target-specific. Each element specifies the index
+  /// of {Prev, A, B, X, Y}.
+  virtual void
+  getReassociateOperandIdx(const MachineInstr &Root,
+                           MachineCombinerPattern Pattern,
+                           std::array<unsigned, 5> &OperandIndices) const;
+
   /// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
   /// reduce critical path length.
   void reassociateOps(MachineInstr &Root, MachineInstr &Prev,
                       MachineCombinerPattern Pattern,
                       SmallVectorImpl<MachineInstr *> &InsInstrs,
                       SmallVectorImpl<MachineInstr *> &DelInstrs,
+                      ArrayRef<unsigned> OperandIndices,
                       DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
 
   /// Reassociation of some instructions requires inverse operations (e.g.
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp
index 9fbd516acea8e1..488922e3c1b720 100644
--- a/llvm/lib/CodeGen/TargetInstrInfo.cpp
+++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp
@@ -1051,13 +1051,34 @@ static std::pair<bool, bool> mustSwapOperands(MachineCombinerPattern Pattern) {
   }
 }
 
+void TargetInstrInfo::getReassociateOperandIdx(
+    const MachineInstr &Root, MachineCombinerPattern Pattern,
+    std::array<unsigned, 5> &OperandIndices) const {
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    OperandIndices = {1, 1, 1, 2, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    OperandIndices = {2, 1, 2, 2, 1};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    OperandIndices = {1, 2, 1, 1, 2};
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    OperandIndices = {2, 2, 2, 1, 1};
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+}
+
 /// Attempt the reassociation transformation to reduce critical path length.
 /// See the above comments before getMachineCombinerPatterns().
 void TargetInstrInfo::reassociateOps(
-    MachineInstr &Root, MachineInstr &Prev,
-    MachineCombinerPattern Pattern,
+    MachineInstr &Root, MachineInstr &Prev, MachineCombinerPattern Pattern,
     SmallVectorImpl<MachineInstr *> &InsInstrs,
     SmallVectorImpl<MachineInstr *> &DelInstrs,
+    ArrayRef<unsigned> OperandIndices,
     DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
   MachineFunction *MF = Root.getMF();
   MachineRegisterInfo &MRI = MF->getRegInfo();
@@ -1065,29 +1086,10 @@ void TargetInstrInfo::reassociateOps(
   const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
   const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
 
-  // This array encodes the operand index for each parameter because the
-  // operands may be commuted. Each row corresponds to a pattern value,
-  // and each column specifies the index of A, B, X, Y.
-  unsigned OpIdx[4][4] = {
-    { 1, 1, 2, 2 },
-    { 1, 2, 2, 1 },
-    { 2, 1, 1, 2 },
-    { 2, 2, 1, 1 }
-  };
-
-  int Row;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
-  case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
-  case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
-  case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
-  default: llvm_unreachable("unexpected MachineCombinerPattern");
-  }
-
-  MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
-  MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
-  MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
-  MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
+  MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
+  MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
+  MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
+  MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
   MachineOperand &OpC = Root.getOperand(0);
 
   Register RegA = OpA.getReg();
@@ -1126,11 +1128,62 @@ void TargetInstrInfo::reassociateOps(
     std::swap(KillX, KillY);
   }
 
+  unsigned PrevFirstOpIdx, PrevSecondOpIdx;
+  unsigned RootFirstOpIdx, RootSecondOpIdx;
+  switch (Pattern) {
+  case MachineCombinerPattern::REASSOC_AX_BY:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_AX_YB:
+    PrevFirstOpIdx = OperandIndices[1];
+    PrevSecondOpIdx = OperandIndices[3];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_BY:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[2];
+    RootSecondOpIdx = OperandIndices[4];
+    break;
+  case MachineCombinerPattern::REASSOC_XA_YB:
+    PrevFirstOpIdx = OperandIndices[3];
+    PrevSecondOpIdx = OperandIndices[1];
+    RootFirstOpIdx = OperandIndices[4];
+    RootSecondOpIdx = OperandIndices[2];
+    break;
+  default:
+    llvm_unreachable("unexpected MachineCombinerPattern");
+  }
+
+  // Basically BuildMI but doesn't add implicit operands by default.
+  auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
+                              const MCInstrDesc &MCID, Register DestReg) {
+    return MachineInstrBuilder(
+               MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
+        .setPCSections(MIMD.getPCSections())
+        .addReg(DestReg, RegState::Define);
+  };
+
   // Create new instructions for insertion.
   MachineInstrBuilder MIB1 =
-      BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
-          .addReg(RegX, getKillRegState(KillX))
-          .addReg(RegY, getKillRegState(KillY));
+      buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
+  for (const auto &MO : Prev.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand we'd already added.
+    if (Idx == 0)
+      continue;
+    if (Idx == PrevFirstOpIdx)
+      MIB1.addReg(RegX, getKillRegState(KillX));
+    else if (Idx == PrevSecondOpIdx)
+      MIB1.addReg(RegY, getKillRegState(KillY));
+    else
+      MIB1.add(MO);
+  }
+  MIB1.copyImplicitOps(Prev);
 
   if (SwapRootOperands) {
     std::swap(RegA, NewVR);
@@ -1138,9 +1191,20 @@ void TargetInstrInfo::reassociateOps(
   }
 
   MachineInstrBuilder MIB2 =
-      BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
-          .addReg(RegA, getKillRegState(KillA))
-          .addReg(NewVR, getKillRegState(KillNewVR));
+      buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
+  for (const auto &MO : Root.explicit_operands()) {
+    unsigned Idx = MO.getOperandNo();
+    // Skip the result operand.
+    if (Idx == 0)
+      continue;
+    if (Idx == RootFirstOpIdx)
+      MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
+    else if (Idx == RootSecondOpIdx)
+      MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
+    else
+      MIB2 = MIB2.add(MO);
+  }
+  MIB2.copyImplicitOps(Root);
 
   // Propagate FP flags from the original instructions.
   // But clear poison-generating flags because those may not be valid now.
@@ -1184,25 +1248,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
   MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
 
   // Select the previous instruction in the sequence based on the input pattern.
-  MachineInstr *Prev = nullptr;
-  switch (Pattern) {
-  case MachineCombinerPattern::REASSOC_AX_BY:
-  case MachineCombinerPattern::REASSOC_XA_BY:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
-    break;
-  case MachineCombinerPattern::REASSOC_AX_YB:
-  case MachineCombinerPattern::REASSOC_XA_YB:
-    Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
-    break;
-  default:
-    llvm_unreachable("Unknown pattern for machine combiner");
-  }
+  std::array<unsigned, 5> OpIdx;
+  getReassociateOperandIdx(Root, Pattern, OpIdx);
+  MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
 
   // Don't reassociate if Prev and Root are in different blocks.
   if (Prev->getParent() != Root.getParent())
     return;
 
-  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
+  reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
+                 InstIdxForVirtReg);
 }
 
 MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 6b75efe684d913..d427842317881c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1575,10 +1575,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
   MachineFunction &MF = *Root.getMF();
 
   for (auto *NewMI : InsInstrs) {
-    assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
-               NewMI->getOpcode(), RISCV::OpName::frm)) ==
-               NewMI->getNumOperands() &&
-           "Instruction has unexpected number of operands");
+    // We'd already added the FRM operand.
+    if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
+            NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
+      continue;
     MachineInstrBuilder MIB(MF, NewMI);
     MIB.add(FRM);
     if (FRM.getImm() == RISCVFPRndMode::DYN)
@@ -1619,8 +1619,184 @@ static bool isFMUL(unsigned Opc) {
   }
 }
 
+bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
+                                                       bool Invert) const {
+#define OPCODE_LMUL_CASE(OPC)                                                  \
+  case RISCV::OPC##_M1:                                                        \
+  case RISCV::OPC##_M2:                                                        \
+  case RISCV::OPC##_M4:                                                        \
+  case RISCV::OPC##_M8:                                                        \
+  case RISCV::OPC##_MF2:                                                       \
+  case RISCV::OPC##_MF4:                                                       \
+  case RISCV::OPC##_MF8
+
+#define OPCODE_LMUL_MASK_CASE(OPC)                                             \
+  case RISCV::OPC##_M1_MASK:                                                   \
+  case RISCV::OPC##_M2_MASK:                                                   \
+  case RISCV::OPC##_M4_MASK:                                                   \
+  case RISCV::OPC##_M8_MASK:                                                   \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+  case RISCV::OPC##_MF8_MASK
+
+  unsigned Opcode = Inst.getOpcode();
+  if (Invert) {
+    if (auto InvOpcode = getInverseOpcode(Opcode))
+      Opcode = *InvOpcode;
+    else
+      return false;
+  }
+
+  // clang-format off
+  switch (Opcode) {
+  default:
+    return false;
+  OPCODE_LMUL_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
+  OPCODE_LMUL_CASE(PseudoVMUL_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
+  OPCODE_LMUL_CASE(PseudoVMULH_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMULH_VV):
+  OPCODE_LMUL_CASE(PseudoVMULHU_VV):
+  OPCODE_LMUL_MASK_CASE(PseudoVMULHU_VV):
+    return true;
+  }
+  // clang-format on
+
+#undef OPCODE_LMUL_MASK_CASE
+#undef OPCODE_LMUL_CASE
+}
+
+bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
+                                             const MachineInstr &MI2) const {
+  if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
+    return false;
+
+  // Make sure vtype operands are also the same.
+  const MCInstrDesc &Desc = get(MI1.getOpcode());
+  const uint64_t TSFlags = Desc.TSFlags;
+
+  auto checkImmOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getImm() == MI2.getOperand(OpIdx).getImm();
+  };
+
+  auto checkRegOperand = [&](unsigned OpIdx) {
+    return MI1.getOperand(OpIdx).getReg() == MI2.getOperand(OpIdx).getReg();
+  };
+
+  // PassThru
+  if (!checkRegOperand(1))
+    return false;
+
+  // SEW
+  if (RISCVII::hasSEWOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
+    return false;
+
+  // Mask
+  // There might be more sophisticated ways to check equality of masks, but
+  // right now we simply check if they're the same virtual register.
+  if (RISCVII::usesMaskPolicy(TSFlags) && !checkRegOperand(4))
+    return false;
+
+  // Tail / Mask policies
+  if (RISCVII::hasVecPolicyOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
+    return false;
+
+  // VL
+  if (RISCVII::hasVLOp(TSFlags)) {
+    unsigned OpIdx = RISCVII::getVLOpNum(Desc);
+    const MachineOperand &Op1 = MI1.getOperand(OpIdx);
+    const MachineOperand &Op2 = MI2.getOperand(OpIdx);
+    if (Op1.getType() != Op2.getType())
+      return false;
+    switch (Op1.getType()) {
+    case MachineOperand::MO_Register:
+      if (Op1.getReg() != Op2.getReg())
+        return false;
+      break;
+    case MachineOperand::MO_Immediate:
+      if (Op1.getImm() != Op2.getImm())
+        return false;
+      break;
+    default:
+      llvm_unreachable("Unrecognized VL operand type");
+    }
+  }
+
+  // Rounding modes
+  if (RISCVII::hasRoundModeOp(TSFlags) &&
+      !checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
+    return false;
+
+  return true;
+}
+
+// Most of our RVV pseudo has passthru operand, so the real operands
+// start from index = 2.
+bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
+                                                  bool &Commuted) const {
+  const MachineBasicBlock *MBB = Inst.getParent();
+  const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+  MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
+  MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());
+
+  // If only one operand has the same or inverse opcode and it's the second
+  // source operand, the operands must be commuted.
+  Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
+             areRVVInstsReassociable(Inst, *MI2);
+  if (Commuted)
+    std::swap(MI1, MI2);
+
+  return areRVVInstsReassociable(Inst, *MI1) &&
+         (isVectorAssociativeAndCommutative(*MI1) ||
+          isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
+         hasReassociableOperands(*MI1, MBB) &&
+         MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
+}
+
+bool RISCVInstrInfo::hasReassociableOperands(
+    const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
+  if (!isVectorAssociativeAndCommutative(Inst) &&
+      !isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+    return TargetInstrInfo::hasReassociableOperands(Inst, MBB);
+
+  const MachineOperand &Op1 = Inst.getOperand(2);
+  const MachineOperand &Op2 = Inst.getOperand(3);
+  const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+  // We need virtual register definitions for the operands that we will
+  // reassociate.
+  MachineInstr *MI1 = nullptr;
+  MachineInstr *MI2 = nullptr;
+  if (Op1.isReg() && Op1.getReg().isVirtual())
+    MI1 = MRI.getUniqueVRegDef(Op1.getReg());
+  if (Op2.isReg() && Op2.getReg().isVirtual())
+    MI2 = MRI.getUniqueVRegDef(Op2.getReg());
+
+  // And at least one operand must be defined in MBB.
+  return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
+}
+
+void RISCVInstrInfo::getReassociateOperandIdx(
+    const MachineInstr &Root, MachineCombinerPattern Pattern,
+    std::array<unsigned, 5> &OperandIndices) const {
+  TargetInstrInfo::getReassociateOperandIdx(Root, Pattern, OperandIndices);
+  if (isVectorAssociativeAndCommutative(Root) ||
+      isVectorAssociativeAndCommutative(Root, /*Invert=*/true)) {
+    // Skip the passthrough operand, so add all indices by one.
+    for (unsigned I = 0; I < 5; ++I)
+      ++OperandIndices[I];
+  }
+}
+
 bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
                                             bool &Commuted) const {
+  if (isVectorAssociativeAndCommutative(Inst) ||
+      isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
+    return hasReassociableVectorSibling(Inst, Commuted);
+
   if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
     return false;
 
@@ -1640,6 +1816,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
 
 bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
                                                  bool Invert) const {
+  if (isVectorAssociativeAndCommutative(Inst, Invert))
+    return true;
+
   unsigned Opc = Inst.getOpcode();
   if (Invert) {
     auto InverseOpcode = getInverseOpcode(Opc);
@@ -1692,6 +1871,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
 
 std::optional<unsigned>
 RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
+#define RVV_OPC_LMUL_CASE(OPC, INV)                                            \
+  case RISCV::OPC##_M1:                                                        \
+    return RISCV::INV##_M1;                                                    \
+  case RISCV::OPC##_M2:                                                        \
+    return RISCV::INV##_M2;                                                    \
+  case RISCV::OPC##_M4:                                                        \
+    return RISCV::INV##_M4;                                                    \
+  case RISCV::OPC##_M8:                                                        \
+    return RISCV::INV##_M8;                                                    \
+  case RISCV::OPC##_MF2:                                                       \
+    return RISCV::INV##_MF2;                                                   \
+  case RISCV::OPC##_MF4:                                                       \
+    return RISCV::INV##_MF4;                                                   \
+  case RISCV::OPC##_MF8:                                                       \
+    return RISCV::INV##_MF8
+
+#define RVV_OPC_LMUL_MASK_CASE(OPC, INV)                                       \
+  case RISCV::OPC##_M1_MASK:                                                   \
+    return RISCV::INV##_M1_MASK;                                               \
+  case RISCV::OPC##_M2_MASK:                                                   \
+    return RISCV::INV##_M2_MASK;                                               \
+  case RISCV::OPC##_M4_MASK:                                                   \
+    return RISCV::INV##_M4_MASK;                                               \
+  case RISCV::OPC##_M8_MASK:                                                   \
+    return RISCV::INV##_M8_MASK;                                               \
+  case RISCV::OPC##_MF2_MASK:                                                  \
+    return RISCV::INV##_MF2_MASK;                                              \
+  case RISCV::OPC##_MF4_MASK:                                                  \
+    return RISCV::INV##_MF4_MASK;                                              \
+  case RISCV::OPC##_MF8_MASK:                                                  \
+    return RISCV::INV##_MF8_MASK
+
   switch (Opcode) {
   default:
     return std::nullopt;
@@ -1715,7 +1926,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
     return RISCV::SUBW;
   case RISCV::SUBW:
     return RISCV::ADDW;
+    // clang-format off
+  RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoV...
[truncated]

return true;
}

// Most of our RVV pseudo has passthru operand, so the real operands
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"pseudos have"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

TargetInstrInfo::getReassociateOperandIdx(Root, Pattern, OperandIndices);
if (isVectorAssociativeAndCommutative(Root) ||
isVectorAssociativeAndCommutative(Root, /*Invert=*/true)) {
// Skip the passthrough operand, so add all indices by one.
Copy link
Collaborator

@topperc topperc Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"increment all indices by one" or "add one to all indices"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@mshockwave
Copy link
Member Author

ping.

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need rebased over dc25a72

@mshockwave
Copy link
Member Author

Does this need rebased over dc25a72

Yes, I had rebased this PR against main since #88306 is now merged.

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than that definesRegister comment, looks correct to me. Just some questions about structuring, not sure if they're possible though so feel free to ignore

@mshockwave
Copy link
Member Author

Other than that definesRegister comment, looks correct to me. Just some questions about structuring, not sure if they're possible though so feel free to ignore

I think you mentioned in several places that whether we can use RISCV::getRVVMCOpcode instead of checking against a list of (reassociable) RVV opcodes. I would like to point out that the table lookup employed by RISCV::getRVVMCOpcode is actually not cheap in terms of time complexity (logarithmic, versus nearly constant time for switch-cases of opcodes) if we want to call it on every instructions. I've seen significant compilation time regression when I tried to do a similar thing in another of my projects (we probably ought to use some sort of cache for RISCV::getRVVMCOpcode and its friends to really solve this problem). That's why I didn't use RISCV::getRVVMCOpcode in the first place, plus there are only two MC opcodes that we support for reassociation and I don't feel like it would justify risking the compilation speed.

};

// PassThru
// TODO: Potentially we can loosen the condition to consider Root (MI1) to be
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name them Root and Prev instead of MI1 and MI2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

mshockwave added a commit that referenced this pull request Apr 25, 2024
@mshockwave
Copy link
Member Author

Forced push to rebase (to account for pre-commit test).

@mshockwave mshockwave merged commit 5f67ce5 into llvm:main Apr 25, 2024
@mshockwave mshockwave deleted the patch/rvv-reassoc branch April 25, 2024 23:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants