@@ -470,14 +470,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
470
470
471
471
// / Return true if `type` is the E5M2 variant of an 8-bit float that is
472
472
// / supported by the `_bf8` instructions on the given `chipset`.
473
- static bool isNativeBf8 (Chipset chipset, Type type) {
473
+ static bool typeIsExpectedBf8ForChipset (Chipset chipset, Type type) {
474
474
return (isGfx940Series (chipset) && isa<Float8E5M2FNUZType>(type)) ||
475
475
(hasOcpFp8 (chipset) && isa<Float8E5M2Type>(type));
476
476
}
477
477
478
478
// / Return true if `type` is the E4M3FN variant of an 8-bit float that is
479
479
// / supported by the `_fp8` instructions on the given `chipset`.
480
- static bool isNativeFp8 (Chipset chipset, Type type) {
480
+ static bool typeIsExpectedFp8ForChipset (Chipset chipset, Type type) {
481
481
return (isGfx940Series (chipset) && isa<Float8E4M3FNUZType>(type)) ||
482
482
(hasOcpFp8 (chipset) && isa<Float8E4M3FNType>(type));
483
483
}
@@ -578,38 +578,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
578
578
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
579
579
}
580
580
581
- if (destElem.isF32 () && isNativeBf8 (chipset, sourceElem)) {
581
+ if (destElem.isF32 () && typeIsExpectedBf8ForChipset (chipset, sourceElem)) {
582
582
// Known to be correct because there are no scalar f8 instructions and
583
583
// because a length mismatch will have been caught by the verifier.
584
584
Type sourceBElem =
585
585
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
586
586
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
587
- if (isNativeBf8 (chipset, sourceBElem))
587
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
588
588
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName ();
589
- if (isNativeFp8 (chipset, sourceBElem))
589
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
590
590
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName ();
591
591
}
592
592
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
593
- if (isNativeBf8 (chipset, sourceBElem))
593
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
594
594
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName ();
595
- if (isNativeFp8 (chipset, sourceBElem))
595
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
596
596
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName ();
597
597
}
598
598
}
599
599
600
- if (destElem.isF32 () && isNativeFp8 (chipset, sourceElem)) {
600
+ if (destElem.isF32 () && typeIsExpectedFp8ForChipset (chipset, sourceElem)) {
601
601
Type sourceBElem =
602
602
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
603
603
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
604
- if (isNativeBf8 (chipset, sourceBElem))
604
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
605
605
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName ();
606
- if (isNativeFp8 (chipset, sourceBElem))
606
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
607
607
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName ();
608
608
}
609
609
if (m == 32 && n == 32 && k == 16 && b == 1 ) {
610
- if (isNativeBf8 (chipset, sourceBElem))
610
+ if (typeIsExpectedBf8ForChipset (chipset, sourceBElem))
611
611
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName ();
612
- if (isNativeFp8 (chipset, sourceBElem))
612
+ if (typeIsExpectedFp8ForChipset (chipset, sourceBElem))
613
613
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName ();
614
614
}
615
615
}
@@ -817,10 +817,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
817
817
}
818
818
Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
819
819
Value wordSel = createI32Constant (rewriter, loc, op.getIndex ());
820
- if (isNativeBf8 (chipset, sourceElemType)) {
820
+ if (typeIsExpectedBf8ForChipset (chipset, sourceElemType)) {
821
821
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Bf8Op>(op, f32 , i32Source,
822
822
wordSel);
823
- } else if (isNativeFp8 (chipset, sourceElemType)) {
823
+ } else if (typeIsExpectedFp8ForChipset (chipset, sourceElemType)) {
824
824
rewriter.replaceOpWithNewOp <ROCDL::CvtF32Fp8Op>(op, f32 , i32Source,
825
825
wordSel);
826
826
}
@@ -852,10 +852,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
852
852
Value wordSel = createI1Constant (rewriter, loc, op.getWordIndex ());
853
853
854
854
Value result;
855
- if (isNativeBf8 (chipset, resultElemType))
855
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
856
856
result = rewriter.create <ROCDL::CvtPkBf8F32Op>(loc, i32 , sourceA, sourceB,
857
857
existing, wordSel);
858
- else if (isNativeFp8 (chipset, resultElemType))
858
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
859
859
result = rewriter.create <ROCDL::CvtPkFp8F32Op>(loc, i32 , sourceA, sourceB,
860
860
existing, wordSel);
861
861
@@ -887,10 +887,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
887
887
Value byteSel = createI32Constant (rewriter, loc, op.getStoreIndex ());
888
888
889
889
Value result;
890
- if (isNativeBf8 (chipset, resultElemType))
890
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
891
891
result = rewriter.create <ROCDL::CvtSrBf8F32Op>(loc, i32 , source, stoch,
892
892
existing, byteSel);
893
- else if (isNativeFp8 (chipset, resultElemType))
893
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
894
894
result = rewriter.create <ROCDL::CvtSrFp8F32Op>(loc, i32 , source, stoch,
895
895
existing, byteSel);
896
896
0 commit comments