Skip to content

Commit 0d38937

Browse files
pcf000mirza-halilcevic
authored andcommitted
[MLIR][AMDGPU] Renaming using suggestions from review.
1 parent feb3886 commit 0d38937

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -470,14 +470,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
470470

471471
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
472472
/// supported by the `_bf8` instructions on the given `chipset`.
473-
static bool isNativeBf8(Chipset chipset, Type type) {
473+
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
474474
return (isGfx940Series(chipset) && isa<Float8E5M2FNUZType>(type)) ||
475475
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
476476
}
477477

478478
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
479479
/// supported by the `_fp8` instructions on the given `chipset`.
480-
static bool isNativeFp8(Chipset chipset, Type type) {
480+
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
481481
return (isGfx940Series(chipset) && isa<Float8E4M3FNUZType>(type)) ||
482482
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
483483
}
@@ -578,38 +578,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
578578
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
579579
}
580580

581-
if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
581+
if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
582582
// Known to be correct because there are no scalar f8 instructions and
583583
// because a length mismatch will have been caught by the verifier.
584584
Type sourceBElem =
585585
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
586586
if (m == 16 && n == 16 && k == 32 && b == 1) {
587-
if (isNativeBf8(chipset, sourceBElem))
587+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
588588
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
589-
if (isNativeFp8(chipset, sourceBElem))
589+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
590590
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
591591
}
592592
if (m == 32 && n == 32 && k == 16 && b == 1) {
593-
if (isNativeBf8(chipset, sourceBElem))
593+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
594594
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
595-
if (isNativeFp8(chipset, sourceBElem))
595+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
596596
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
597597
}
598598
}
599599

600-
if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
600+
if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
601601
Type sourceBElem =
602602
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
603603
if (m == 16 && n == 16 && k == 32 && b == 1) {
604-
if (isNativeBf8(chipset, sourceBElem))
604+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
605605
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
606-
if (isNativeFp8(chipset, sourceBElem))
606+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
607607
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
608608
}
609609
if (m == 32 && n == 32 && k == 16 && b == 1) {
610-
if (isNativeBf8(chipset, sourceBElem))
610+
if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
611611
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
612-
if (isNativeFp8(chipset, sourceBElem))
612+
if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
613613
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
614614
}
615615
}
@@ -817,10 +817,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
817817
}
818818
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
819819
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
820-
if (isNativeBf8(chipset, sourceElemType)) {
820+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
821821
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
822822
wordSel);
823-
} else if (isNativeFp8(chipset, sourceElemType)) {
823+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
824824
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
825825
wordSel);
826826
}
@@ -852,10 +852,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
852852
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
853853

854854
Value result;
855-
if (isNativeBf8(chipset, resultElemType))
855+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
856856
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
857857
existing, wordSel);
858-
else if (isNativeFp8(chipset, resultElemType))
858+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
859859
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
860860
existing, wordSel);
861861

@@ -887,10 +887,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
887887
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
888888

889889
Value result;
890-
if (isNativeBf8(chipset, resultElemType))
890+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
891891
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
892892
existing, byteSel);
893-
else if (isNativeFp8(chipset, resultElemType))
893+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
894894
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
895895
existing, byteSel);
896896

0 commit comments

Comments
 (0)