@@ -456,14 +456,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
456
456
457
457
// / Return true if `type` is the E5M2 variant of an 8-bit float that is
458
458
// / supported by the `_bf8` instructions on the given `chipset`.
459
- static bool isNativeBf8 (Chipset chipset, Type type) {
459
+ static bool typeIsExpectedBf8ForChipset (Chipset chipset, Type type) {
460
460
return (isGfx940Series (chipset) && type.isFloat8E5M2FNUZ ()) ||
461
461
(hasOcpFp8 (chipset) && type.isFloat8E5M2 ());
462
462
}
463
463
464
464
// / Return true if `type` is the E4M3FN variant of an 8-bit float that is
465
465
// / supported by the `_fp8` instructions on the given `chipset`.
466
- static bool isNativeFp8 (Chipset chipset, Type type) {
466
+ static bool typeIsExpectedFp8ForChipset (Chipset chipset, Type type) {
467
467
return (isGfx940Series (chipset) && type.isFloat8E4M3FNUZ ()) ||
468
468
(hasOcpFp8 (chipset) && type.isFloat8E4M3FN ());
469
469
}
@@ -564,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
564
564
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
565
565
}
566
566
567
- if (destElem.isF32 () && isNativeBf8 (chipset, sourceElem)) {
567
+ if (destElem.isF32 () && typeIsExpectedBf8ForChipset (chipset, sourceElem)) {
568
568
// Known to be correct because there are no scalar f8 instructions and
569
569
// because a length mismatch will have been caught by the verifier.
570
570
Type sourceBElem =
571
571
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
572
572
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
573
- if (isNativeBf8 (chipset, sourceBElem))
573
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
574
574
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
575
- if (isNativeFp8 (chipset, sourceBElem))
575
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
576
576
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
577
577
}
578
578
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
579
- if (isNativeBf8 (chipset, sourceBElem))
579
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
580
580
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
581
- if (isNativeFp8 (chipset, sourceBElem))
581
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
582
582
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
583
583
}
584
584
}
585
585
586
- if (destElem.isF32 () && isNativeFp8 (chipset, sourceElem)) {
586
+ if (destElem.isF32 () && typeIsExpectedFp8ForChipset (chipset, sourceElem)) {
587
587
Type sourceBElem =
588
588
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
589
589
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
590
- if (isNativeBf8 (chipset, sourceBElem))
590
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
591
591
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
592
- if (isNativeFp8 (chipset, sourceBElem))
592
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
593
593
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
594
594
}
595
595
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
596
- if (isNativeBf8 (chipset, sourceBElem))
596
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
597
597
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
598
- if (isNativeFp8 (chipset, sourceBElem))
598
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
599
599
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
600
600
}
601
601
}
@@ -801,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
801
801
}
802
802
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
803
803
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
804
- if (isNativeBf8 (chipset, sourceElemType)) {
804
+ if (typeIsExpectedBf8ForChipset (chipset, sourceElemType)) {
805
805
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
806
806
wordSel);
807
- } else if (isNativeFp8 (chipset, sourceElemType)) {
807
+ } else if (typeIsExpectedFp8ForChipset (chipset, sourceElemType)) {
808
808
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
809
809
wordSel);
810
810
}
@@ -836,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
836
836
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
837
837
838
838
Value result;
839
- if (isNativeBf8 (chipset, resultElemType))
839
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
840
840
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
841
841
existing, wordSel);
842
- else if (isNativeFp8 (chipset, resultElemType))
842
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
843
843
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
844
844
existing, wordSel);
845
845
@@ -871,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
871
871
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
872
872
873
873
Value result;
874
- if (isNativeBf8 (chipset, resultElemType))
874
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
875
875
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
876
876
existing, byteSel);
877
- else if (isNativeFp8 (chipset, resultElemType))
877
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
878
878
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
879
879
existing, byteSel);
880
880
0 commit comments