@@ -476,14 +476,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
476
476
477
477
// / Return true if `type` is the E5M2 variant of an 8-bit float that is
478
478
// / supported by the `_bf8` instructions on the given `chipset`.
479
- static bool isNativeBf8 (Chipset chipset, Type type) {
479
+ static bool typeIsExpectedBf8ForChipset (Chipset chipset, Type type) {
480
480
return (isGfx940Series (chipset) && isa<Float8E5M2FNUZType>(type)) ||
481
481
(hasOcpFp8 (chipset) && isa<Float8E5M2Type>(type));
482
482
}
483
483
484
484
// / Return true if `type` is the E4M3FN variant of an 8-bit float that is
485
485
// / supported by the `_fp8` instructions on the given `chipset`.
486
- static bool isNativeFp8 (Chipset chipset, Type type) {
486
+ static bool typeIsExpectedFp8ForChipset (Chipset chipset, Type type) {
487
487
return (isGfx940Series (chipset) && isa<Float8E4M3FNUZType>(type)) ||
488
488
(hasOcpFp8 (chipset) && isa<Float8E4M3FNType>(type));
489
489
}
@@ -584,38 +584,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
584
584
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
585
585
}
586
586
587
- if (destElem.isF32 () && isNativeBf8 (chipset, sourceElem)) {
587
+ if (destElem.isF32 () && typeIsExpectedBf8ForChipset (chipset, sourceElem)) {
588
588
// Known to be correct because there are no scalar f8 instructions and
589
589
// because a length mismatch will have been caught by the verifier.
590
590
Type sourceBElem =
591
591
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
592
592
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
593
- if (isNativeBf8 (chipset, sourceBElem))
593
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
594
594
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
595
- if (isNativeFp8 (chipset, sourceBElem))
595
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
596
596
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
597
597
}
598
598
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
599
- if (isNativeBf8 (chipset, sourceBElem))
599
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
600
600
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
601
- if (isNativeFp8 (chipset, sourceBElem))
601
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
602
602
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
603
603
}
604
604
}
605
605
606
- if (destElem.isF32 () && isNativeFp8 (chipset, sourceElem)) {
606
+ if (destElem.isF32 () && typeIsExpectedFp8ForChipset (chipset, sourceElem)) {
607
607
Type sourceBElem =
608
608
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
609
609
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
610
- if (isNativeBf8 (chipset, sourceBElem))
610
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
611
611
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
612
- if (isNativeFp8 (chipset, sourceBElem))
612
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
613
613
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
614
614
}
615
615
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
616
- if (isNativeBf8 (chipset, sourceBElem))
616
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
617
617
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
618
- if (isNativeFp8 (chipset, sourceBElem))
618
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
619
619
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
620
620
}
621
621
}
@@ -823,10 +823,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
823
823
}
824
824
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
825
825
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
826
- if (isNativeBf8 (chipset, sourceElemType)) {
826
+ if (typeIsExpectedBf8ForChipset (chipset, sourceElemType)) {
827
827
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
828
828
wordSel);
829
- } else if (isNativeFp8 (chipset, sourceElemType)) {
829
+ } else if (typeIsExpectedFp8ForChipset (chipset, sourceElemType)) {
830
830
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
831
831
wordSel);
832
832
}
@@ -858,10 +858,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
858
858
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
859
859
860
860
Value result;
861
- if (isNativeBf8 (chipset, resultElemType))
861
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
862
862
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
863
863
existing, wordSel);
864
- else if (isNativeFp8 (chipset, resultElemType))
864
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
865
865
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
866
866
existing, wordSel);
867
867
@@ -893,10 +893,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
893
893
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
894
894
895
895
Value result;
896
- if (isNativeBf8 (chipset, resultElemType))
896
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
897
897
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
898
898
existing, byteSel);
899
- else if (isNativeFp8 (chipset, resultElemType))
899
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
900
900
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
901
901
existing, byteSel);
902
902
0 commit comments