@@ -468,6 +468,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
468
468
}
469
469
}
470
470
471
+ // / Return true if `type` is the E5M2 variant of an 8-bit float that is
472
+ // / supported by the `_bf8` instructions on the given `chipset`.
473
+ static bool isNativeBf8 (Chipset chipset, Type type) {
474
+ return (chipset.isGfx940 () && isa<Float8E5M2FNUZType>(type)) ||
475
+ (chipset.hasOcpFp8 () && isa<Float8E5M2Type>(type));
476
+ }
477
+
478
+ // / Return true if `type` is the E4M3FN variant of an 8-bit float that is
479
+ // / supported by the `_fp8` instructions on the given `chipset`.
480
+ static bool isNativeFp8 (Chipset chipset, Type type) {
481
+ return (chipset.isGfx940 () && isa<Float8E4M3FNUZType>(type)) ||
482
+ (chipset.hasOcpFp8 () && isa<Float8E4M3FNType>(type));
483
+ }
484
+
471
485
// / Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
472
486
// / if one exists. This includes checking to ensure the intrinsic is supported
473
487
// / on the architecture you are compiling for.
@@ -564,42 +578,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
564
578
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
565
579
}
566
580
567
- if (destElem.isF32 () &&
568
- ((isa<Float8E5M2FNUZType>(sourceElem) && chipset >= kGfx942 ) ||
569
- (isa<Float8E5M2Type>(sourceElem) && chipset.hasOcpFp8 ()))) {
581
+ if (destElem.isF32 () && isNativeBf8 (chipset, sourceElem)) {
570
582
// Known to be correct because there are no scalar f8 instructions and
571
583
// because a length mismatch will have been caught by the verifier.
572
584
Type sourceBElem =
573
585
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
574
586
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
575
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
587
+ if (isNativeBf8 (chipset, sourceBElem))
576
588
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
577
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
589
+ if (isNativeFp8 (chipset, sourceBElem))
578
590
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
579
591
}
580
592
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
581
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
593
+ if (isNativeBf8 (chipset, sourceBElem))
582
594
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
583
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
595
+ if (isNativeFp8 (chipset, sourceBElem))
584
596
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
585
597
}
586
598
}
587
599
588
- if (destElem.isF32 () &&
589
- ((isa<Float8E4M3FNUZType>(sourceElem) && chipset >= kGfx942 ) ||
590
- (isa<Float8E4M3FNType>(sourceElem) && chipset.hasOcpFp8 ()))) {
600
+ if (destElem.isF32 () && isNativeFp8 (chipset, sourceElem)) {
591
601
Type sourceBElem =
592
602
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
593
603
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
594
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
604
+ if (isNativeBf8 (chipset, sourceBElem))
595
605
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
596
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
606
+ if (isNativeFp8 (chipset, sourceBElem))
597
607
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
598
608
}
599
609
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
600
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceBElem))
610
+ if (isNativeBf8 (chipset, sourceBElem))
601
611
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
602
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceBElem))
612
+ if (isNativeFp8 (chipset, sourceBElem))
603
613
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
604
614
}
605
615
}
@@ -807,10 +817,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
807
817
}
808
818
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
809
819
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
810
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( sourceElemType)) {
820
+ if (isNativeBf8 (chipset, sourceElemType)) {
811
821
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
812
822
wordSel);
813
- } else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( sourceElemType)) {
823
+ } else if (isNativeFp8 (chipset, sourceElemType)) {
814
824
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
815
825
wordSel);
816
826
}
@@ -842,10 +852,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
842
852
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
843
853
844
854
Value result;
845
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( resultElemType))
855
+ if (isNativeBf8 (chipset, resultElemType))
846
856
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
847
857
existing, wordSel);
848
- else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( resultElemType))
858
+ else if (isNativeFp8 (chipset, resultElemType))
849
859
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
850
860
existing, wordSel);
851
861
@@ -877,10 +887,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
877
887
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
878
888
879
889
Value result;
880
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>( resultElemType))
890
+ if (isNativeBf8 (chipset, resultElemType))
881
891
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
882
892
existing, byteSel);
883
- else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>( resultElemType))
893
+ else if (isNativeFp8 (chipset, resultElemType))
884
894
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
885
895
existing, byteSel);
886
896
0 commit comments