@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
862
862
case X86::MMX_MOVD64rm:
863
863
case X86::MMX_MOVQ64rm:
864
864
// AVX-512
865
+ case X86::VPBROADCASTBZ128rm:
866
+ case X86::VPBROADCASTBZ256rm:
867
+ case X86::VPBROADCASTBZrm:
868
+ case X86::VBROADCASTF32X2Z256rm:
869
+ case X86::VBROADCASTF32X2Zrm:
870
+ case X86::VBROADCASTI32X2Z128rm:
871
+ case X86::VBROADCASTI32X2Z256rm:
872
+ case X86::VBROADCASTI32X2Zrm:
873
+ case X86::VPBROADCASTWZ128rm:
874
+ case X86::VPBROADCASTWZ256rm:
875
+ case X86::VPBROADCASTWZrm:
876
+ case X86::VPBROADCASTDZ128rm:
877
+ case X86::VPBROADCASTDZ256rm:
878
+ case X86::VPBROADCASTDZrm:
879
+ case X86::VBROADCASTSSZ128rm:
880
+ case X86::VBROADCASTSSZ256rm:
881
+ case X86::VBROADCASTSSZrm:
882
+ case X86::VPBROADCASTQZ128rm:
883
+ case X86::VPBROADCASTQZ256rm:
884
+ case X86::VPBROADCASTQZrm:
885
+ case X86::VBROADCASTSDZ256rm:
886
+ case X86::VBROADCASTSDZrm:
865
887
case X86::VMOVSSZrm:
866
888
case X86::VMOVSSZrm_alt:
867
889
case X86::VMOVSDZrm:
@@ -8063,6 +8085,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
8063
8085
MOs.push_back (MachineOperand::CreateReg (0 , false ));
8064
8086
break ;
8065
8087
}
8088
+ case X86::VPBROADCASTBZ128rm:
8089
+ case X86::VPBROADCASTBZ256rm:
8090
+ case X86::VPBROADCASTBZrm:
8091
+ case X86::VBROADCASTF32X2Z256rm:
8092
+ case X86::VBROADCASTF32X2Zrm:
8093
+ case X86::VBROADCASTI32X2Z128rm:
8094
+ case X86::VBROADCASTI32X2Z256rm:
8095
+ case X86::VBROADCASTI32X2Zrm:
8096
+ // No instructions currently fuse with 8bits or 32bits x 2.
8097
+ return nullptr ;
8098
+
8099
+ #define FOLD_BROADCAST (SIZE ) \
8100
+ MOs.append (LoadMI.operands_begin () + NumOps - X86::AddrNumOperands, \
8101
+ LoadMI.operands_begin () + NumOps); \
8102
+ return foldMemoryBroadcast (MF, MI, Ops[0 ], MOs, InsertPt, /* Size=*/ SIZE, \
8103
+ Alignment, /* AllowCommute=*/ true );
8104
+ case X86::VPBROADCASTWZ128rm:
8105
+ case X86::VPBROADCASTWZ256rm:
8106
+ case X86::VPBROADCASTWZrm:
8107
+ FOLD_BROADCAST (16 );
8108
+ case X86::VPBROADCASTDZ128rm:
8109
+ case X86::VPBROADCASTDZ256rm:
8110
+ case X86::VPBROADCASTDZrm:
8111
+ case X86::VBROADCASTSSZ128rm:
8112
+ case X86::VBROADCASTSSZ256rm:
8113
+ case X86::VBROADCASTSSZrm:
8114
+ FOLD_BROADCAST (32 );
8115
+ case X86::VPBROADCASTQZ128rm:
8116
+ case X86::VPBROADCASTQZ256rm:
8117
+ case X86::VPBROADCASTQZrm:
8118
+ case X86::VBROADCASTSDZ256rm:
8119
+ case X86::VBROADCASTSDZrm:
8120
+ FOLD_BROADCAST (64 );
8066
8121
default : {
8067
8122
if (isNonFoldablePartialRegisterLoad (LoadMI, MI, MF))
8068
8123
return nullptr ;
@@ -8077,6 +8132,80 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
8077
8132
/* Size=*/ 0 , Alignment, /* AllowCommute=*/ true );
8078
8133
}
8079
8134
8135
+ MachineInstr *X86InstrInfo::foldMemoryBroadcast (
8136
+ MachineFunction &MF, MachineInstr &MI, unsigned OpNum,
8137
+ ArrayRef<MachineOperand> MOs, MachineBasicBlock::iterator InsertPt,
8138
+ unsigned BitsSize, Align Alignment, bool AllowCommute) const {
8139
+
8140
+ const X86FoldTableEntry *I = lookupBroadcastFoldTable (MI.getOpcode (), OpNum);
8141
+
8142
+ if (I)
8143
+ return matchBroadcastSize (*I, BitsSize)
8144
+ ? FuseInst (MF, I->DstOp , OpNum, MOs, InsertPt, MI, *this )
8145
+ : nullptr ;
8146
+
8147
+ // TODO: Share code with foldMemoryOperandImpl for the commute
8148
+ if (AllowCommute) {
8149
+ unsigned CommuteOpIdx1 = OpNum, CommuteOpIdx2 = CommuteAnyOperandIndex;
8150
+ if (findCommutedOpIndices (MI, CommuteOpIdx1, CommuteOpIdx2)) {
8151
+ bool HasDef = MI.getDesc ().getNumDefs ();
8152
+ Register Reg0 = HasDef ? MI.getOperand (0 ).getReg () : Register ();
8153
+ Register Reg1 = MI.getOperand (CommuteOpIdx1).getReg ();
8154
+ Register Reg2 = MI.getOperand (CommuteOpIdx2).getReg ();
8155
+ bool Tied1 =
8156
+ 0 == MI.getDesc ().getOperandConstraint (CommuteOpIdx1, MCOI::TIED_TO);
8157
+ bool Tied2 =
8158
+ 0 == MI.getDesc ().getOperandConstraint (CommuteOpIdx2, MCOI::TIED_TO);
8159
+
8160
+ // If either of the commutable operands are tied to the destination
8161
+ // then we can not commute + fold.
8162
+ if ((HasDef && Reg0 == Reg1 && Tied1) ||
8163
+ (HasDef && Reg0 == Reg2 && Tied2))
8164
+ return nullptr ;
8165
+
8166
+ MachineInstr *CommutedMI =
8167
+ commuteInstruction (MI, false , CommuteOpIdx1, CommuteOpIdx2);
8168
+ if (!CommutedMI) {
8169
+ // Unable to commute.
8170
+ return nullptr ;
8171
+ }
8172
+ if (CommutedMI != &MI) {
8173
+ // New instruction. We can't fold from this.
8174
+ CommutedMI->eraseFromParent ();
8175
+ return nullptr ;
8176
+ }
8177
+
8178
+ // Attempt to fold with the commuted version of the instruction.
8179
+ MachineInstr *NewMI = foldMemoryBroadcast (MF, MI, CommuteOpIdx2, MOs,
8180
+ InsertPt, BitsSize, Alignment,
8181
+ /* AllowCommute=*/ false );
8182
+ if (NewMI)
8183
+ return NewMI;
8184
+
8185
+ // Folding failed again - undo the commute before returning.
8186
+ MachineInstr *UncommutedMI =
8187
+ commuteInstruction (MI, false , CommuteOpIdx1, CommuteOpIdx2);
8188
+ if (!UncommutedMI) {
8189
+ // Unable to commute.
8190
+ return nullptr ;
8191
+ }
8192
+ if (UncommutedMI != &MI) {
8193
+ // New instruction. It doesn't need to be kept.
8194
+ UncommutedMI->eraseFromParent ();
8195
+ return nullptr ;
8196
+ }
8197
+
8198
+ // Return here to prevent duplicate fuse failure report.
8199
+ return nullptr ;
8200
+ }
8201
+ }
8202
+
8203
+ // No fusion
8204
+ if (PrintFailedFusing && !MI.isCopy ())
8205
+ dbgs () << " We failed to fuse operand " << OpNum << " in " << MI;
8206
+ return nullptr ;
8207
+ }
8208
+
8080
8209
static SmallVector<MachineMemOperand *, 2 >
8081
8210
extractLoadMMOs (ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
8082
8211
SmallVector<MachineMemOperand *, 2 > LoadMMOs;
@@ -8130,6 +8259,18 @@ static unsigned getBroadcastOpcode(const X86FoldTableEntry *I,
8130
8259
switch (I->Flags & TB_BCAST_MASK) {
8131
8260
default :
8132
8261
llvm_unreachable (" Unexpected broadcast type!" );
8262
+ case TB_BCAST_W:
8263
+ switch (SpillSize) {
8264
+ default :
8265
+ llvm_unreachable (" Unknown spill size" );
8266
+ case 16 :
8267
+ return X86::VPBROADCASTWZ128rm;
8268
+ case 32 :
8269
+ return X86::VPBROADCASTWZ256rm;
8270
+ case 64 :
8271
+ return X86::VPBROADCASTWZrm;
8272
+ }
8273
+ break ;
8133
8274
case TB_BCAST_D:
8134
8275
switch (SpillSize) {
8135
8276
default :
@@ -8191,7 +8332,11 @@ bool X86InstrInfo::unfoldMemoryOperand(
8191
8332
unsigned Index = I->Flags & TB_INDEX_MASK;
8192
8333
bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
8193
8334
bool FoldedStore = I->Flags & TB_FOLDED_STORE;
8194
- bool FoldedBCast = I->Flags & TB_FOLDED_BCAST;
8335
+ unsigned BCastType = I->Flags & TB_FOLDED_BCAST;
8336
+ // FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
8337
+ if (BCastType == TB_BCAST_SH)
8338
+ return false ;
8339
+
8195
8340
if (UnfoldLoad && !FoldedLoad)
8196
8341
return false ;
8197
8342
UnfoldLoad &= FoldedLoad;
@@ -8231,7 +8376,7 @@ bool X86InstrInfo::unfoldMemoryOperand(
8231
8376
auto MMOs = extractLoadMMOs (MI.memoperands (), MF);
8232
8377
8233
8378
unsigned Opc;
8234
- if (FoldedBCast ) {
8379
+ if (BCastType ) {
8235
8380
Opc = getBroadcastOpcode (I, RC, Subtarget);
8236
8381
} else {
8237
8382
unsigned Alignment = std::max<uint32_t >(TRI.getSpillSize (*RC), 16 );
@@ -8341,7 +8486,10 @@ bool X86InstrInfo::unfoldMemoryOperand(
8341
8486
unsigned Index = I->Flags & TB_INDEX_MASK;
8342
8487
bool FoldedLoad = I->Flags & TB_FOLDED_LOAD;
8343
8488
bool FoldedStore = I->Flags & TB_FOLDED_STORE;
8344
- bool FoldedBCast = I->Flags & TB_FOLDED_BCAST;
8489
+ unsigned BCastType = I->Flags & TB_FOLDED_BCAST;
8490
+ // FIXME: Support TB_BCAST_SH in getBroadcastOpcode?
8491
+ if (BCastType == TB_BCAST_SH)
8492
+ return false ;
8345
8493
const MCInstrDesc &MCID = get (Opc);
8346
8494
MachineFunction &MF = DAG.getMachineFunction ();
8347
8495
const TargetRegisterInfo &TRI = *MF.getSubtarget ().getRegisterInfo ();
@@ -8377,7 +8525,7 @@ bool X86InstrInfo::unfoldMemoryOperand(
8377
8525
// memory access is slow above.
8378
8526
8379
8527
unsigned Opc;
8380
- if (FoldedBCast ) {
8528
+ if (BCastType ) {
8381
8529
Opc = getBroadcastOpcode (I, RC, Subtarget);
8382
8530
} else {
8383
8531
unsigned Alignment = std::max<uint32_t >(TRI.getSpillSize (*RC), 16 );
0 commit comments