Skip to content

Commit 6ac5a84

Browse files
committed
[RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions
This patch covers VADD_VV and VMUL_VV.
1 parent 69ed35f commit 6ac5a84

File tree

3 files changed

+241
-7
lines changed

3 files changed

+241
-7
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,8 +1626,184 @@ static bool isFMUL(unsigned Opc) {
16261626
}
16271627
}
16281628

1629+
bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
1630+
bool Invert) const {
1631+
#define OPCODE_LMUL_CASE(OPC) \
1632+
case RISCV::OPC##_M1: \
1633+
case RISCV::OPC##_M2: \
1634+
case RISCV::OPC##_M4: \
1635+
case RISCV::OPC##_M8: \
1636+
case RISCV::OPC##_MF2: \
1637+
case RISCV::OPC##_MF4: \
1638+
case RISCV::OPC##_MF8
1639+
1640+
#define OPCODE_LMUL_MASK_CASE(OPC) \
1641+
case RISCV::OPC##_M1_MASK: \
1642+
case RISCV::OPC##_M2_MASK: \
1643+
case RISCV::OPC##_M4_MASK: \
1644+
case RISCV::OPC##_M8_MASK: \
1645+
case RISCV::OPC##_MF2_MASK: \
1646+
case RISCV::OPC##_MF4_MASK: \
1647+
case RISCV::OPC##_MF8_MASK
1648+
1649+
unsigned Opcode = Inst.getOpcode();
1650+
if (Invert) {
1651+
if (auto InvOpcode = getInverseOpcode(Opcode))
1652+
Opcode = *InvOpcode;
1653+
else
1654+
return false;
1655+
}
1656+
1657+
// clang-format off
1658+
switch (Opcode) {
1659+
default:
1660+
return false;
1661+
OPCODE_LMUL_CASE(PseudoVADD_VV):
1662+
OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
1663+
OPCODE_LMUL_CASE(PseudoVMUL_VV):
1664+
OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
1665+
OPCODE_LMUL_CASE(PseudoVMULH_VV):
1666+
OPCODE_LMUL_MASK_CASE(PseudoVMULH_VV):
1667+
OPCODE_LMUL_CASE(PseudoVMULHU_VV):
1668+
OPCODE_LMUL_MASK_CASE(PseudoVMULHU_VV):
1669+
return true;
1670+
}
1671+
// clang-format on
1672+
1673+
#undef OPCODE_LMUL_MASK_CASE
1674+
#undef OPCODE_LMUL_CASE
1675+
}
1676+
1677+
bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &MI1,
1678+
const MachineInstr &MI2) const {
1679+
if (!areOpcodesEqualOrInverse(MI1.getOpcode(), MI2.getOpcode()))
1680+
return false;
1681+
1682+
// Make sure vtype operands are also the same.
1683+
const MCInstrDesc &Desc = get(MI1.getOpcode());
1684+
const uint64_t TSFlags = Desc.TSFlags;
1685+
1686+
auto checkImmOperand = [&](unsigned OpIdx) {
1687+
return MI1.getOperand(OpIdx).getImm() == MI2.getOperand(OpIdx).getImm();
1688+
};
1689+
1690+
auto checkRegOperand = [&](unsigned OpIdx) {
1691+
return MI1.getOperand(OpIdx).getReg() == MI2.getOperand(OpIdx).getReg();
1692+
};
1693+
1694+
// PassThru
1695+
if (!checkRegOperand(1))
1696+
return false;
1697+
1698+
// SEW
1699+
if (RISCVII::hasSEWOp(TSFlags) &&
1700+
!checkImmOperand(RISCVII::getSEWOpNum(Desc)))
1701+
return false;
1702+
1703+
// Mask
1704+
// There might be more sophisticated ways to check equality of masks, but
1705+
// right now we simply check if they're the same virtual register.
1706+
if (RISCVII::usesMaskPolicy(TSFlags) && !checkRegOperand(4))
1707+
return false;
1708+
1709+
// Tail / Mask policies
1710+
if (RISCVII::hasVecPolicyOp(TSFlags) &&
1711+
!checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
1712+
return false;
1713+
1714+
// VL
1715+
if (RISCVII::hasVLOp(TSFlags)) {
1716+
unsigned OpIdx = RISCVII::getVLOpNum(Desc);
1717+
const MachineOperand &Op1 = MI1.getOperand(OpIdx);
1718+
const MachineOperand &Op2 = MI2.getOperand(OpIdx);
1719+
if (Op1.getType() != Op2.getType())
1720+
return false;
1721+
switch (Op1.getType()) {
1722+
case MachineOperand::MO_Register:
1723+
if (Op1.getReg() != Op2.getReg())
1724+
return false;
1725+
break;
1726+
case MachineOperand::MO_Immediate:
1727+
if (Op1.getImm() != Op2.getImm())
1728+
return false;
1729+
break;
1730+
default:
1731+
llvm_unreachable("Unrecognized VL operand type");
1732+
}
1733+
}
1734+
1735+
// Rounding modes
1736+
if (RISCVII::hasRoundModeOp(TSFlags) &&
1737+
!checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
1738+
return false;
1739+
1740+
return true;
1741+
}
1742+
1743+
// Most of our RVV pseudo has passthru operand, so the real operands
1744+
// start from index = 2.
1745+
bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
1746+
bool &Commuted) const {
1747+
const MachineBasicBlock *MBB = Inst.getParent();
1748+
const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
1749+
MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
1750+
MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());
1751+
1752+
// If only one operand has the same or inverse opcode and it's the second
1753+
// source operand, the operands must be commuted.
1754+
Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
1755+
areRVVInstsReassociable(Inst, *MI2);
1756+
if (Commuted)
1757+
std::swap(MI1, MI2);
1758+
1759+
return areRVVInstsReassociable(Inst, *MI1) &&
1760+
(isVectorAssociativeAndCommutative(*MI1) ||
1761+
isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
1762+
hasReassociableOperands(*MI1, MBB) &&
1763+
MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
1764+
}
1765+
1766+
bool RISCVInstrInfo::hasReassociableOperands(
1767+
const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
1768+
if (!isVectorAssociativeAndCommutative(Inst) &&
1769+
!isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
1770+
return TargetInstrInfo::hasReassociableOperands(Inst, MBB);
1771+
1772+
const MachineOperand &Op1 = Inst.getOperand(2);
1773+
const MachineOperand &Op2 = Inst.getOperand(3);
1774+
const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
1775+
1776+
// We need virtual register definitions for the operands that we will
1777+
// reassociate.
1778+
MachineInstr *MI1 = nullptr;
1779+
MachineInstr *MI2 = nullptr;
1780+
if (Op1.isReg() && Op1.getReg().isVirtual())
1781+
MI1 = MRI.getUniqueVRegDef(Op1.getReg());
1782+
if (Op2.isReg() && Op2.getReg().isVirtual())
1783+
MI2 = MRI.getUniqueVRegDef(Op2.getReg());
1784+
1785+
// And at least one operand must be defined in MBB.
1786+
return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
1787+
}
1788+
1789+
void RISCVInstrInfo::getReassociateOperandIndices(
1790+
const MachineInstr &Root, unsigned Pattern,
1791+
std::array<unsigned, 5> &OperandIndices) const {
1792+
TargetInstrInfo::getReassociateOperandIndices(Root, Pattern, OperandIndices);
1793+
if (isVectorAssociativeAndCommutative(Root) ||
1794+
isVectorAssociativeAndCommutative(Root, /*Invert=*/true)) {
1795+
// Skip the passthrough operand, so add all indices by one.
1796+
for (unsigned I = 0; I < 5; ++I)
1797+
++OperandIndices[I];
1798+
}
1799+
}
1800+
16291801
bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
16301802
bool &Commuted) const {
1803+
if (isVectorAssociativeAndCommutative(Inst) ||
1804+
isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
1805+
return hasReassociableVectorSibling(Inst, Commuted);
1806+
16311807
if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
16321808
return false;
16331809

@@ -1647,6 +1823,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
16471823

16481824
bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
16491825
bool Invert) const {
1826+
if (isVectorAssociativeAndCommutative(Inst, Invert))
1827+
return true;
1828+
16501829
unsigned Opc = Inst.getOpcode();
16511830
if (Invert) {
16521831
auto InverseOpcode = getInverseOpcode(Opc);
@@ -1699,6 +1878,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
16991878

17001879
std::optional<unsigned>
17011880
RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
1881+
#define RVV_OPC_LMUL_CASE(OPC, INV) \
1882+
case RISCV::OPC##_M1: \
1883+
return RISCV::INV##_M1; \
1884+
case RISCV::OPC##_M2: \
1885+
return RISCV::INV##_M2; \
1886+
case RISCV::OPC##_M4: \
1887+
return RISCV::INV##_M4; \
1888+
case RISCV::OPC##_M8: \
1889+
return RISCV::INV##_M8; \
1890+
case RISCV::OPC##_MF2: \
1891+
return RISCV::INV##_MF2; \
1892+
case RISCV::OPC##_MF4: \
1893+
return RISCV::INV##_MF4; \
1894+
case RISCV::OPC##_MF8: \
1895+
return RISCV::INV##_MF8
1896+
1897+
#define RVV_OPC_LMUL_MASK_CASE(OPC, INV) \
1898+
case RISCV::OPC##_M1_MASK: \
1899+
return RISCV::INV##_M1_MASK; \
1900+
case RISCV::OPC##_M2_MASK: \
1901+
return RISCV::INV##_M2_MASK; \
1902+
case RISCV::OPC##_M4_MASK: \
1903+
return RISCV::INV##_M4_MASK; \
1904+
case RISCV::OPC##_M8_MASK: \
1905+
return RISCV::INV##_M8_MASK; \
1906+
case RISCV::OPC##_MF2_MASK: \
1907+
return RISCV::INV##_MF2_MASK; \
1908+
case RISCV::OPC##_MF4_MASK: \
1909+
return RISCV::INV##_MF4_MASK; \
1910+
case RISCV::OPC##_MF8_MASK: \
1911+
return RISCV::INV##_MF8_MASK
1912+
17021913
switch (Opcode) {
17031914
default:
17041915
return std::nullopt;
@@ -1722,7 +1933,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
17221933
return RISCV::SUBW;
17231934
case RISCV::SUBW:
17241935
return RISCV::ADDW;
1936+
// clang-format off
1937+
RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV);
1938+
RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV);
1939+
RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV);
1940+
RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV);
1941+
// clang-format on
17251942
}
1943+
1944+
#undef RVV_OPC_LMUL_MASK_CASE
1945+
#undef RVV_OPC_LMUL_CASE
17261946
}
17271947

17281948
static bool canCombineFPFusedMultiply(const MachineInstr &Root,

llvm/lib/Target/RISCV/RISCVInstrInfo.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
266266
SmallVectorImpl<MachineInstr *> &DelInstrs,
267267
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;
268268

269+
bool hasReassociableOperands(const MachineInstr &Inst,
270+
const MachineBasicBlock *MBB) const override;
271+
269272
bool hasReassociableSibling(const MachineInstr &Inst,
270273
bool &Commuted) const override;
271274

@@ -274,6 +277,10 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
274277

275278
std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;
276279

280+
void getReassociateOperandIndices(
281+
const MachineInstr &Root, unsigned Pattern,
282+
std::array<unsigned, 5> &OperandIndices) const override;
283+
277284
ArrayRef<std::pair<MachineMemOperand::Flags, const char *>>
278285
getSerializableMachineMemOperandTargetFlags() const override;
279286

@@ -297,6 +304,13 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
297304

298305
private:
299306
unsigned getInstBundleLength(const MachineInstr &MI) const;
307+
308+
bool isVectorAssociativeAndCommutative(const MachineInstr &MI,
309+
bool Invert = false) const;
310+
bool areRVVInstsReassociable(const MachineInstr &MI1,
311+
const MachineInstr &MI2) const;
312+
bool hasReassociableVectorSibling(const MachineInstr &Inst,
313+
bool &Commuted) const;
300314
};
301315

302316
namespace RISCV {

llvm/test/CodeGen/RISCV/rvv/vector-reassociations.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ define <vscale x 1 x i8> @simple_vadd_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
3131
; CHECK: # %bb.0: # %entry
3232
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
3333
; CHECK-NEXT: vadd.vv v9, v8, v9
34-
; CHECK-NEXT: vadd.vv v9, v8, v9
34+
; CHECK-NEXT: vadd.vv v8, v8, v8
3535
; CHECK-NEXT: vadd.vv v8, v8, v9
3636
; CHECK-NEXT: ret
3737
entry:
@@ -61,7 +61,7 @@ define <vscale x 1 x i8> @simple_vadd_vsub_vv(<vscale x 1 x i8> %0, <vscale x 1
6161
; CHECK: # %bb.0: # %entry
6262
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
6363
; CHECK-NEXT: vsub.vv v9, v8, v9
64-
; CHECK-NEXT: vadd.vv v9, v8, v9
64+
; CHECK-NEXT: vadd.vv v8, v8, v8
6565
; CHECK-NEXT: vadd.vv v8, v8, v9
6666
; CHECK-NEXT: ret
6767
entry:
@@ -91,7 +91,7 @@ define <vscale x 1 x i8> @simple_vmul_vv(<vscale x 1 x i8> %0, <vscale x 1 x i8>
9191
; CHECK: # %bb.0: # %entry
9292
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
9393
; CHECK-NEXT: vmul.vv v9, v8, v9
94-
; CHECK-NEXT: vmul.vv v9, v8, v9
94+
; CHECK-NEXT: vmul.vv v8, v8, v8
9595
; CHECK-NEXT: vmul.vv v8, v8, v9
9696
; CHECK-NEXT: ret
9797
entry:
@@ -124,8 +124,8 @@ define <vscale x 1 x i8> @vadd_vv_passthru(<vscale x 1 x i8> %0, <vscale x 1 x i
124124
; CHECK-NEXT: vmv1r.v v10, v8
125125
; CHECK-NEXT: vadd.vv v10, v8, v9
126126
; CHECK-NEXT: vmv1r.v v9, v8
127-
; CHECK-NEXT: vadd.vv v9, v8, v10
128-
; CHECK-NEXT: vadd.vv v8, v8, v9
127+
; CHECK-NEXT: vadd.vv v9, v8, v8
128+
; CHECK-NEXT: vadd.vv v8, v9, v10
129129
; CHECK-NEXT: ret
130130
entry:
131131
%a = call <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
@@ -187,8 +187,8 @@ define <vscale x 1 x i8> @vadd_vv_mask(<vscale x 1 x i8> %0, <vscale x 1 x i8> %
187187
; CHECK-NEXT: vmv1r.v v10, v8
188188
; CHECK-NEXT: vadd.vv v10, v8, v9, v0.t
189189
; CHECK-NEXT: vmv1r.v v9, v8
190-
; CHECK-NEXT: vadd.vv v9, v8, v10, v0.t
191-
; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t
190+
; CHECK-NEXT: vadd.vv v9, v8, v8, v0.t
191+
; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t
192192
; CHECK-NEXT: ret
193193
entry:
194194
%a = call <vscale x 1 x i8> @llvm.riscv.vadd.mask.nxv1i8.nxv1i8(

0 commit comments

Comments
 (0)