@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
862
862
case X86::MMX_MOVD64rm:
863
863
case X86::MMX_MOVQ64rm:
864
864
// AVX-512
865
+ case X86::VPBROADCASTBZ128rm:
866
+ case X86::VPBROADCASTBZ256rm:
867
+ case X86::VPBROADCASTBZrm:
868
+ case X86::VBROADCASTF32X2Z256rm:
869
+ case X86::VBROADCASTF32X2Zrm:
870
+ case X86::VBROADCASTI32X2Z128rm:
871
+ case X86::VBROADCASTI32X2Z256rm:
872
+ case X86::VBROADCASTI32X2Zrm:
873
+ case X86::VPBROADCASTWZ128rm:
874
+ case X86::VPBROADCASTWZ256rm:
875
+ case X86::VPBROADCASTWZrm:
876
+ case X86::VPBROADCASTDZ128rm:
877
+ case X86::VPBROADCASTDZ256rm:
878
+ case X86::VPBROADCASTDZrm:
879
+ case X86::VBROADCASTSSZ128rm:
880
+ case X86::VBROADCASTSSZ256rm:
881
+ case X86::VBROADCASTSSZrm:
882
+ case X86::VPBROADCASTQZ128rm:
883
+ case X86::VPBROADCASTQZ256rm:
884
+ case X86::VPBROADCASTQZrm:
885
+ case X86::VBROADCASTSDZ256rm:
886
+ case X86::VBROADCASTSDZrm:
865
887
case X86::VMOVSSZrm:
866
888
case X86::VMOVSSZrm_alt:
867
889
case X86::VMOVSDZrm:
@@ -8063,6 +8085,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
8063
8085
MOs.push_back (MachineOperand::CreateReg (0 , false ));
8064
8086
break ;
8065
8087
}
8088
+ case X86::VPBROADCASTBZ128rm:
8089
+ case X86::VPBROADCASTBZ256rm:
8090
+ case X86::VPBROADCASTBZrm:
8091
+ case X86::VBROADCASTF32X2Z256rm:
8092
+ case X86::VBROADCASTF32X2Zrm:
8093
+ case X86::VBROADCASTI32X2Z128rm:
8094
+ case X86::VBROADCASTI32X2Z256rm:
8095
+ case X86::VBROADCASTI32X2Zrm:
8096
+ // No instructions currently fuse with 8bits or 32bits x 2.
8097
+ return nullptr ;
8098
+
8099
+ #define FOLD_BROADCAST (SIZE ) \
8100
+ MOs.append (LoadMI.operands_begin () + NumOps - X86::AddrNumOperands, \
8101
+ LoadMI.operands_begin () + NumOps); \
8102
+ return foldMemoryBroadcast (MF, MI, Ops[0 ], MOs, InsertPt, /* Size=*/ SIZE, \
8103
+ Alignment, /* AllowCommute=*/ true );
8104
+ case X86::VPBROADCASTWZ128rm:
8105
+ case X86::VPBROADCASTWZ256rm:
8106
+ case X86::VPBROADCASTWZrm:
8107
+ FOLD_BROADCAST (16 );
8108
+ case X86::VPBROADCASTDZ128rm:
8109
+ case X86::VPBROADCASTDZ256rm:
8110
+ case X86::VPBROADCASTDZrm:
8111
+ case X86::VBROADCASTSSZ128rm:
8112
+ case X86::VBROADCASTSSZ256rm:
8113
+ case X86::VBROADCASTSSZrm:
8114
+ FOLD_BROADCAST (32 );
8115
+ case X86::VPBROADCASTQZ128rm:
8116
+ case X86::VPBROADCASTQZ256rm:
8117
+ case X86::VPBROADCASTQZrm:
8118
+ case X86::VBROADCASTSDZ256rm:
8119
+ case X86::VBROADCASTSDZrm:
8120
+ FOLD_BROADCAST (64 );
8066
8121
default : {
8067
8122
if (isNonFoldablePartialRegisterLoad (LoadMI, MI, MF))
8068
8123
return nullptr ;
@@ -8077,6 +8132,78 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
8077
8132
/* Size=*/ 0 , Alignment, /* AllowCommute=*/ true );
8078
8133
}
8079
8134
8135
+ MachineInstr *X86InstrInfo::foldMemoryBroadcast (
8136
+ MachineFunction &MF, MachineInstr &MI, unsigned OpNum,
8137
+ ArrayRef<MachineOperand> MOs, MachineBasicBlock::iterator InsertPt,
8138
+ unsigned BitsSize, Align Alignment, bool AllowCommute) const {
8139
+
8140
+ if (auto *I = lookupBroadcastFoldTable (MI.getOpcode (), OpNum))
8141
+ return matchBroadcastSize (*I, BitsSize)
8142
+ ? FuseInst (MF, I->DstOp , OpNum, MOs, InsertPt, MI, *this )
8143
+ : nullptr ;
8144
+
8145
+ // TODO: Share code with foldMemoryOperandImpl for the commute
8146
+ if (AllowCommute) {
8147
+ unsigned CommuteOpIdx1 = OpNum, CommuteOpIdx2 = CommuteAnyOperandIndex;
8148
+ if (findCommutedOpIndices (MI, CommuteOpIdx1, CommuteOpIdx2)) {
8149
+ bool HasDef = MI.getDesc ().getNumDefs ();
8150
+ Register Reg0 = HasDef ? MI.getOperand (0 ).getReg () : Register ();
8151
+ Register Reg1 = MI.getOperand (CommuteOpIdx1).getReg ();
8152
+ Register Reg2 = MI.getOperand (CommuteOpIdx2).getReg ();
8153
+ bool Tied1 =
8154
+ 0 == MI.getDesc ().getOperandConstraint (CommuteOpIdx1, MCOI::TIED_TO);
8155
+ bool Tied2 =
8156
+ 0 == MI.getDesc ().getOperandConstraint (CommuteOpIdx2, MCOI::TIED_TO);
8157
+
8158
+ // If either of the commutable operands are tied to the destination
8159
+ // then we can not commute + fold.
8160
+ if ((HasDef && Reg0 == Reg1 && Tied1) ||
8161
+ (HasDef && Reg0 == Reg2 && Tied2))
8162
+ return nullptr ;
8163
+
8164
+ MachineInstr *CommutedMI =
8165
+ commuteInstruction (MI, false , CommuteOpIdx1, CommuteOpIdx2);
8166
+ if (!CommutedMI) {
8167
+ // Unable to commute.
8168
+ return nullptr ;
8169
+ }
8170
+ if (CommutedMI != &MI) {
8171
+ // New instruction. We can't fold from this.
8172
+ CommutedMI->eraseFromParent ();
8173
+ return nullptr ;
8174
+ }
8175
+
8176
+ // Attempt to fold with the commuted version of the instruction.
8177
+ MachineInstr *NewMI = foldMemoryBroadcast (MF, MI, CommuteOpIdx2, MOs,
8178
+ InsertPt, BitsSize, Alignment,
8179
+ /* AllowCommute=*/ false );
8180
+ if (NewMI)
8181
+ return NewMI;
8182
+
8183
+ // Folding failed again - undo the commute before returning.
8184
+ MachineInstr *UncommutedMI =
8185
+ commuteInstruction (MI, false , CommuteOpIdx1, CommuteOpIdx2);
8186
+ if (!UncommutedMI) {
8187
+ // Unable to commute.
8188
+ return nullptr ;
8189
+ }
8190
+ if (UncommutedMI != &MI) {
8191
+ // New instruction. It doesn't need to be kept.
8192
+ UncommutedMI->eraseFromParent ();
8193
+ return nullptr ;
8194
+ }
8195
+
8196
+ // Return here to prevent duplicate fuse failure report.
8197
+ return nullptr ;
8198
+ }
8199
+ }
8200
+
8201
+ // No fusion
8202
+ if (PrintFailedFusing && !MI.isCopy ())
8203
+ dbgs () << " We failed to fuse operand " << OpNum << " in " << MI;
8204
+ return nullptr ;
8205
+ }
8206
+
8080
8207
static SmallVector<MachineMemOperand *, 2 >
8081
8208
extractLoadMMOs (ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
8082
8209
SmallVector<MachineMemOperand *, 2 > LoadMMOs;
0 commit comments