Skip to content

Commit a62c37c

Browse files
committed
[MLIR][AMDGPU] Renaming using suggestions from review.
1 parent de5a263 commit a62c37c

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
456456

457457
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
458458
/// supported by the `_bf8` instructions on the given `chipset`.
459-
static bool isNativeBf8(Chipset chipset, Type type) {
459+
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
460460
return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) ||
461461
(hasOcpFp8(chipset) && type.isFloat8E5M2());
462462
}
463463

464464
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
465465
/// supported by the `_fp8` instructions on the given `chipset`.
466-
static bool isNativeFp8(Chipset chipset, Type type) {
466+
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
467467
return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) ||
468468
(hasOcpFp8(chipset) && type.isFloat8E4M3FN());
469469
}
@@ -564,38 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
564564
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
565565
}
566566

567-
if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
567+
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
568568
// Known to be correct because there are no scalar f8 instructions and
569569
// because a length mismatch will have been caught by the verifier.
570570
Type sourceBElem =
571571
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
572572
if (m == 16 && n == 16 && k == 32 && b == 1) {
573-
if (isNativeBf8(chipset, sourceBElem))
573+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
574574
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
575-
if (isNativeFp8(chipset, sourceBElem))
575+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
576576
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
577577
}
578578
if (m == 32 && n == 32 && k == 16 && b == 1) {
579-
if (isNativeBf8(chipset, sourceBElem))
579+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
580580
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
581-
if (isNativeFp8(chipset, sourceBElem))
581+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
582582
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
583583
}
584584
}
585585

586-
if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
586+
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
587587
Type sourceBElem =
588588
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
589589
if (m == 16 && n == 16 && k == 32 && b == 1) {
590-
if (isNativeBf8(chipset, sourceBElem))
590+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
591591
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
592-
if (isNativeFp8(chipset, sourceBElem))
592+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
593593
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
594594
}
595595
if (m == 32 && n == 32 && k == 16 && b == 1) {
596-
if (isNativeBf8(chipset, sourceBElem))
596+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
597597
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
598-
if (isNativeFp8(chipset, sourceBElem))
598+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
599599
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
600600
}
601601
}
@@ -801,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
801801
}
802802
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
803803
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
804-
if (isNativeBf8(chipset, sourceElemType)) {
804+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
805805
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
806806
wordSel);
807-
} else if (isNativeFp8(chipset, sourceElemType)) {
807+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
808808
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
809809
wordSel);
810810
}
@@ -836,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
836836
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
837837

838838
Value result;
839-
if (isNativeBf8(chipset, resultElemType))
839+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
840840
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
841841
existing, wordSel);
842-
else if (isNativeFp8(chipset, resultElemType))
842+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
843843
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
844844
existing, wordSel);
845845

@@ -871,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
871871
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
872872

873873
Value result;
874-
if (isNativeBf8(chipset, resultElemType))
874+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
875875
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
876876
existing, byteSel);
877-
else if (isNativeFp8(chipset, resultElemType))
877+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
878878
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
879879
existing, byteSel);
880880

0 commit comments

Comments
 (0)