Skip to content

Commit 062fa4d

Browse files
committed
[X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl
1 parent 6d7c8a6 commit 062fa4d

9 files changed

+1790
-1624
lines changed

llvm/lib/Target/X86/X86FixupVectorConstants.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,14 +406,14 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
406406
unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
407407
if (OpSrc32) {
408408
if (const X86FoldTableEntry *Mem2Bcst =
409-
llvm::lookupBroadcastFoldTable(OpSrc32, 32)) {
409+
llvm::lookupBroadcastFoldTableBySize(OpSrc32, 32)) {
410410
OpBcst32 = Mem2Bcst->DstOp;
411411
OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
412412
}
413413
}
414414
if (OpSrc64) {
415415
if (const X86FoldTableEntry *Mem2Bcst =
416-
llvm::lookupBroadcastFoldTable(OpSrc64, 64)) {
416+
llvm::lookupBroadcastFoldTableBySize(OpSrc64, 64)) {
417417
OpBcst64 = Mem2Bcst->DstOp;
418418
OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
419419
}

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
10671067
MaskInfo.RC:$src0))],
10681068
DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;
10691069

1070-
let hasSideEffects = 0, mayLoad = 1 in
1070+
let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
10711071
def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
10721072
(ins SrcInfo.ScalarMemOp:$src),
10731073
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),

llvm/lib/Target/X86/X86InstrFoldTables.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,23 @@ llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
145145
return lookupFoldTableImpl(FoldTable, RegOp);
146146
}
147147

148+
const X86FoldTableEntry *
149+
llvm::lookupBroadcastFoldTable(unsigned RegOp, unsigned OpNum) {
150+
ArrayRef<X86FoldTableEntry> FoldTable;
151+
if (OpNum == 1)
152+
FoldTable = ArrayRef(BroadcastTable1);
153+
else if (OpNum == 2)
154+
FoldTable = ArrayRef(BroadcastTable2);
155+
else if (OpNum == 3)
156+
FoldTable = ArrayRef(BroadcastTable3);
157+
else if (OpNum == 4)
158+
FoldTable = ArrayRef(BroadcastTable4);
159+
else
160+
return nullptr;
161+
162+
return lookupFoldTableImpl(FoldTable, RegOp);
163+
}
164+
148165
namespace {
149166

150167
// This class stores the memory unfolding tables. It is instantiated as a
@@ -288,8 +305,8 @@ struct X86BroadcastFoldTable {
288305
};
289306
} // namespace
290307

291-
static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
292-
unsigned BroadcastBits) {
308+
bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
309+
unsigned BroadcastBits) {
293310
switch (Entry.Flags & TB_BCAST_MASK) {
294311
case TB_BCAST_W:
295312
case TB_BCAST_SH:
@@ -305,7 +322,7 @@ static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
305322
}
306323

307324
const X86FoldTableEntry *
308-
llvm::lookupBroadcastFoldTable(unsigned MemOp, unsigned BroadcastBits) {
325+
llvm::lookupBroadcastFoldTableBySize(unsigned MemOp, unsigned BroadcastBits) {
309326
static X86BroadcastFoldTable BroadcastFoldTable;
310327
auto &Table = BroadcastFoldTable.Table;
311328
for (auto I = llvm::lower_bound(Table, MemOp);

llvm/lib/Target/X86/X86InstrFoldTables.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,20 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
4444
// operand OpNum.
4545
const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);
4646

47+
// Look up the broadcast folding table entry for folding a broadcast with
48+
// operand OpNum.
49+
const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
50+
unsigned OpNum);
51+
4752
// Look up the memory unfolding table entry for this instruction.
4853
const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
4954

5055
// Look up the broadcast folding table entry for this instruction from
5156
// the regular memory instruction.
52-
const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned MemOp,
57+
const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
5358
unsigned BroadcastBits);
5459

60+
bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
5561
} // namespace llvm
5662

5763
#endif

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
862862
case X86::MMX_MOVD64rm:
863863
case X86::MMX_MOVQ64rm:
864864
// AVX-512
865+
case X86::VPBROADCASTBZ128rm:
866+
case X86::VPBROADCASTBZ256rm:
867+
case X86::VPBROADCASTBZrm:
868+
case X86::VBROADCASTF32X2Z256rm:
869+
case X86::VBROADCASTF32X2Zrm:
870+
case X86::VBROADCASTI32X2Z128rm:
871+
case X86::VBROADCASTI32X2Z256rm:
872+
case X86::VBROADCASTI32X2Zrm:
873+
case X86::VPBROADCASTWZ128rm:
874+
case X86::VPBROADCASTWZ256rm:
875+
case X86::VPBROADCASTWZrm:
876+
case X86::VPBROADCASTDZ128rm:
877+
case X86::VPBROADCASTDZ256rm:
878+
case X86::VPBROADCASTDZrm:
879+
case X86::VBROADCASTSSZ128rm:
880+
case X86::VBROADCASTSSZ256rm:
881+
case X86::VBROADCASTSSZrm:
882+
case X86::VPBROADCASTQZ128rm:
883+
case X86::VPBROADCASTQZ256rm:
884+
case X86::VPBROADCASTQZrm:
885+
case X86::VBROADCASTSDZ256rm:
886+
case X86::VBROADCASTSDZrm:
865887
case X86::VMOVSSZrm:
866888
case X86::VMOVSSZrm_alt:
867889
case X86::VMOVSDZrm:
@@ -8063,6 +8085,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80638085
MOs.push_back(MachineOperand::CreateReg(0, false));
80648086
break;
80658087
}
8088+
case X86::VPBROADCASTBZ128rm:
8089+
case X86::VPBROADCASTBZ256rm:
8090+
case X86::VPBROADCASTBZrm:
8091+
case X86::VBROADCASTF32X2Z256rm:
8092+
case X86::VBROADCASTF32X2Zrm:
8093+
case X86::VBROADCASTI32X2Z128rm:
8094+
case X86::VBROADCASTI32X2Z256rm:
8095+
case X86::VBROADCASTI32X2Zrm:
8096+
// No instructions currently fuse with 8bits or 32bits x 2.
8097+
return nullptr;
8098+
8099+
#define FOLD_BROADCAST(SIZE) \
8100+
MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands, \
8101+
LoadMI.operands_begin() + NumOps); \
8102+
return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE, \
8103+
Alignment, /*AllowCommute=*/true);
8104+
case X86::VPBROADCASTWZ128rm:
8105+
case X86::VPBROADCASTWZ256rm:
8106+
case X86::VPBROADCASTWZrm:
8107+
FOLD_BROADCAST(16);
8108+
case X86::VPBROADCASTDZ128rm:
8109+
case X86::VPBROADCASTDZ256rm:
8110+
case X86::VPBROADCASTDZrm:
8111+
case X86::VBROADCASTSSZ128rm:
8112+
case X86::VBROADCASTSSZ256rm:
8113+
case X86::VBROADCASTSSZrm:
8114+
FOLD_BROADCAST(32);
8115+
case X86::VPBROADCASTQZ128rm:
8116+
case X86::VPBROADCASTQZ256rm:
8117+
case X86::VPBROADCASTQZrm:
8118+
case X86::VBROADCASTSDZ256rm:
8119+
case X86::VBROADCASTSDZrm:
8120+
FOLD_BROADCAST(64);
80668121
default: {
80678122
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
80688123
return nullptr;
@@ -8077,6 +8132,80 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
80778132
/*Size=*/0, Alignment, /*AllowCommute=*/true);
80788133
}
80798134

8135+
MachineInstr *X86InstrInfo::foldMemoryBroadcast(
8136+
MachineFunction &MF, MachineInstr &MI, unsigned OpNum,
8137+
ArrayRef<MachineOperand> MOs, MachineBasicBlock::iterator InsertPt,
8138+
unsigned BitsSize, Align Alignment, bool AllowCommute) const {
8139+
8140+
const X86FoldTableEntry *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum);
8141+
8142+
if (I)
8143+
return matchBroadcastSize(*I, BitsSize)
8144+
? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
8145+
: nullptr;
8146+
8147+
// TODO: Share code with foldMemoryOperandImpl for the commute
8148+
if (AllowCommute) {
8149+
unsigned CommuteOpIdx1 = OpNum, CommuteOpIdx2 = CommuteAnyOperandIndex;
8150+
if (findCommutedOpIndices(MI, CommuteOpIdx1, CommuteOpIdx2)) {
8151+
bool HasDef = MI.getDesc().getNumDefs();
8152+
Register Reg0 = HasDef ? MI.getOperand(0).getReg() : Register();
8153+
Register Reg1 = MI.getOperand(CommuteOpIdx1).getReg();
8154+
Register Reg2 = MI.getOperand(CommuteOpIdx2).getReg();
8155+
bool Tied1 =
8156+
0 == MI.getDesc().getOperandConstraint(CommuteOpIdx1, MCOI::TIED_TO);
8157+
bool Tied2 =
8158+
0 == MI.getDesc().getOperandConstraint(CommuteOpIdx2, MCOI::TIED_TO);
8159+
8160+
// If either of the commutable operands are tied to the destination
8161+
// then we can not commute + fold.
8162+
if ((HasDef && Reg0 == Reg1 && Tied1) ||
8163+
(HasDef && Reg0 == Reg2 && Tied2))
8164+
return nullptr;
8165+
8166+
MachineInstr *CommutedMI =
8167+
commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
8168+
if (!CommutedMI) {
8169+
// Unable to commute.
8170+
return nullptr;
8171+
}
8172+
if (CommutedMI != &MI) {
8173+
// New instruction. We can't fold from this.
8174+
CommutedMI->eraseFromParent();
8175+
return nullptr;
8176+
}
8177+
8178+
// Attempt to fold with the commuted version of the instruction.
8179+
MachineInstr *NewMI = foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs,
8180+
InsertPt, BitsSize, Alignment,
8181+
/*AllowCommute=*/false);
8182+
if (NewMI)
8183+
return NewMI;
8184+
8185+
// Folding failed again - undo the commute before returning.
8186+
MachineInstr *UncommutedMI =
8187+
commuteInstruction(MI, false, CommuteOpIdx1, CommuteOpIdx2);
8188+
if (!UncommutedMI) {
8189+
// Unable to commute.
8190+
return nullptr;
8191+
}
8192+
if (UncommutedMI != &MI) {
8193+
// New instruction. It doesn't need to be kept.
8194+
UncommutedMI->eraseFromParent();
8195+
return nullptr;
8196+
}
8197+
8198+
// Return here to prevent duplicate fuse failure report.
8199+
return nullptr;
8200+
}
8201+
}
8202+
8203+
// No fusion
8204+
if (PrintFailedFusing && !MI.isCopy())
8205+
dbgs() << "We failed to fuse operand " << OpNum << " in " << MI;
8206+
return nullptr;
8207+
}
8208+
80808209
static SmallVector<MachineMemOperand *, 2>
80818210
extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
80828211
SmallVector<MachineMemOperand *, 2> LoadMMOs;
@@ -8130,6 +8259,18 @@ static unsigned getBroadcastOpcode(const X86FoldTableEntry *I,
81308259
switch (I->Flags & TB_BCAST_MASK) {
81318260
default:
81328261
llvm_unreachable("Unexpected broadcast type!");
8262+
case TB_BCAST_W:
8263+
switch (SpillSize) {
8264+
default:
8265+
llvm_unreachable("Unknown spill size");
8266+
case 16:
8267+
return X86::VPBROADCASTWZ128rm;
8268+
case 32:
8269+
return X86::VPBROADCASTWZ256rm;
8270+
case 64:
8271+
return X86::VPBROADCASTWZrm;
8272+
}
8273+
break;
81338274
case TB_BCAST_D:
81348275
switch (SpillSize) {
81358276
default:
@@ -8191,7 +8332,11 @@ bool X86InstrInfo::unfoldMemoryOperand(
81918332
unsigned Index = I->Flags & TB_INDEX_MASK;
81928333
bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
81938334
bool FoldedStore = I->Flags & TB_FOLDED_STORE;
8194-
bool FoldedBCast = I->Flags & TB_FOLDED_BCAST;
8335+
unsigned BCastType = I->Flags & TB_FOLDED_BCAST;
8336+
// FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
8337+
if (BCastType == TB_BCAST_SH)
8338+
return false;
8339+
81958340
if (UnfoldLoad && !FoldedLoad)
81968341
return false;
81978342
UnfoldLoad &= FoldedLoad;
@@ -8231,7 +8376,7 @@ bool X86InstrInfo::unfoldMemoryOperand(
82318376
auto MMOs = extractLoadMMOs(MI.memoperands(), MF);
82328377

82338378
unsigned Opc;
8234-
if (FoldedBCast) {
8379+
if (BCastType) {
82358380
Opc = getBroadcastOpcode(I, RC, Subtarget);
82368381
} else {
82378382
unsigned Alignment = std::max<uint32_t>(TRI.getSpillSize(*RC), 16);
@@ -8341,7 +8486,10 @@ bool X86InstrInfo::unfoldMemoryOperand(
83418486
unsigned Index = I->Flags & TB_INDEX_MASK;
83428487
bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
83438488
bool FoldedStore = I->Flags & TB_FOLDED_STORE;
8344-
bool FoldedBCast = I->Flags & TB_FOLDED_BCAST;
8489+
unsigned BCastType = I->Flags & TB_FOLDED_BCAST;
8490+
// FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
8491+
if (BCastType == TB_BCAST_SH)
8492+
return false;
83458493
const MCInstrDesc &MCID = get(Opc);
83468494
MachineFunction &MF = DAG.getMachineFunction();
83478495
const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
@@ -8377,7 +8525,7 @@ bool X86InstrInfo::unfoldMemoryOperand(
83778525
// memory access is slow above.
83788526

83798527
unsigned Opc;
8380-
if (FoldedBCast) {
8528+
if (BCastType) {
83818529
Opc = getBroadcastOpcode(I, RC, Subtarget);
83828530
} else {
83838531
unsigned Alignment = std::max<uint32_t>(TRI.getSpillSize(*RC), 16);

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,13 @@ class X86InstrInfo final : public X86GenInstrInfo {
643643
MachineBasicBlock::iterator InsertPt,
644644
unsigned Size, Align Alignment) const;
645645

646+
MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
647+
unsigned OpNum,
648+
ArrayRef<MachineOperand> MOs,
649+
MachineBasicBlock::iterator InsertPt,
650+
unsigned BitsSize, Align Alignment,
651+
bool AllowCommute) const;
652+
646653
/// isFrameOperand - Return true and the FrameIndex if the specified
647654
/// operand and follow operands form a reference to the stack frame.
648655
bool isFrameOperand(const MachineInstr &MI, unsigned int Op,

0 commit comments

Comments
 (0)