Skip to content

Commit e770fe0

Browse files
committed
[X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl
1 parent 2960656 commit e770fe0

8 files changed

+1761
-1616
lines changed

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
10671067
MaskInfo.RC:$src0))],
10681068
DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;
10691069

1070-
let hasSideEffects = 0, mayLoad = 1 in
1070+
let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
10711071
def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
10721072
(ins SrcInfo.ScalarMemOp:$src),
10731073
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),

llvm/lib/Target/X86/X86InstrFoldTables.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ const X86FoldTableEntry *llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
143143
return lookupFoldTableImpl(FoldTable, RegOp);
144144
}
145145

146+
const X86FoldTableEntry *
147+
llvm::lookupBroadcastFoldTable(unsigned RegOp, unsigned OpNum) {
148+
ArrayRef<X86FoldTableEntry> FoldTable;
149+
if (OpNum == 1)
150+
FoldTable = ArrayRef(BroadcastTable1);
151+
else if (OpNum == 2)
152+
FoldTable = ArrayRef(BroadcastTable2);
153+
else if (OpNum == 3)
154+
FoldTable = ArrayRef(BroadcastTable3);
155+
else if (OpNum == 4)
156+
FoldTable = ArrayRef(BroadcastTable4);
157+
else
158+
return nullptr;
159+
160+
return lookupFoldTableImpl(FoldTable, RegOp);
161+
}
162+
146163
namespace {
147164

148165
// This class stores the memory unfolding tables. It is instantiated as a
@@ -288,8 +305,8 @@ struct X86BroadcastFoldTable {
288305
};
289306
} // namespace
290307

291-
static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
292-
unsigned BroadcastBits) {
308+
bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
309+
unsigned BroadcastBits) {
293310
switch (Entry.Flags & TB_BCAST_MASK) {
294311
case TB_BCAST_W:
295312
case TB_BCAST_SH:

llvm/lib/Target/X86/X86InstrFoldTables.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
4444
// operand OpNum.
4545
const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);
4646

47+
// Look up the broadcast folding table entry for folding a broadcast with
48+
// operand OpNum.
49+
const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
50+
unsigned OpNum);
51+
4752
// Look up the memory unfolding table entry for this instruction.
4853
const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
4954

@@ -52,6 +57,7 @@ const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
5257
const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
5358
unsigned BroadcastBits);
5459

60+
bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
5561
} // namespace llvm
5662

5763
#endif

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
862862
case X86::MMX_MOVD64rm:
863863
case X86::MMX_MOVQ64rm:
864864
// AVX-512
865+
case X86::VPBROADCASTBZ128rm:
866+
case X86::VPBROADCASTBZ256rm:
867+
case X86::VPBROADCASTBZrm:
868+
case X86::VBROADCASTF32X2Z256rm:
869+
case X86::VBROADCASTF32X2Zrm:
870+
case X86::VBROADCASTI32X2Z128rm:
871+
case X86::VBROADCASTI32X2Z256rm:
872+
case X86::VBROADCASTI32X2Zrm:
873+
case X86::VPBROADCASTWZ128rm:
874+
case X86::VPBROADCASTWZ256rm:
875+
case X86::VPBROADCASTWZrm:
876+
case X86::VPBROADCASTDZ128rm:
877+
case X86::VPBROADCASTDZ256rm:
878+
case X86::VPBROADCASTDZrm:
879+
case X86::VBROADCASTSSZ128rm:
880+
case X86::VBROADCASTSSZ256rm:
881+
case X86::VBROADCASTSSZrm:
882+
case X86::VPBROADCASTQZ128rm:
883+
case X86::VPBROADCASTQZ256rm:
884+
case X86::VPBROADCASTQZrm:
885+
case X86::VBROADCASTSDZ256rm:
886+
case X86::VBROADCASTSDZrm:
865887
case X86::VMOVSSZrm:
866888
case X86::VMOVSSZrm_alt:
867889
case X86::VMOVSDZrm:
@@ -8067,6 +8089,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80678089
MOs.push_back(MachineOperand::CreateReg(0, false));
80688090
break;
80698091
}
8092+
case X86::VPBROADCASTBZ128rm:
8093+
case X86::VPBROADCASTBZ256rm:
8094+
case X86::VPBROADCASTBZrm:
8095+
case X86::VBROADCASTF32X2Z256rm:
8096+
case X86::VBROADCASTF32X2Zrm:
8097+
case X86::VBROADCASTI32X2Z128rm:
8098+
case X86::VBROADCASTI32X2Z256rm:
8099+
case X86::VBROADCASTI32X2Zrm:
8100+
// No instructions currently fuse with 8bits or 32bits x 2.
8101+
return nullptr;
8102+
8103+
#define FOLD_BROADCAST(SIZE) \
8104+
MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands, \
8105+
LoadMI.operands_begin() + NumOps); \
8106+
return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE, \
8107+
Alignment, /*AllowCommute=*/true);
8108+
case X86::VPBROADCASTWZ128rm:
8109+
case X86::VPBROADCASTWZ256rm:
8110+
case X86::VPBROADCASTWZrm:
8111+
FOLD_BROADCAST(16);
8112+
case X86::VPBROADCASTDZ128rm:
8113+
case X86::VPBROADCASTDZ256rm:
8114+
case X86::VPBROADCASTDZrm:
8115+
case X86::VBROADCASTSSZ128rm:
8116+
case X86::VBROADCASTSSZ256rm:
8117+
case X86::VBROADCASTSSZrm:
8118+
FOLD_BROADCAST(32);
8119+
case X86::VPBROADCASTQZ128rm:
8120+
case X86::VPBROADCASTQZ256rm:
8121+
case X86::VPBROADCASTQZrm:
8122+
case X86::VBROADCASTSDZ256rm:
8123+
case X86::VBROADCASTSDZrm:
8124+
FOLD_BROADCAST(64);
80708125
default: {
80718126
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
80728127
return nullptr;
@@ -8081,6 +8136,78 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80818136
/*Size=*/0, Alignment, /*AllowCommute=*/true);
80828137
}
80838138

8139+
MachineInstr *X86InstrInfo::foldMemoryBroadcast(
8140+
MachineFunction &MF, MachineInstr &MI, unsigned OpNum,
8141+
ArrayRef<MachineOperand> MOs, MachineBasicBlock::iterator InsertPt,
8142+
unsigned BitsSize, Align Alignment, bool AllowCommute) const {
8143+
8144+
if (auto *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum))
8145+
return matchBroadcastSize(*I, BitsSize)
8146+
? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
8147+
: nullptr;
8148+
8149+
// TODO: Share code with foldMemoryOperandImpl for the commute
8150+
if (AllowCommute) {
8151+
unsigned CommuteOpIdx1 = OpNum, CommuteOpIdx2 = CommuteAnyOperandIndex;
8152+
if (findCommutedOpIndices(MI, CommuteOpIdx1, CommuteOpIdx2)) {
8153+
bool HasDef = MI.getDesc().getNumDefs();
8154+
Register Reg0 = HasDef ? MI.getOperand(0).getReg() : Register();
8155+
Register Reg1 = MI.getOperand(CommuteOpIdx1).getReg();
8156+
Register Reg2 = MI.getOperand(CommuteOpIdx2).getReg();
8157+
bool Tied1 =
8158+
0 == MI.getDesc().getOperandConstraint(CommuteOpIdx1, MCOI::TIED_TO);
8159+
bool Tied2 =
8160+
0 == MI.getDesc().getOperandConstraint(CommuteOpIdx2, MCOI::TIED_TO);
8161+
8162+
// If either of the commutable operands are tied to the destination
8163+
// then we can not commute + fold.
8164+
if ((HasDef && Reg0 == Reg1 && Tied1) ||
8165+
(HasDef && Reg0 == Reg2 && Tied2))
8166+
return nullptr;
8167+
8168+
MachineInstr *CommutedMI =
8169+
commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
8170+
if (!CommutedMI) {
8171+
// Unable to commute.
8172+
return nullptr;
8173+
}
8174+
if (CommutedMI != &MI) {
8175+
// New instruction. We can't fold from this.
8176+
CommutedMI->eraseFromParent();
8177+
return nullptr;
8178+
}
8179+
8180+
// Attempt to fold with the commuted version of the instruction.
8181+
MachineInstr *NewMI = foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs,
8182+
InsertPt, BitsSize, Alignment,
8183+
/*AllowCommute=*/false);
8184+
if (NewMI)
8185+
return NewMI;
8186+
8187+
// Folding failed again - undo the commute before returning.
8188+
MachineInstr *UncommutedMI =
8189+
commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
8190+
if (!UncommutedMI) {
8191+
// Unable to commute.
8192+
return nullptr;
8193+
}
8194+
if (UncommutedMI != &MI) {
8195+
// New instruction. It doesn't need to be kept.
8196+
UncommutedMI->eraseFromParent();
8197+
return nullptr;
8198+
}
8199+
8200+
// Return here to prevent duplicate fuse failure report.
8201+
return nullptr;
8202+
}
8203+
}
8204+
8205+
// No fusion
8206+
if (PrintFailedFusing && !MI.isCopy())
8207+
dbgs() << "We failed to fuse operand " << OpNum << " in " << MI;
8208+
return nullptr;
8209+
}
8210+
80848211
static SmallVector<MachineMemOperand *, 2>
80858212
extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
80868213
SmallVector<MachineMemOperand *, 2> LoadMMOs;

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,13 @@ class X86InstrInfo final : public X86GenInstrInfo {
643643
MachineBasicBlock::iterator InsertPt,
644644
unsigned Size, Align Alignment) const;
645645

646+
MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
647+
unsigned OpNum,
648+
ArrayRef<MachineOperand> MOs,
649+
MachineBasicBlock::iterator InsertPt,
650+
unsigned BitsSize, Align Alignment,
651+
bool AllowCommute) const;
652+
646653
/// isFrameOperand - Return true and the FrameIndex if the specified
647654
/// operand and follow operands form a reference to the stack frame.
648655
bool isFrameOperand(const MachineInstr &MI, unsigned int Op,

0 commit comments

Comments
 (0)