@@ -474,6 +474,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
474
474
}
475
475
}
476
476
477
+ // / Return true if `type` is the E5M2 variant of an 8-bit float that is
478
+ // / supported by the `_bf8` instructions on the given `chipset`.
479
+ static bool isNativeBf8 (Chipset chipset, Type type) {
480
+ return (chipset.isGfx940 () && isa<Float8E5M2FNUZType>(type)) ||
481
+ (chipset.hasOcpFp8 () && isa<Float8E5M2Type>(type));
482
+ }
483
+
484
+ // / Return true if `type` is the E4M3FN variant of an 8-bit float that is
485
+ // / supported by the `_fp8` instructions on the given `chipset`.
486
+ static bool isNativeFp8 (Chipset chipset, Type type) {
487
+ return (chipset.isGfx940 () && isa<Float8E4M3FNUZType>(type)) ||
488
+ (chipset.hasOcpFp8 () && isa<Float8E4M3FNType>(type));
489
+ }
490
+
477
491
// / Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
478
492
// / if one exists. This includes checking to ensure the intrinsic is supported
479
493
// / on the architecture you are compiling for.
@@ -570,42 +584,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
570
584
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
571
585
}
572
586
573
- if (destElem.isF32 () &&
574
- ((isa<Float8E5M2FNUZType>(sourceElem) && chipset >= kGfx942 ) ||
575
- (isa<Float8E5M2Type>(sourceElem) && chipset.hasOcpFp8 ()))) {
587
+ if (destElem.isF32 () && isNativeBf8 (chipset, sourceElem)) {
576
588
// Known to be correct because there are no scalar f8 instructions and
577
589
// because a length mismatch will have been caught by the verifier.
578
590
Type sourceBElem =
579
591
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
580
592
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
581
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
593
+ if (isNativeBf8 (chipset, sourceBElem))
582
594
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
583
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
595
+ if (isNativeFp8 (chipset, sourceBElem))
584
596
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
585
597
}
586
598
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
587
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
599
+ if (isNativeBf8 (chipset, sourceBElem))
588
600
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
589
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
601
+ if (isNativeFp8 (chipset, sourceBElem))
590
602
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
591
603
}
592
604
}
593
605
594
- if (destElem.isF32 () &&
595
- ((isa<Float8E4M3FNUZType>(sourceElem) && chipset >= kGfx942 ) ||
596
- (isa<Float8E4M3FNType>(sourceElem) && chipset.hasOcpFp8 ()))) {
606
+ if (destElem.isF32 () && isNativeFp8 (chipset, sourceElem)) {
597
607
Type sourceBElem =
598
608
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
599
609
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
600
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
610
+ if (isNativeBf8 (chipset, sourceBElem))
601
611
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
602
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
612
+ if (isNativeFp8 (chipset, sourceBElem))
603
613
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
604
614
}
605
615
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
606
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
616
+ if (isNativeBf8 (chipset, sourceBElem))
607
617
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
608
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
618
+ if (isNativeFp8 (chipset, sourceBElem))
609
619
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
610
620
}
611
621
}
@@ -813,10 +823,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
813
823
}
814
824
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
815
825
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
816
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceElemType)) {
826
+ if (isNativeBf8 (chipset, sourceElemType)) {
817
827
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
818
828
wordSel);
819
- } else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceElemType)) {
829
+ } else if (isNativeFp8 (chipset, sourceElemType)) {
820
830
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
821
831
wordSel);
822
832
}
@@ -848,10 +858,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
848
858
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
849
859
850
860
Value result;
851
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( resultElemType))
861
+ if (isNativeBf8 (chipset, resultElemType))
852
862
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
853
863
existing, wordSel);
854
- else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( resultElemType))
864
+ else if (isNativeFp8 (chipset, resultElemType))
855
865
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
856
866
existing, wordSel);
857
867
@@ -883,10 +893,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
883
893
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
884
894
885
895
Value result;
886
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( resultElemType))
896
+ if (isNativeBf8 (chipset, resultElemType))
887
897
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
888
898
existing, byteSel);
889
- else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( resultElemType))
899
+ else if (isNativeFp8 (chipset, resultElemType))
890
900
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
891
901
existing, byteSel);
892
902
0 commit comments