Skip to content

Commit bc071ff

Browse files
committed
[X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl
1 parent 02a275c commit bc071ff

8 files changed

+1761
-1616
lines changed

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
10671067
MaskInfo.RC:$src0))],
10681068
DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;
10691069

1070-
let hasSideEffects = 0, mayLoad = 1 in
1070+
let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
10711071
def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
10721072
(ins SrcInfo.ScalarMemOp:$src),
10731073
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),

llvm/lib/Target/X86/X86InstrFoldTables.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ const X86FoldTableEntry *llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
143143
return lookupFoldTableImpl(FoldTable, RegOp);
144144
}
145145

146+
const X86FoldTableEntry *
147+
llvm::lookupBroadcastFoldTable(unsigned RegOp, unsigned OpNum) {
148+
ArrayRef<X86FoldTableEntry> FoldTable;
149+
if (OpNum == 1)
150+
FoldTable = ArrayRef(BroadcastTable1);
151+
else if (OpNum == 2)
152+
FoldTable = ArrayRef(BroadcastTable2);
153+
else if (OpNum == 3)
154+
FoldTable = ArrayRef(BroadcastTable3);
155+
else if (OpNum == 4)
156+
FoldTable = ArrayRef(BroadcastTable4);
157+
else
158+
return nullptr;
159+
160+
return lookupFoldTableImpl(FoldTable, RegOp);
161+
}
162+
146163
namespace {
147164

148165
// This class stores the memory unfolding tables. It is instantiated as a
@@ -288,8 +305,8 @@ struct X86BroadcastFoldTable {
288305
};
289306
} // namespace
290307

291-
static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
292-
unsigned BroadcastBits) {
308+
bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
309+
unsigned BroadcastBits) {
293310
switch (Entry.Flags & TB_BCAST_MASK) {
294311
case TB_BCAST_W:
295312
case TB_BCAST_SH:

llvm/lib/Target/X86/X86InstrFoldTables.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
4444
// operand OpNum.
4545
const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);
4646

47+
// Look up the broadcast folding table entry for folding a broadcast with
48+
// operand OpNum.
49+
const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
50+
unsigned OpNum);
51+
4752
// Look up the memory unfolding table entry for this instruction.
4853
const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
4954

@@ -52,6 +57,7 @@ const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
5257
const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
5358
unsigned BroadcastBits);
5459

60+
bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
5561
} // namespace llvm
5662

5763
#endif

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
862862
case X86::MMX_MOVD64rm:
863863
case X86::MMX_MOVQ64rm:
864864
// AVX-512
865+
case X86::VPBROADCASTBZ128rm:
866+
case X86::VPBROADCASTBZ256rm:
867+
case X86::VPBROADCASTBZrm:
868+
case X86::VBROADCASTF32X2Z256rm:
869+
case X86::VBROADCASTF32X2Zrm:
870+
case X86::VBROADCASTI32X2Z128rm:
871+
case X86::VBROADCASTI32X2Z256rm:
872+
case X86::VBROADCASTI32X2Zrm:
873+
case X86::VPBROADCASTWZ128rm:
874+
case X86::VPBROADCASTWZ256rm:
875+
case X86::VPBROADCASTWZrm:
876+
case X86::VPBROADCASTDZ128rm:
877+
case X86::VPBROADCASTDZ256rm:
878+
case X86::VPBROADCASTDZrm:
879+
case X86::VBROADCASTSSZ128rm:
880+
case X86::VBROADCASTSSZ256rm:
881+
case X86::VBROADCASTSSZrm:
882+
case X86::VPBROADCASTQZ128rm:
883+
case X86::VPBROADCASTQZ256rm:
884+
case X86::VPBROADCASTQZrm:
885+
case X86::VBROADCASTSDZ256rm:
886+
case X86::VBROADCASTSDZrm:
865887
case X86::VMOVSSZrm:
866888
case X86::VMOVSSZrm_alt:
867889
case X86::VMOVSDZrm:
@@ -8063,6 +8085,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80638085
MOs.push_back(MachineOperand::CreateReg(0, false));
80648086
break;
80658087
}
8088+
case X86::VPBROADCASTBZ128rm:
8089+
case X86::VPBROADCASTBZ256rm:
8090+
case X86::VPBROADCASTBZrm:
8091+
case X86::VBROADCASTF32X2Z256rm:
8092+
case X86::VBROADCASTF32X2Zrm:
8093+
case X86::VBROADCASTI32X2Z128rm:
8094+
case X86::VBROADCASTI32X2Z256rm:
8095+
case X86::VBROADCASTI32X2Zrm:
8096+
// No instructions currently fuse with 8bits or 32bits x 2.
8097+
return nullptr;
8098+
8099+
#define FOLD_BROADCAST(SIZE) \
8100+
MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands, \
8101+
LoadMI.operands_begin() + NumOps); \
8102+
return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE, \
8103+
Alignment, /*AllowCommute=*/true);
8104+
case X86::VPBROADCASTWZ128rm:
8105+
case X86::VPBROADCASTWZ256rm:
8106+
case X86::VPBROADCASTWZrm:
8107+
FOLD_BROADCAST(16);
8108+
case X86::VPBROADCASTDZ128rm:
8109+
case X86::VPBROADCASTDZ256rm:
8110+
case X86::VPBROADCASTDZrm:
8111+
case X86::VBROADCASTSSZ128rm:
8112+
case X86::VBROADCASTSSZ256rm:
8113+
case X86::VBROADCASTSSZrm:
8114+
FOLD_BROADCAST(32);
8115+
case X86::VPBROADCASTQZ128rm:
8116+
case X86::VPBROADCASTQZ256rm:
8117+
case X86::VPBROADCASTQZrm:
8118+
case X86::VBROADCASTSDZ256rm:
8119+
case X86::VBROADCASTSDZrm:
8120+
FOLD_BROADCAST(64);
80668121
default: {
80678122
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
80688123
return nullptr;
@@ -8077,6 +8132,78 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80778132
/*Size=*/0, Alignment, /*AllowCommute=*/true);
80788133
}
80798134

8135+
MachineInstr *X86InstrInfo::foldMemoryBroadcast(
8136+
MachineFunction &MF, MachineInstr &MI, unsigned OpNum,
8137+
ArrayRef<MachineOperand> MOs, MachineBasicBlock::iterator InsertPt,
8138+
unsigned BitsSize, Align Alignment, bool AllowCommute) const {
8139+
8140+
if (auto *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum))
8141+
return matchBroadcastSize(*I, BitsSize)
8142+
? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
8143+
: nullptr;
8144+
8145+
// TODO: Share code with foldMemoryOperandImpl for the commute
8146+
if (AllowCommute) {
8147+
unsigned CommuteOpIdx1 = OpNum, CommuteOpIdx2 = CommuteAnyOperandIndex;
8148+
if (findCommutedOpIndices(MI, CommuteOpIdx1, CommuteOpIdx2)) {
8149+
bool HasDef = MI.getDesc().getNumDefs();
8150+
Register Reg0 = HasDef ? MI.getOperand(0).getReg() : Register();
8151+
Register Reg1 = MI.getOperand(CommuteOpIdx1).getReg();
8152+
Register Reg2 = MI.getOperand(CommuteOpIdx2).getReg();
8153+
bool Tied1 =
8154+
0 == MI.getDesc().getOperandConstraint(CommuteOpIdx1, MCOI::TIED_TO);
8155+
bool Tied2 =
8156+
0 == MI.getDesc().getOperandConstraint(CommuteOpIdx2, MCOI::TIED_TO);
8157+
8158+
// If either of the commutable operands are tied to the destination
8159+
// then we can not commute + fold.
8160+
if ((HasDef && Reg0 == Reg1 && Tied1) ||
8161+
(HasDef && Reg0 == Reg2 && Tied2))
8162+
return nullptr;
8163+
8164+
MachineInstr *CommutedMI =
8165+
commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
8166+
if (!CommutedMI) {
8167+
// Unable to commute.
8168+
return nullptr;
8169+
}
8170+
if (CommutedMI != &MI) {
8171+
// New instruction. We can't fold from this.
8172+
CommutedMI->eraseFromParent();
8173+
return nullptr;
8174+
}
8175+
8176+
// Attempt to fold with the commuted version of the instruction.
8177+
MachineInstr *NewMI = foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs,
8178+
InsertPt, BitsSize, Alignment,
8179+
/*AllowCommute=*/false);
8180+
if (NewMI)
8181+
return NewMI;
8182+
8183+
// Folding failed again - undo the commute before returning.
8184+
MachineInstr *UncommutedMI =
8185+
commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
8186+
if (!UncommutedMI) {
8187+
// Unable to commute.
8188+
return nullptr;
8189+
}
8190+
if (UncommutedMI != &MI) {
8191+
// New instruction. It doesn't need to be kept.
8192+
UncommutedMI->eraseFromParent();
8193+
return nullptr;
8194+
}
8195+
8196+
// Return here to prevent duplicate fuse failure report.
8197+
return nullptr;
8198+
}
8199+
}
8200+
8201+
// No fusion
8202+
if (PrintFailedFusing && !MI.isCopy())
8203+
dbgs() << "We failed to fuse operand " << OpNum << " in " << MI;
8204+
return nullptr;
8205+
}
8206+
80808207
static SmallVector<MachineMemOperand *, 2>
80818208
extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
80828209
SmallVector<MachineMemOperand *, 2> LoadMMOs;

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,13 @@ class X86InstrInfo final : public X86GenInstrInfo {
643643
MachineBasicBlock::iterator InsertPt,
644644
unsigned Size, Align Alignment) const;
645645

646+
MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
647+
unsigned OpNum,
648+
ArrayRef<MachineOperand> MOs,
649+
MachineBasicBlock::iterator InsertPt,
650+
unsigned BitsSize, Align Alignment,
651+
bool AllowCommute) const;
652+
646653
/// isFrameOperand - Return true and the FrameIndex if the specified
647654
/// operand and follow operands form a reference to the stack frame.
648655
bool isFrameOperand(const MachineInstr &MI, unsigned int Op,

0 commit comments

Comments
 (0)