@@ -64,6 +64,7 @@ def Write8PassMAI : SchedWrite;
64
64
def Write16PassMAI : SchedWrite;
65
65
def Write4PassDGEMM : SchedWrite;
66
66
def Write8PassDGEMM : SchedWrite;
67
+ def Write16PassDGEMM : SchedWrite;
67
68
68
69
// Scalar float instructions
69
70
def WriteSFPU : SchedWrite;
@@ -94,6 +95,7 @@ def SIFullSpeedModel : SISchedMachineModel;
94
95
def SIQuarterSpeedModel : SISchedMachineModel;
95
96
def SIDPFullSpeedModel : SISchedMachineModel;
96
97
def SIDPGFX940FullSpeedModel : SISchedMachineModel;
98
+ def SIDPGFX950FullSpeedModel : SISchedMachineModel;
97
99
def GFX10SpeedModel : SISchedMachineModel;
98
100
def GFX11SpeedModel : SISchedMachineModel;
99
101
def GFX12SpeedModel : SISchedMachineModel;
@@ -169,6 +171,8 @@ multiclass SICommonWriteRes {
169
171
def : HWVALUWriteRes<Write4PassDGEMM, 4>;
170
172
let ReleaseAtCycles = [8] in
171
173
def : HWVALUWriteRes<Write8PassDGEMM, 8>;
174
+ let ReleaseAtCycles = [16] in
175
+ def : HWVALUWriteRes<Write16PassDGEMM, 16>;
172
176
173
177
let ReleaseAtCycles = [2] in
174
178
def : HWWriteRes<Write2PassMAI, [HWXDL], 2>;
@@ -201,6 +205,13 @@ def WriteCopy : SchedWriteVariant<[
201
205
SchedVar<PredIsVGPR64Copy, [Write64Bit]>,
202
206
SchedVar<NoSchedPred, [WriteSALU]>]>;
203
207
208
+ // Check if any matrix inputs are interpreted as f8 in an f8f6f4 mfma
209
+ // instruction.
210
+ def PredIsF8_MFMA_SCALE : SchedPredicate<[{
211
+ TII->getNamedOperand(*MI, AMDGPU::OpName::cbsz)->getImm() <= AMDGPU::MFMAScaleFormats::FP8_E5M2 ||
212
+ TII->getNamedOperand(*MI, AMDGPU::OpName::blgp)->getImm() <= AMDGPU::MFMAScaleFormats::FP8_E5M2
213
+ }]>;
214
+
204
215
let SchedModel = SIFullSpeedModel in {
205
216
206
217
defm : SICommonWriteRes;
@@ -299,6 +310,58 @@ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_SMFMAC_.32_32X32X")>;
299
310
300
311
} // End SchedModel = SIDPGFX940FullSpeedModel
301
312
313
+
314
+ let SchedModel = SIDPGFX950FullSpeedModel in {
315
+ defm : SICommonWriteRes;
316
+
317
+ def : HWVALUWriteRes<WriteFloatFMA, 1>;
318
+ def : HWVALUWriteRes<WriteDouble, 1>;
319
+ def : HWVALUWriteRes<WriteDoubleAdd, 1>;
320
+ def : HWVALUWriteRes<WriteDoubleCvt, 1>;
321
+ def : HWVALUWriteRes<WriteTrans64, 4>;
322
+ def : HWVALUWriteRes<WriteIntMul, 1>;
323
+ def : HWVALUWriteRes<Write64Bit, 1>;
324
+
325
+ def : InstRW<[WriteCopy], (instrs COPY)>;
326
+ def : InstRW<[Write64Bit], (instregex "^V_ACCVGPR_WRITE_B32_e64$")>;
327
+ def : InstRW<[Write2PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_4X4X")>;
328
+
329
+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X8X")>;
330
+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X16")>;
331
+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X32")>;
332
+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X64")>;
333
+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_16X16X[14][FBI]")>;
334
+
335
+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X4XF")>;
336
+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X8")>;
337
+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X16")>;
338
+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X32_")>;
339
+ def : InstRW<[Write16PassMAI, MIMFMARead], (instregex "^V_MFMA_.32_32X32X[124][FBI]")>;
340
+
341
+ def : InstRW<[Write4PassDGEMM, MIMFMARead], (instregex "^V_MFMA_.64_4X4X")>;
342
+ def : InstRW<[Write16PassDGEMM, MIMFMARead], (instregex "^V_MFMA_.64_16X16X")>;
343
+
344
+ def : InstRW<[Write4PassMAI, MIMFMARead], (instregex "^V_SMFMAC_.32_16X16X")>;
345
+ def : InstRW<[Write8PassMAI, MIMFMARead], (instregex "^V_SMFMAC_.32_32X32X")>;
346
+
347
+
348
+ // If either matrix format is f8, the instruction takes 2x as many
349
+ // cycles. TODO: This isn't reflected in MCA.
350
+ def WriteMFMAScale_16X16X128_F8F6F4 : SchedWriteVariant<[
351
+ SchedVar<PredIsF8_MFMA_SCALE, [Write8PassMAI]>,
352
+ SchedVar<NoSchedPred, [Write4PassMAI]>]>;
353
+ def WriteMFMAScale_32X32X64_F8F6F4 : SchedWriteVariant<[
354
+ SchedVar<PredIsF8_MFMA_SCALE, [Write16PassMAI]>,
355
+ SchedVar<NoSchedPred, [Write8PassMAI]>]>;
356
+
357
+ def : InstRW<[WriteMFMAScale_16X16X128_F8F6F4, MIMFMARead],
358
+ (instregex "^V_MFMA(_SCALE)?_.32_16X16X128_F8F6F4")>;
359
+ def : InstRW<[WriteMFMAScale_32X32X64_F8F6F4, MIMFMARead],
360
+ (instregex "^V_MFMA(_SCALE)?_.32_32X32X64_F8F6F4")>;
361
+
362
+ } // End SchedModel = SIDPGFX950FullSpeedModel
363
+
364
+
302
365
let SchedModel = GFX10SpeedModel in {
303
366
304
367
// The latency values are 1 / (operations / cycle).
0 commit comments