Skip to content

[X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl #79761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
MaskInfo.RC:$src0))],
DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;

let hasSideEffects = 0, mayLoad = 1 in
let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
(ins SrcInfo.ScalarMemOp:$src),
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
Expand Down
21 changes: 19 additions & 2 deletions llvm/lib/Target/X86/X86InstrFoldTables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,23 @@ const X86FoldTableEntry *llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
return lookupFoldTableImpl(FoldTable, RegOp);
}

const X86FoldTableEntry *llvm::lookupBroadcastFoldTable(unsigned RegOp,
unsigned OpNum) {
ArrayRef<X86FoldTableEntry> FoldTable;
if (OpNum == 1)
FoldTable = ArrayRef(BroadcastTable1);
else if (OpNum == 2)
FoldTable = ArrayRef(BroadcastTable2);
else if (OpNum == 3)
FoldTable = ArrayRef(BroadcastTable3);
else if (OpNum == 4)
FoldTable = ArrayRef(BroadcastTable4);
else
return nullptr;

return lookupFoldTableImpl(FoldTable, RegOp);
}

namespace {

// This class stores the memory unfolding tables. It is instantiated as a
Expand Down Expand Up @@ -288,8 +305,8 @@ struct X86BroadcastFoldTable {
};
} // namespace

static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
unsigned BroadcastBits) {
bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
unsigned BroadcastBits) {
switch (Entry.Flags & TB_BCAST_MASK) {
case TB_BCAST_W:
case TB_BCAST_SH:
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/X86/X86InstrFoldTables.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
// operand OpNum.
const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);

// Look up the broadcast folding table entry for folding a broadcast with
// operand OpNum.
const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
unsigned OpNum);

// Look up the memory unfolding table entry for this instruction.
const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);

Expand All @@ -52,6 +57,7 @@ const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
unsigned BroadcastBits);

bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
} // namespace llvm

#endif
86 changes: 86 additions & 0 deletions llvm/lib/Target/X86/X86InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
case X86::MMX_MOVD64rm:
case X86::MMX_MOVQ64rm:
// AVX-512
case X86::VPBROADCASTBZ128rm:
case X86::VPBROADCASTBZ256rm:
case X86::VPBROADCASTBZrm:
case X86::VBROADCASTF32X2Z256rm:
case X86::VBROADCASTF32X2Zrm:
case X86::VBROADCASTI32X2Z128rm:
case X86::VBROADCASTI32X2Z256rm:
case X86::VBROADCASTI32X2Zrm:
case X86::VPBROADCASTWZ128rm:
case X86::VPBROADCASTWZ256rm:
case X86::VPBROADCASTWZrm:
case X86::VPBROADCASTDZ128rm:
case X86::VPBROADCASTDZ256rm:
case X86::VPBROADCASTDZrm:
case X86::VBROADCASTSSZ128rm:
case X86::VBROADCASTSSZ256rm:
case X86::VBROADCASTSSZrm:
case X86::VPBROADCASTQZ128rm:
case X86::VPBROADCASTQZ256rm:
case X86::VPBROADCASTQZrm:
case X86::VBROADCASTSDZ256rm:
case X86::VBROADCASTSDZrm:
case X86::VMOVSSZrm:
case X86::VMOVSSZrm_alt:
case X86::VMOVSDZrm:
Expand Down Expand Up @@ -8067,6 +8089,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
MOs.push_back(MachineOperand::CreateReg(0, false));
break;
}
case X86::VPBROADCASTBZ128rm:
case X86::VPBROADCASTBZ256rm:
case X86::VPBROADCASTBZrm:
case X86::VBROADCASTF32X2Z256rm:
case X86::VBROADCASTF32X2Zrm:
case X86::VBROADCASTI32X2Z128rm:
case X86::VBROADCASTI32X2Z256rm:
case X86::VBROADCASTI32X2Zrm:
// No instructions currently fuse with 8bits or 32bits x 2.
return nullptr;

#define FOLD_BROADCAST(SIZE) \
MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands, \
LoadMI.operands_begin() + NumOps); \
return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE, \
/*AllowCommute=*/true);
case X86::VPBROADCASTWZ128rm:
case X86::VPBROADCASTWZ256rm:
case X86::VPBROADCASTWZrm:
FOLD_BROADCAST(16);
case X86::VPBROADCASTDZ128rm:
case X86::VPBROADCASTDZ256rm:
case X86::VPBROADCASTDZrm:
case X86::VBROADCASTSSZ128rm:
case X86::VBROADCASTSSZ256rm:
case X86::VBROADCASTSSZrm:
FOLD_BROADCAST(32);
case X86::VPBROADCASTQZ128rm:
case X86::VPBROADCASTQZ256rm:
case X86::VPBROADCASTQZrm:
case X86::VBROADCASTSDZ256rm:
case X86::VBROADCASTSDZrm:
FOLD_BROADCAST(64);
default: {
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
return nullptr;
Expand All @@ -8081,6 +8136,37 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
/*Size=*/0, Alignment, /*AllowCommute=*/true);
}

MachineInstr *
X86InstrInfo::foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
unsigned OpNum, ArrayRef<MachineOperand> MOs,
MachineBasicBlock::iterator InsertPt,
unsigned BitsSize, bool AllowCommute) const {

if (auto *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum))
return matchBroadcastSize(*I, BitsSize)
? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
: nullptr;

if (AllowCommute) {
// If the instruction and target operand are commutable, commute the
// instruction and try again.
unsigned CommuteOpIdx2 = commuteOperandsForFold(MI, OpNum);
if (CommuteOpIdx2 == OpNum) {
printFailMsgforFold(MI, OpNum);
return nullptr;
}
MachineInstr *NewMI =
foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs, InsertPt, BitsSize,
/*AllowCommute=*/false);
if (NewMI)
return NewMI;
UndoCommuteForFold(MI, OpNum, CommuteOpIdx2);
}

printFailMsgforFold(MI, OpNum);
return nullptr;
}

static SmallVector<MachineMemOperand *, 2>
extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
SmallVector<MachineMemOperand *, 2> LoadMMOs;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/X86/X86InstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,12 @@ class X86InstrInfo final : public X86GenInstrInfo {
MachineBasicBlock::iterator InsertPt,
unsigned Size, Align Alignment) const;

MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
unsigned OpNum,
ArrayRef<MachineOperand> MOs,
MachineBasicBlock::iterator InsertPt,
unsigned BitsSize, bool AllowCommute) const;

/// isFrameOperand - Return true and the FrameIndex if the specified
/// operand and follow operands form a reference to the stack frame.
bool isFrameOperand(const MachineInstr &MI, unsigned int Op,
Expand Down
Loading