@@ -570,40 +570,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
570
570
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
571
571
}
572
572
573
- if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32 () &&
574
- chipset >= kGfx942 ) {
573
+ if (destElem.isF32 () &&
574
+ ((isa<Float8E5M2FNUZType>(sourceElem) && chipset >= kGfx942 ) ||
575
+ (isa<Float8E5M2Type>(sourceElem) && chipset.hasOcpFp8 ()))) {
575
576
// Known to be correct because there are no scalar f8 instructions and
576
577
// because a length mismatch will have been caught by the verifier.
577
578
Type sourceBElem =
578
579
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
579
580
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
580
- if (isa<Float8E5M2FNUZType>(sourceBElem))
581
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type >(sourceBElem))
581
582
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
582
- if (isa<Float8E4M3FNUZType>(sourceBElem))
583
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType >(sourceBElem))
583
584
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
584
585
}
585
586
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
586
- if (isa<Float8E5M2FNUZType>(sourceBElem))
587
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type >(sourceBElem))
587
588
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
588
- if (isa<Float8E4M3FNUZType>(sourceBElem))
589
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType >(sourceBElem))
589
590
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
590
591
}
591
592
}
592
593
593
- if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32 () &&
594
- chipset >= kGfx942 ) {
594
+ if (destElem.isF32 () &&
595
+ ((isa<Float8E4M3FNUZType>(sourceElem) && chipset >= kGfx942 ) ||
596
+ (isa<Float8E4M3FNType>(sourceElem) && chipset.hasOcpFp8 ()))) {
595
597
Type sourceBElem =
596
598
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
597
599
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
598
- if (isa<Float8E5M2FNUZType>(sourceBElem))
600
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type >(sourceBElem))
599
601
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
600
- if (isa<Float8E4M3FNUZType>(sourceBElem))
602
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType >(sourceBElem))
601
603
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
602
604
}
603
605
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
604
- if (isa<Float8E5M2FNUZType>(sourceBElem))
606
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type >(sourceBElem))
605
607
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
606
- if (isa<Float8E4M3FNUZType>(sourceBElem))
608
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType >(sourceBElem))
607
609
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
608
610
}
609
611
}
@@ -811,10 +813,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
811
813
}
812
814
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
813
815
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
814
- if (isa<Float8E5M2FNUZType>(sourceElemType)) {
816
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type >(sourceElemType)) {
815
817
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
816
818
wordSel);
817
- } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
819
+ } else if (isa<Float8E4M3FNUZType, Float8E4M3FNType >(sourceElemType)) {
818
820
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
819
821
wordSel);
820
822
}
@@ -846,10 +848,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
846
848
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
847
849
848
850
Value result;
849
- if (isa<Float8E5M2FNUZType>(resultElemType))
851
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type >(resultElemType))
850
852
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
851
853
existing, wordSel);
852
- else if (isa<Float8E4M3FNUZType>(resultElemType))
854
+ else if (isa<Float8E4M3FNUZType, Float8E4M3FNType >(resultElemType))
853
855
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
854
856
existing, wordSel);
855
857
@@ -881,10 +883,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
881
883
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
882
884
883
885
Value result;
884
- if (isa<Float8E5M2FNUZType>(resultElemType))
886
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type >(resultElemType))
885
887
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
886
888
existing, byteSel);
887
- else if (isa<Float8E4M3FNUZType>(resultElemType))
889
+ else if (isa<Float8E4M3FNUZType, Float8E4M3FNType >(resultElemType))
888
890
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
889
891
existing, byteSel);
890
892
0 commit comments