Skip to content

Commit dd2cad3

Browse files
pcf000mirza-halilcevic
authored andcommitted
[MLIR][AMDGPU] Renaming using suggestions from review.
1 parent 0a761c0 commit dd2cad3

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -476,14 +476,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
476476

477477
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
478478
/// supported by the `_bf8` instructions on the given `chipset`.
479-
static bool isNativeBf8(Chipset chipset, Type type) {
479+
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
480480
return (isGfx940Series(chipset) && isa<Float8E5M2FNUZType>(type)) ||
481481
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
482482
}
483483

484484
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
485485
/// supported by the `_fp8` instructions on the given `chipset`.
486-
static bool isNativeFp8(Chipset chipset, Type type) {
486+
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
487487
return (isGfx940Series(chipset) && isa<Float8E4M3FNUZType>(type)) ||
488488
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
489489
}
@@ -584,38 +584,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
584584
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
585585
}
586586

587-
if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
587+
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
588588
// Known to be correct because there are no scalar f8 instructions and
589589
// because a length mismatch will have been caught by the verifier.
590590
Type sourceBElem =
591591
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
592592
if (m == 16 && n == 16 && k == 32 && b == 1) {
593-
if (isNativeBf8(chipset, sourceBElem))
593+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
594594
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
595-
if (isNativeFp8(chipset, sourceBElem))
595+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
596596
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
597597
}
598598
if (m == 32 && n == 32 && k == 16 && b == 1) {
599-
if (isNativeBf8(chipset, sourceBElem))
599+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
600600
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
601-
if (isNativeFp8(chipset, sourceBElem))
601+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
602602
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
603603
}
604604
}
605605

606-
if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
606+
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
607607
Type sourceBElem =
608608
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
609609
if (m == 16 && n == 16 && k == 32 && b == 1) {
610-
if (isNativeBf8(chipset, sourceBElem))
610+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
611611
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
612-
if (isNativeFp8(chipset, sourceBElem))
612+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
613613
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
614614
}
615615
if (m == 32 && n == 32 && k == 16 && b == 1) {
616-
if (isNativeBf8(chipset, sourceBElem))
616+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
617617
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
618-
if (isNativeFp8(chipset, sourceBElem))
618+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
619619
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
620620
}
621621
}
@@ -823,10 +823,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
823823
}
824824
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
825825
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
826-
if (isNativeBf8(chipset, sourceElemType)) {
826+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
827827
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
828828
wordSel);
829-
} else if (isNativeFp8(chipset, sourceElemType)) {
829+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
830830
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
831831
wordSel);
832832
}
@@ -858,10 +858,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
858858
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
859859

860860
Value result;
861-
if (isNativeBf8(chipset, resultElemType))
861+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
862862
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
863863
existing, wordSel);
864-
else if (isNativeFp8(chipset, resultElemType))
864+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
865865
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
866866
existing, wordSel);
867867

@@ -893,10 +893,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
893893
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
894894

895895
Value result;
896-
if (isNativeBf8(chipset, resultElemType))
896+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
897897
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
898898
existing, byteSel);
899-
else if (isNativeFp8(chipset, resultElemType))
899+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
900900
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
901901
existing, byteSel);
902902

0 commit comments

Comments
 (0)