@@ -1626,8 +1626,184 @@ static bool isFMUL(unsigned Opc) {
1626
1626
}
1627
1627
}
1628
1628
1629
+ bool RISCVInstrInfo::isVectorAssociativeAndCommutative (const MachineInstr &Inst,
1630
+ bool Invert) const {
1631
+ #define OPCODE_LMUL_CASE (OPC ) \
1632
+ case RISCV::OPC##_M1: \
1633
+ case RISCV::OPC##_M2: \
1634
+ case RISCV::OPC##_M4: \
1635
+ case RISCV::OPC##_M8: \
1636
+ case RISCV::OPC##_MF2: \
1637
+ case RISCV::OPC##_MF4: \
1638
+ case RISCV::OPC##_MF8
1639
+
1640
+ #define OPCODE_LMUL_MASK_CASE (OPC ) \
1641
+ case RISCV::OPC##_M1_MASK: \
1642
+ case RISCV::OPC##_M2_MASK: \
1643
+ case RISCV::OPC##_M4_MASK: \
1644
+ case RISCV::OPC##_M8_MASK: \
1645
+ case RISCV::OPC##_MF2_MASK: \
1646
+ case RISCV::OPC##_MF4_MASK: \
1647
+ case RISCV::OPC##_MF8_MASK
1648
+
1649
+ unsigned Opcode = Inst.getOpcode ();
1650
+ if (Invert) {
1651
+ if (auto InvOpcode = getInverseOpcode (Opcode))
1652
+ Opcode = *InvOpcode;
1653
+ else
1654
+ return false ;
1655
+ }
1656
+
1657
+ // clang-format off
1658
+ switch (Opcode) {
1659
+ default :
1660
+ return false ;
1661
+ OPCODE_LMUL_CASE (PseudoVADD_VV):
1662
+ OPCODE_LMUL_MASK_CASE (PseudoVADD_VV):
1663
+ OPCODE_LMUL_CASE (PseudoVMUL_VV):
1664
+ OPCODE_LMUL_MASK_CASE (PseudoVMUL_VV):
1665
+ OPCODE_LMUL_CASE (PseudoVMULH_VV):
1666
+ OPCODE_LMUL_MASK_CASE (PseudoVMULH_VV):
1667
+ OPCODE_LMUL_CASE (PseudoVMULHU_VV):
1668
+ OPCODE_LMUL_MASK_CASE (PseudoVMULHU_VV):
1669
+ return true ;
1670
+ }
1671
+ // clang-format on
1672
+
1673
+ #undef OPCODE_LMUL_MASK_CASE
1674
+ #undef OPCODE_LMUL_CASE
1675
+ }
1676
+
1677
+ bool RISCVInstrInfo::areRVVInstsReassociable (const MachineInstr &MI1,
1678
+ const MachineInstr &MI2) const {
1679
+ if (!areOpcodesEqualOrInverse (MI1.getOpcode (), MI2.getOpcode ()))
1680
+ return false ;
1681
+
1682
+ // Make sure vtype operands are also the same.
1683
+ const MCInstrDesc &Desc = get (MI1.getOpcode ());
1684
+ const uint64_t TSFlags = Desc.TSFlags ;
1685
+
1686
+ auto checkImmOperand = [&](unsigned OpIdx) {
1687
+ return MI1.getOperand (OpIdx).getImm () == MI2.getOperand (OpIdx).getImm ();
1688
+ };
1689
+
1690
+ auto checkRegOperand = [&](unsigned OpIdx) {
1691
+ return MI1.getOperand (OpIdx).getReg () == MI2.getOperand (OpIdx).getReg ();
1692
+ };
1693
+
1694
+ // PassThru
1695
+ if (!checkRegOperand (1 ))
1696
+ return false ;
1697
+
1698
+ // SEW
1699
+ if (RISCVII::hasSEWOp (TSFlags) &&
1700
+ !checkImmOperand (RISCVII::getSEWOpNum (Desc)))
1701
+ return false ;
1702
+
1703
+ // Mask
1704
+ // There might be more sophisticated ways to check equality of masks, but
1705
+ // right now we simply check if they're the same virtual register.
1706
+ if (RISCVII::usesMaskPolicy (TSFlags) && !checkRegOperand (4 ))
1707
+ return false ;
1708
+
1709
+ // Tail / Mask policies
1710
+ if (RISCVII::hasVecPolicyOp (TSFlags) &&
1711
+ !checkImmOperand (RISCVII::getVecPolicyOpNum (Desc)))
1712
+ return false ;
1713
+
1714
+ // VL
1715
+ if (RISCVII::hasVLOp (TSFlags)) {
1716
+ unsigned OpIdx = RISCVII::getVLOpNum (Desc);
1717
+ const MachineOperand &Op1 = MI1.getOperand (OpIdx);
1718
+ const MachineOperand &Op2 = MI2.getOperand (OpIdx);
1719
+ if (Op1.getType () != Op2.getType ())
1720
+ return false ;
1721
+ switch (Op1.getType ()) {
1722
+ case MachineOperand::MO_Register:
1723
+ if (Op1.getReg () != Op2.getReg ())
1724
+ return false ;
1725
+ break ;
1726
+ case MachineOperand::MO_Immediate:
1727
+ if (Op1.getImm () != Op2.getImm ())
1728
+ return false ;
1729
+ break ;
1730
+ default :
1731
+ llvm_unreachable (" Unrecognized VL operand type" );
1732
+ }
1733
+ }
1734
+
1735
+ // Rounding modes
1736
+ if (RISCVII::hasRoundModeOp (TSFlags) &&
1737
+ !checkImmOperand (RISCVII::getVLOpNum (Desc) - 1 ))
1738
+ return false ;
1739
+
1740
+ return true ;
1741
+ }
1742
+
1743
+ // Most of our RVV pseudo has passthru operand, so the real operands
1744
+ // start from index = 2.
1745
+ bool RISCVInstrInfo::hasReassociableVectorSibling (const MachineInstr &Inst,
1746
+ bool &Commuted) const {
1747
+ const MachineBasicBlock *MBB = Inst.getParent ();
1748
+ const MachineRegisterInfo &MRI = MBB->getParent ()->getRegInfo ();
1749
+ MachineInstr *MI1 = MRI.getUniqueVRegDef (Inst.getOperand (2 ).getReg ());
1750
+ MachineInstr *MI2 = MRI.getUniqueVRegDef (Inst.getOperand (3 ).getReg ());
1751
+
1752
+ // If only one operand has the same or inverse opcode and it's the second
1753
+ // source operand, the operands must be commuted.
1754
+ Commuted = !areRVVInstsReassociable (Inst, *MI1) &&
1755
+ areRVVInstsReassociable (Inst, *MI2);
1756
+ if (Commuted)
1757
+ std::swap (MI1, MI2);
1758
+
1759
+ return areRVVInstsReassociable (Inst, *MI1) &&
1760
+ (isVectorAssociativeAndCommutative (*MI1) ||
1761
+ isVectorAssociativeAndCommutative (*MI1, /* Invert */ true )) &&
1762
+ hasReassociableOperands (*MI1, MBB) &&
1763
+ MRI.hasOneNonDBGUse (MI1->getOperand (0 ).getReg ());
1764
+ }
1765
+
1766
+ bool RISCVInstrInfo::hasReassociableOperands (
1767
+ const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
1768
+ if (!isVectorAssociativeAndCommutative (Inst) &&
1769
+ !isVectorAssociativeAndCommutative (Inst, /* Invert=*/ true ))
1770
+ return TargetInstrInfo::hasReassociableOperands (Inst, MBB);
1771
+
1772
+ const MachineOperand &Op1 = Inst.getOperand (2 );
1773
+ const MachineOperand &Op2 = Inst.getOperand (3 );
1774
+ const MachineRegisterInfo &MRI = MBB->getParent ()->getRegInfo ();
1775
+
1776
+ // We need virtual register definitions for the operands that we will
1777
+ // reassociate.
1778
+ MachineInstr *MI1 = nullptr ;
1779
+ MachineInstr *MI2 = nullptr ;
1780
+ if (Op1.isReg () && Op1.getReg ().isVirtual ())
1781
+ MI1 = MRI.getUniqueVRegDef (Op1.getReg ());
1782
+ if (Op2.isReg () && Op2.getReg ().isVirtual ())
1783
+ MI2 = MRI.getUniqueVRegDef (Op2.getReg ());
1784
+
1785
+ // And at least one operand must be defined in MBB.
1786
+ return MI1 && MI2 && (MI1->getParent () == MBB || MI2->getParent () == MBB);
1787
+ }
1788
+
1789
+ void RISCVInstrInfo::getReassociateOperandIndices (
1790
+ const MachineInstr &Root, unsigned Pattern,
1791
+ std::array<unsigned , 5 > &OperandIndices) const {
1792
+ TargetInstrInfo::getReassociateOperandIndices (Root, Pattern, OperandIndices);
1793
+ if (isVectorAssociativeAndCommutative (Root) ||
1794
+ isVectorAssociativeAndCommutative (Root, /* Invert=*/ true )) {
1795
+ // Skip the passthrough operand, so add all indices by one.
1796
+ for (unsigned I = 0 ; I < 5 ; ++I)
1797
+ ++OperandIndices[I];
1798
+ }
1799
+ }
1800
+
1629
1801
bool RISCVInstrInfo::hasReassociableSibling (const MachineInstr &Inst,
1630
1802
bool &Commuted) const {
1803
+ if (isVectorAssociativeAndCommutative (Inst) ||
1804
+ isVectorAssociativeAndCommutative (Inst, /* Invert=*/ true ))
1805
+ return hasReassociableVectorSibling (Inst, Commuted);
1806
+
1631
1807
if (!TargetInstrInfo::hasReassociableSibling (Inst, Commuted))
1632
1808
return false ;
1633
1809
@@ -1647,6 +1823,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
1647
1823
1648
1824
bool RISCVInstrInfo::isAssociativeAndCommutative (const MachineInstr &Inst,
1649
1825
bool Invert) const {
1826
+ if (isVectorAssociativeAndCommutative (Inst, Invert))
1827
+ return true ;
1828
+
1650
1829
unsigned Opc = Inst.getOpcode ();
1651
1830
if (Invert) {
1652
1831
auto InverseOpcode = getInverseOpcode (Opc);
@@ -1699,6 +1878,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
1699
1878
1700
1879
std::optional<unsigned >
1701
1880
RISCVInstrInfo::getInverseOpcode (unsigned Opcode) const {
1881
+ #define RVV_OPC_LMUL_CASE (OPC, INV ) \
1882
+ case RISCV::OPC##_M1: \
1883
+ return RISCV::INV##_M1; \
1884
+ case RISCV::OPC##_M2: \
1885
+ return RISCV::INV##_M2; \
1886
+ case RISCV::OPC##_M4: \
1887
+ return RISCV::INV##_M4; \
1888
+ case RISCV::OPC##_M8: \
1889
+ return RISCV::INV##_M8; \
1890
+ case RISCV::OPC##_MF2: \
1891
+ return RISCV::INV##_MF2; \
1892
+ case RISCV::OPC##_MF4: \
1893
+ return RISCV::INV##_MF4; \
1894
+ case RISCV::OPC##_MF8: \
1895
+ return RISCV::INV##_MF8
1896
+
1897
+ #define RVV_OPC_LMUL_MASK_CASE (OPC, INV ) \
1898
+ case RISCV::OPC##_M1_MASK: \
1899
+ return RISCV::INV##_M1_MASK; \
1900
+ case RISCV::OPC##_M2_MASK: \
1901
+ return RISCV::INV##_M2_MASK; \
1902
+ case RISCV::OPC##_M4_MASK: \
1903
+ return RISCV::INV##_M4_MASK; \
1904
+ case RISCV::OPC##_M8_MASK: \
1905
+ return RISCV::INV##_M8_MASK; \
1906
+ case RISCV::OPC##_MF2_MASK: \
1907
+ return RISCV::INV##_MF2_MASK; \
1908
+ case RISCV::OPC##_MF4_MASK: \
1909
+ return RISCV::INV##_MF4_MASK; \
1910
+ case RISCV::OPC##_MF8_MASK: \
1911
+ return RISCV::INV##_MF8_MASK
1912
+
1702
1913
switch (Opcode) {
1703
1914
default :
1704
1915
return std::nullopt;
@@ -1722,7 +1933,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
1722
1933
return RISCV::SUBW;
1723
1934
case RISCV::SUBW:
1724
1935
return RISCV::ADDW;
1936
+ // clang-format off
1937
+ RVV_OPC_LMUL_CASE (PseudoVADD_VV, PseudoVSUB_VV);
1938
+ RVV_OPC_LMUL_MASK_CASE (PseudoVADD_VV, PseudoVSUB_VV);
1939
+ RVV_OPC_LMUL_CASE (PseudoVSUB_VV, PseudoVADD_VV);
1940
+ RVV_OPC_LMUL_MASK_CASE (PseudoVSUB_VV, PseudoVADD_VV);
1941
+ // clang-format on
1725
1942
}
1943
+
1944
+ #undef RVV_OPC_LMUL_MASK_CASE
1945
+ #undef RVV_OPC_LMUL_CASE
1726
1946
}
1727
1947
1728
1948
static bool canCombineFPFusedMultiply (const MachineInstr &Root,
0 commit comments