Skip to content

[RISCV][MachineCombiner] Add reassociation optimizations for RVV instructions #88307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1633,8 +1633,230 @@ static bool isFMUL(unsigned Opc) {
}
}

bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
bool Invert) const {
#define OPCODE_LMUL_CASE(OPC) \
case RISCV::OPC##_M1: \
case RISCV::OPC##_M2: \
case RISCV::OPC##_M4: \
case RISCV::OPC##_M8: \
case RISCV::OPC##_MF2: \
case RISCV::OPC##_MF4: \
case RISCV::OPC##_MF8

#define OPCODE_LMUL_MASK_CASE(OPC) \
case RISCV::OPC##_M1_MASK: \
case RISCV::OPC##_M2_MASK: \
case RISCV::OPC##_M4_MASK: \
case RISCV::OPC##_M8_MASK: \
case RISCV::OPC##_MF2_MASK: \
case RISCV::OPC##_MF4_MASK: \
case RISCV::OPC##_MF8_MASK

unsigned Opcode = Inst.getOpcode();
if (Invert) {
if (auto InvOpcode = getInverseOpcode(Opcode))
Opcode = *InvOpcode;
else
return false;
}

// clang-format off
switch (Opcode) {
default:
return false;
OPCODE_LMUL_CASE(PseudoVADD_VV):
OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
OPCODE_LMUL_CASE(PseudoVMUL_VV):
OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
return true;
}
// clang-format on

#undef OPCODE_LMUL_MASK_CASE
#undef OPCODE_LMUL_CASE
}

bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &Root,
const MachineInstr &Prev) const {
if (!areOpcodesEqualOrInverse(Root.getOpcode(), Prev.getOpcode()))
return false;

assert(Root.getMF() == Prev.getMF());
const MachineRegisterInfo *MRI = &Root.getMF()->getRegInfo();
const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();

// Make sure vtype operands are also the same.
const MCInstrDesc &Desc = get(Root.getOpcode());
const uint64_t TSFlags = Desc.TSFlags;

auto checkImmOperand = [&](unsigned OpIdx) {
return Root.getOperand(OpIdx).getImm() == Prev.getOperand(OpIdx).getImm();
};

auto checkRegOperand = [&](unsigned OpIdx) {
return Root.getOperand(OpIdx).getReg() == Prev.getOperand(OpIdx).getReg();
};

// PassThru
// TODO: Potentially we can loosen the condition to consider Root to be
// associable with Prev if Root has NoReg as passthru. In which case we
// also need to loosen the condition on vector policies between these.
if (!checkRegOperand(1))
return false;

// SEW
if (RISCVII::hasSEWOp(TSFlags) &&
!checkImmOperand(RISCVII::getSEWOpNum(Desc)))
return false;

// Mask
if (RISCVII::usesMaskPolicy(TSFlags)) {
const MachineBasicBlock *MBB = Root.getParent();
const MachineBasicBlock::const_reverse_iterator It1(&Root);
const MachineBasicBlock::const_reverse_iterator It2(&Prev);
Register MI1VReg;

bool SeenMI2 = false;
for (auto End = MBB->rend(), It = It1; It != End; ++It) {
if (It == It2) {
SeenMI2 = true;
if (!MI1VReg.isValid())
// There is no V0 def between Root and Prev; they're sharing the
// same V0.
break;
}

if (It->modifiesRegister(RISCV::V0, TRI)) {
Register SrcReg = It->getOperand(1).getReg();
// If it's not VReg it'll be more difficult to track its defs, so
// bailing out here just to be safe.
if (!SrcReg.isVirtual())
return false;

if (!MI1VReg.isValid()) {
// This is the V0 def for Root.
MI1VReg = SrcReg;
continue;
}

// Some random mask updates.
if (!SeenMI2)
continue;

// This is the V0 def for Prev; check if it's the same as that of
// Root.
if (MI1VReg != SrcReg)
return false;
else
break;
}
}

// If we haven't encountered Prev, it's likely that this function was
// called in a wrong way (e.g. Root is before Prev).
assert(SeenMI2 && "Prev is expected to appear before Root");
}

// Tail / Mask policies
if (RISCVII::hasVecPolicyOp(TSFlags) &&
!checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
return false;

// VL
if (RISCVII::hasVLOp(TSFlags)) {
unsigned OpIdx = RISCVII::getVLOpNum(Desc);
const MachineOperand &Op1 = Root.getOperand(OpIdx);
const MachineOperand &Op2 = Prev.getOperand(OpIdx);
if (Op1.getType() != Op2.getType())
return false;
switch (Op1.getType()) {
case MachineOperand::MO_Register:
if (Op1.getReg() != Op2.getReg())
return false;
break;
case MachineOperand::MO_Immediate:
if (Op1.getImm() != Op2.getImm())
return false;
break;
default:
llvm_unreachable("Unrecognized VL operand type");
}
}

// Rounding modes
if (RISCVII::hasRoundModeOp(TSFlags) &&
!checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
return false;

return true;
}

// Most of our RVV pseudos have passthru operand, so the real operands
// start from index = 2.
bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
bool &Commuted) const {
const MachineBasicBlock *MBB = Inst.getParent();
const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
assert(RISCVII::isFirstDefTiedToFirstUse(get(Inst.getOpcode())) &&
"Expect the present of passthrough operand.");
MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());

// If only one operand has the same or inverse opcode and it's the second
// source operand, the operands must be commuted.
Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
areRVVInstsReassociable(Inst, *MI2);
if (Commuted)
std::swap(MI1, MI2);

return areRVVInstsReassociable(Inst, *MI1) &&
(isVectorAssociativeAndCommutative(*MI1) ||
isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
hasReassociableOperands(*MI1, MBB) &&
MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
}

bool RISCVInstrInfo::hasReassociableOperands(
const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
if (!isVectorAssociativeAndCommutative(Inst) &&
!isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
return TargetInstrInfo::hasReassociableOperands(Inst, MBB);

const MachineOperand &Op1 = Inst.getOperand(2);
const MachineOperand &Op2 = Inst.getOperand(3);
const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();

// We need virtual register definitions for the operands that we will
// reassociate.
MachineInstr *MI1 = nullptr;
MachineInstr *MI2 = nullptr;
if (Op1.isReg() && Op1.getReg().isVirtual())
MI1 = MRI.getUniqueVRegDef(Op1.getReg());
if (Op2.isReg() && Op2.getReg().isVirtual())
MI2 = MRI.getUniqueVRegDef(Op2.getReg());

// And at least one operand must be defined in MBB.
return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
}

void RISCVInstrInfo::getReassociateOperandIndices(
const MachineInstr &Root, unsigned Pattern,
std::array<unsigned, 5> &OperandIndices) const {
TargetInstrInfo::getReassociateOperandIndices(Root, Pattern, OperandIndices);
if (RISCV::getRVVMCOpcode(Root.getOpcode())) {
// Skip the passthrough operand, so increment all indices by one.
for (unsigned I = 0; I < 5; ++I)
++OperandIndices[I];
}
}

bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
bool &Commuted) const {
if (isVectorAssociativeAndCommutative(Inst) ||
isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
return hasReassociableVectorSibling(Inst, Commuted);

if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
return false;

Expand All @@ -1654,6 +1876,9 @@ bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,

bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
bool Invert) const {
if (isVectorAssociativeAndCommutative(Inst, Invert))
return true;

unsigned Opc = Inst.getOpcode();
if (Invert) {
auto InverseOpcode = getInverseOpcode(Opc);
Expand Down Expand Up @@ -1706,6 +1931,38 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,

std::optional<unsigned>
RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
#define RVV_OPC_LMUL_CASE(OPC, INV) \
case RISCV::OPC##_M1: \
return RISCV::INV##_M1; \
case RISCV::OPC##_M2: \
return RISCV::INV##_M2; \
case RISCV::OPC##_M4: \
return RISCV::INV##_M4; \
case RISCV::OPC##_M8: \
return RISCV::INV##_M8; \
case RISCV::OPC##_MF2: \
return RISCV::INV##_MF2; \
case RISCV::OPC##_MF4: \
return RISCV::INV##_MF4; \
case RISCV::OPC##_MF8: \
return RISCV::INV##_MF8

#define RVV_OPC_LMUL_MASK_CASE(OPC, INV) \
case RISCV::OPC##_M1_MASK: \
return RISCV::INV##_M1_MASK; \
case RISCV::OPC##_M2_MASK: \
return RISCV::INV##_M2_MASK; \
case RISCV::OPC##_M4_MASK: \
return RISCV::INV##_M4_MASK; \
case RISCV::OPC##_M8_MASK: \
return RISCV::INV##_M8_MASK; \
case RISCV::OPC##_MF2_MASK: \
return RISCV::INV##_MF2_MASK; \
case RISCV::OPC##_MF4_MASK: \
return RISCV::INV##_MF4_MASK; \
case RISCV::OPC##_MF8_MASK: \
return RISCV::INV##_MF8_MASK

switch (Opcode) {
default:
return std::nullopt;
Expand All @@ -1729,7 +1986,16 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
return RISCV::SUBW;
case RISCV::SUBW:
return RISCV::ADDW;
// clang-format off
RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV);
RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV);
RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV);
RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV);
// clang-format on
}

#undef RVV_OPC_LMUL_MASK_CASE
#undef RVV_OPC_LMUL_CASE
}

static bool canCombineFPFusedMultiply(const MachineInstr &Root,
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
SmallVectorImpl<MachineInstr *> &DelInstrs,
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;

bool hasReassociableOperands(const MachineInstr &Inst,
const MachineBasicBlock *MBB) const override;

bool hasReassociableSibling(const MachineInstr &Inst,
bool &Commuted) const override;

Expand All @@ -274,6 +277,10 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {

std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;

void getReassociateOperandIndices(
const MachineInstr &Root, unsigned Pattern,
std::array<unsigned, 5> &OperandIndices) const override;

ArrayRef<std::pair<MachineMemOperand::Flags, const char *>>
getSerializableMachineMemOperandTargetFlags() const override;

Expand All @@ -297,6 +304,13 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {

private:
unsigned getInstBundleLength(const MachineInstr &MI) const;

bool isVectorAssociativeAndCommutative(const MachineInstr &MI,
bool Invert = false) const;
bool areRVVInstsReassociable(const MachineInstr &MI1,
const MachineInstr &MI2) const;
bool hasReassociableVectorSibling(const MachineInstr &Inst,
bool &Commuted) const;
};

namespace RISCV {
Expand Down
Loading