Skip to content

Commit b87a0a0

Browse files
krzysz00mirza-halilcevic
authored andcommitted
[MLIR][AMDGPU] After fp8 conversions were lowered to AMDGPU dialect ops,
those operations were not being converted to the LLVM intrinsics they correspond to because the rewrite patterns were still checking for gfx940+. As part of this, factor out tests for type-match isto isNativeFp8() and isNativeBf8() functions in the AMDGPUToRocdl rewrites. Also, fix a typo in isGfx940() that caused it to be true for gfx950. Finally, test all these OCP format conversions by duplicating the gfx940 tests.
1 parent 32e052d commit b87a0a0

File tree

4 files changed

+373
-20
lines changed

4 files changed

+373
-20
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
474474
}
475475
}
476476

477+
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
478+
/// supported by the `_bf8` instructions on the given `chipset`.
479+
static bool isNativeBf8(Chipset chipset, Type type) {
480+
return (chipset.isGfx940() && isa<Float8E5M2FNUZType>(type)) ||
481+
(chipset.hasOcpFp8() && isa<Float8E5M2Type>(type));
482+
}
483+
484+
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
485+
/// supported by the `_fp8` instructions on the given `chipset`.
486+
static bool isNativeFp8(Chipset chipset, Type type) {
487+
return (chipset.isGfx940() && isa<Float8E4M3FNUZType>(type)) ||
488+
(chipset.hasOcpFp8() && isa<Float8E4M3FNType>(type));
489+
}
490+
477491
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
478492
/// if one exists. This includes checking to ensure the intrinsic is supported
479493
/// on the architecture you are compiling for.
@@ -570,42 +584,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
570584
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
571585
}
572586

573-
if (destElem.isF32() &&
574-
((isa<Float8E5M2FNUZType>(sourceElem) && chipset >= kGfx942) ||
575-
(isa<Float8E5M2Type>(sourceElem) && chipset.hasOcpFp8()))) {
587+
if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
576588
// Known to be correct because there are no scalar f8 instructions and
577589
// because a length mismatch will have been caught by the verifier.
578590
Type sourceBElem =
579591
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
580592
if (m == 16 && n == 16 && k == 32 && b == 1) {
581-
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
593+
if (isNativeBf8(chipset, sourceBElem))
582594
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
583-
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
595+
if (isNativeFp8(chipset, sourceBElem))
584596
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
585597
}
586598
if (m == 32 && n == 32 && k == 16 && b == 1) {
587-
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
599+
if (isNativeBf8(chipset, sourceBElem))
588600
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
589-
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
601+
if (isNativeFp8(chipset, sourceBElem))
590602
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
591603
}
592604
}
593605

594-
if (destElem.isF32() &&
595-
((isa<Float8E4M3FNUZType>(sourceElem) && chipset >= kGfx942) ||
596-
(isa<Float8E4M3FNType>(sourceElem) && chipset.hasOcpFp8()))) {
606+
if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
597607
Type sourceBElem =
598608
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
599609
if (m == 16 && n == 16 && k == 32 && b == 1) {
600-
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
610+
if (isNativeBf8(chipset, sourceBElem))
601611
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
602-
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
612+
if (isNativeFp8(chipset, sourceBElem))
603613
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
604614
}
605615
if (m == 32 && n == 32 && k == 16 && b == 1) {
606-
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
616+
if (isNativeBf8(chipset, sourceBElem))
607617
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
608-
if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
618+
if (isNativeFp8(chipset, sourceBElem))
609619
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
610620
}
611621
}
@@ -813,10 +823,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
813823
}
814824
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
815825
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
816-
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceElemType)) {
826+
if (isNativeBf8(chipset, sourceElemType)) {
817827
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
818828
wordSel);
819-
} else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceElemType)) {
829+
} else if (isNativeFp8(chipset, sourceElemType)) {
820830
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
821831
wordSel);
822832
}
@@ -848,10 +858,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
848858
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
849859

850860
Value result;
851-
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
861+
if (isNativeBf8(chipset, resultElemType))
852862
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
853863
existing, wordSel);
854-
else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
864+
else if (isNativeFp8(chipset, resultElemType))
855865
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
856866
existing, wordSel);
857867

@@ -883,10 +893,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
883893
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
884894

885895
Value result;
886-
if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
896+
if (isNativeBf8(chipset, resultElemType))
887897
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
888898
existing, byteSel);
889-
else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
899+
else if (isNativeFp8(chipset, resultElemType))
890900
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
891901
existing, byteSel);
892902

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
2+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s
3+
4+
// CHECK-LABEL: func @ext_scalar
5+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8
6+
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
7+
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
8+
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
9+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
10+
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
11+
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
12+
// CHECK: return [[EXT]]
13+
func.func @ext_scalar(%v: f8E5M2) -> f32 {
14+
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
15+
func.return %ret : f32
16+
}
17+
18+
// CHECK-LABEL: func @ext_short_vec
19+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
20+
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
21+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
22+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
23+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
24+
// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
25+
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
26+
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
27+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
28+
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
29+
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
30+
// CHECK: return [[EXT]]
31+
func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
32+
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
33+
func.return %ret : f32
34+
}
35+
36+
// CHECK-LABEL: func @ext_full_vec(
37+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
38+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
39+
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
40+
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
41+
// CHECK: return [[EXT]] : f32
42+
43+
func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
44+
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
45+
func.return %ret : f32
46+
}
47+
48+
// CHECK-LABEL: func @packed_trunc
49+
// CHECK-SAME: ([[V:%.+]]: f32)
50+
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
51+
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
52+
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
53+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
54+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
55+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
56+
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
57+
%ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FN>
58+
func.return %ret : vector<4xf8E4M3FN>
59+
}
60+
61+
// CHECK-LABEL: func @packed_truncx2
62+
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
63+
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
64+
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
65+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
66+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
67+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
68+
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
69+
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FN>
70+
func.return %ret : vector<4xf8E4M3FN>
71+
}
72+
73+
// CHECK-LABEL: func @packed_truncx2_into
74+
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
75+
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
76+
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
77+
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
78+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
79+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
80+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
81+
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
82+
%ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
83+
func.return %ret : vector<4xf8E5M2>
84+
}
85+
86+
// CHECK-LABEL: func @packed_stoch_round
87+
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
88+
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
89+
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
90+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
91+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
92+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
93+
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
94+
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FN>
95+
func.return %ret : vector<4xf8E4M3FN>
96+
}
97+
98+
// CHECK-LABEL: func @packed_stoch_round_into
99+
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
100+
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
101+
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
102+
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
103+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
104+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
105+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
106+
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
107+
%ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
108+
func.return %ret : vector<4xf8E5M2>
109+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: mlir-opt --split-input-file %s \
2+
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx950 saturate-fp8-truncf=true}))' \
3+
// RUN: | FileCheck %s
4+
5+
// RUN: mlir-opt --split-input-file %s \
6+
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx1200 saturate-fp8-truncf=true}))' \
7+
// RUN: | FileCheck %s
8+
9+
// CHECK-LABEL: func.func @scalar_trunc
10+
// CHECK-SAME: ([[V:%.+]]: f16)
11+
// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
12+
// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
13+
// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
14+
// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
15+
// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
16+
// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
17+
// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
18+
// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
19+
// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
20+
// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
21+
// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
22+
// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
23+
// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
24+
// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2>
25+
// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2>
26+
// CHECK: return [[W]] : f8E5M2
27+
func.func @scalar_trunc(%v: f16) -> f8E5M2 {
28+
%w = arith.truncf %v : f16 to f8E5M2
29+
return %w : f8E5M2
30+
}
31+
32+
// No 0-D test because arith.truncf hasn't been extended to support it.
33+
34+
// -----
35+
36+
// CHECK-LABEL: func.func @vector_trunc
37+
// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FN> {
38+
// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-4.480000e+02> : vector<2xf32>
39+
// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<4.480000e+02> : vector<2xf32>
40+
// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
41+
// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
42+
// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
43+
// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
44+
// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
45+
// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
46+
// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
47+
// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
48+
// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
49+
// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
50+
// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0]
51+
// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1]
52+
// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FN>
53+
// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FN> to vector<2xf8E4M3FN>
54+
// CHECK: return [[W]] : vector<2xf8E4M3FN>
55+
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FN> {
56+
%w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FN>
57+
return %w : vector<2xf8E4M3FN>
58+
}

0 commit comments

Comments
 (0)