@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
454
454
}
455
455
}
456
456
457
+ // / Return true if `type` is the E5M2 variant of an 8-bit float that is
458
+ // / supported by the `_bf8` instructions on the given `chipset`.
459
+ static bool isNativeBf8 (Chipset chipset, Type type) {
460
+ return (chipset.isGfx940 () && type.isFloat8E5M2FNUZ ()) ||
461
+ (chipset.hasOcpFp8 () && type.isFloat8E5M2 ());
462
+ }
463
+
464
+ // / Return true if `type` is the E4M3FN variant of an 8-bit float that is
465
+ // / supported by the `_fp8` instructions on the given `chipset`.
466
+ static bool isNativeFp8 (Chipset chipset, Type type) {
467
+ return (chipset.isGfx940 () && type.isFloat8E4M3FNUZ ()) ||
468
+ (chipset.hasOcpFp8 () && type.isFloat8E4M3FN ());
469
+ }
470
+
457
471
// / Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
458
472
// / if one exists. This includes checking to ensure the intrinsic is supported
459
473
// / on the architecture you are compiling for.
@@ -550,42 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
550
564
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
551
565
}
552
566
553
- if (destElem.isF32 () &&
554
- ((sourceElem.isFloat8E5M2FNUZ () && chipset >= kGfx940 ) ||
555
- (sourceElem.isFloat8E5M2 () && chipset.hasOcpFp8 ()))) {
567
+ if (destElem.isF32 () && isNativeBf8 (chipset, sourceElem)) {
556
568
// Known to be correct because there are no scalar f8 instructions and
557
569
// because a length mismatch will have been caught by the verifier.
558
570
Type sourceBElem =
559
571
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
560
572
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
561
- if (sourceBElem. isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 ( ))
573
+ if (isNativeBf8 (chipset, sourceBElem))
562
574
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
563
- if (sourceBElem. isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN ( ))
575
+ if (isNativeFp8 (chipset, sourceBElem))
564
576
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
565
577
}
566
578
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
567
- if (sourceBElem. isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 ( ))
579
+ if (isNativeBf8 (chipset, sourceBElem))
568
580
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
569
- if (sourceBElem. isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN ( ))
581
+ if (isNativeFp8 (chipset, sourceBElem))
570
582
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
571
583
}
572
584
}
573
585
574
- if (destElem.isF32 () &&
575
- ((sourceElem.isFloat8E4M3FNUZ () && chipset >= kGfx940 ) ||
576
- (sourceElem.isFloat8E4M3FN () && chipset.hasOcpFp8 ()))) {
586
+ if (destElem.isF32 () && isNativeFp8 (chipset, sourceElem)) {
577
587
Type sourceBElem =
578
588
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
579
589
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
580
- if (sourceBElem. isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 ( ))
590
+ if (isNativeBf8 (chipset, sourceBElem))
581
591
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
582
- if (sourceBElem. isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN ( ))
592
+ if (isNativeFp8 (chipset, sourceBElem))
583
593
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
584
594
}
585
595
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
586
- if (sourceBElem. isFloat8E5M2FNUZ () || sourceBElem. isFloat8E5M2 ( ))
596
+ if (isNativeBf8 (chipset, sourceBElem))
587
597
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
588
- if (sourceBElem. isFloat8E4M3FNUZ () || sourceBElem. isFloat8E4M3FN ( ))
598
+ if (isNativeFp8 (chipset, sourceBElem))
589
599
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
590
600
}
591
601
}
@@ -791,11 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
791
801
}
792
802
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
793
803
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
794
- if (sourceElemType. isFloat8E5M2FNUZ () || sourceElemType. isFloat8E5M2 ( )) {
804
+ if (isNativeBf8 (chipset, sourceElemType)) {
795
805
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
796
806
wordSel);
797
- } else if (sourceElemType.isFloat8E4M3FNUZ () ||
798
- sourceElemType.isFloat8E4M3FN ()) {
807
+ } else if (isNativeFp8 (chipset, sourceElemType)) {
799
808
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
800
809
wordSel);
801
810
}
@@ -827,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
827
836
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
828
837
829
838
Value result;
830
- if (resultElemType. isFloat8E5M2FNUZ () || resultElemType. isFloat8E5M2 ( ))
839
+ if (isNativeBf8 (chipset, resultElemType))
831
840
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
832
841
existing, wordSel);
833
- else if (resultElemType. isFloat8E4M3FNUZ () || resultElemType. isFloat8E4M3FN ( ))
842
+ else if (isNativeFp8 (chipset, resultElemType))
834
843
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
835
844
existing, wordSel);
836
845
@@ -862,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
862
871
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
863
872
864
873
Value result;
865
- if (resultElemType. isFloat8E5M2FNUZ () || resultElemType. isFloat8E5M2 ( ))
874
+ if (isNativeBf8 (chipset, resultElemType))
866
875
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
867
876
existing, byteSel);
868
- else if (resultElemType. isFloat8E4M3FNUZ () || resultElemType. isFloat8E4M3FN ( ))
877
+ else if (isNativeFp8 (chipset, resultElemType))
869
878
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
870
879
existing, byteSel);
871
880
0 commit comments