Skip to content

Commit 2ebd633

Browse files
committed
[mlir][AMDGPU] Add packed 8-bit float conversion ops and lowering
Define operations that wrap the gfx940's new operations for converting between f32 and registers containing packed sets of four 8-bit floats. Define rocdl operations for the intrinsics and an AMDGPU dialect wrapper around them (to account for the fact that MLIR distinguishes the two float formats at the type level but that the LLVM IR does not). Define an ArithToAMDGPU pass, meant to run before conversion to LLVM, that replaces relevant calls to arith.extf and arith.truncf with the packed operations in the AMDGPU dialect. Note that the conversion currently only handles scalars and vectors of rank <= 1, as we do not have a usecase for multi-dimensional vector support right now. Reviewed By: jsjodin Differential Revision: https://reviews.llvm.org/D152457
1 parent 0eed8ae commit 2ebd633

File tree

16 files changed

+914
-4
lines changed

16 files changed

+914
-4
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- ArithToAMDGPU.h - Arith to AMDGPU dialect conversion ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
10+
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
16+
class RewritePatternSet;
17+
class Pass;
18+
19+
#define GEN_PASS_DECL_ARITHTOAMDGPUCONVERSIONPASS
20+
#include "mlir/Conversion/Passes.h.inc"
21+
22+
namespace arith {
23+
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
24+
} // namespace arith
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
14+
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
1415
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1516
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1617
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,21 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> {
112112
"Chipset that these operations will run on">];
113113
}
114114

115+
//===----------------------------------------------------------------------===//
116+
// ArithToAMDGPU
117+
//===----------------------------------------------------------------------===//
118+
def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
119+
let summary = "Convert Arith operations to AMDGPU-specific implementations";
120+
let description = [{
121+
Convert `arith` operations (currently extf and truncf on 8-bit floats)
122+
to operations in the `amdgpu` dialect. This pass is done in two steps
123+
in order to avoid running a notional arith-to-rocdl and arith-to-llvm
124+
simultaniously.
125+
}];
126+
127+
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
128+
}
129+
115130
//===----------------------------------------------------------------------===//
116131
// ArithToLLVM
117132
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,85 @@ def AMDGPU_Dialect : Dialect {
3838
class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
3939
Op<AMDGPU_Dialect, mnemonic, traits> {}
4040

41+
def AMDGPU_ExtPackedFp8Op :
42+
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
43+
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
44+
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
45+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
46+
Results<(outs F32:$res)> {
47+
let summary = "Extend one of a vector of packed fp8 values to a float";
48+
let description = [{
49+
Extend the value `source[index]` to a 32-bit float and return it.
50+
51+
This rather unusual signature arises from the fact that AMD GPUs cannot
52+
easily work with sub 32-bit quantities, so the compiler intrinsics for
53+
extending 8-bit floats (which are, currently, the only way to work with
54+
this operation) take packed vectors of 4 such floats.
55+
56+
If the passed-in vector has fewer than four elements, or the input is scalar,
57+
the remaining values in the <4 x i8> will be filled with with
58+
undefined values as needed.
59+
}];
60+
let assemblyFormat = [{
61+
attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
62+
}];
63+
}
64+
65+
def AMDGPU_PackedTrunc2xFp8Op :
66+
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
67+
Arguments<(ins F32:$sourceA,
68+
Optional<F32>:$sourceB,
69+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
70+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
71+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
72+
let summary = "Round two floats into a packed vector of 8-bit floats";
73+
let description = [{
74+
Round the inputs `sourceA` and `sourceB` (which is undefined if not
75+
specified) into the low or high word (bottom two or top two) elements
76+
of the returned vector, keeping the other two elements of `existing`
77+
unchanged if present (or undefined if it was not passed in).
78+
79+
The reason for this odd signature is that AMD GPUs cannot easily work with
80+
sub-registers, and so the conversion intrinsics (which are currently the
81+
only way to work with 8-bit float types) take packed vectors of 4 8-bit
82+
values.
83+
}];
84+
let assemblyFormat = [{
85+
attr-dict $sourceA `,` ($sourceB^):(`undef`)?
86+
`into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
87+
`:` type($sourceA) `to` type($res) (`into` type($existing)^)?
88+
}];
89+
let hasVerifier = 1;
90+
}
91+
92+
def AMDGPU_PackedStochRoundFp8Op :
93+
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
94+
Arguments<(ins F32:$source,
95+
I32:$stochiasticParam,
96+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
97+
Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
98+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
99+
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
100+
let description = [{
101+
Round the input `source`, adding in `stochiasticParam`, and place it into
102+
the `storeIndex`th element of `res`.
103+
104+
If `existing` is passed in, elements of `res` other than the one at `storeIndex`
105+
are copied from `existing`.
106+
107+
The reason for this odd signature is that AMD GPUs cannot easily work with
108+
sub-registers, and so the conversion intrinsics (which are currently the
109+
only way to work with 8-bit float types) take packed vectors of 4 8-bit
110+
values.
111+
}];
112+
let assemblyFormat = [{
113+
attr-dict $source `+` $stochiasticParam
114+
`into` ($existing^):(`undef`)? `[` $storeIndex `]`
115+
`:` type($source) `to` type($res) (`into` type($existing)^)?
116+
}];
117+
let hasVerifier = 1;
118+
}
119+
41120
/// Raw buffer load
42121
def AMDGPU_RawBufferLoadOp :
43122
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class ROCDL_MbcntOp<string mnemonic> :
116116
def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">;
117117
def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">;
118118

119-
def ROCDL_DsSwizzleOp :
119+
def ROCDL_DsSwizzleOp :
120120
ROCDL_Op<"ds_swizzle">,
121121
Results<(outs I32:$res)>,
122122
Arguments<(ins I32:$src,
@@ -130,7 +130,7 @@ Arguments<(ins I32:$src,
130130
}];
131131
}
132132

133-
def ROCDL_DsBpermuteOp :
133+
def ROCDL_DsBpermuteOp :
134134
ROCDL_Op<"ds_bpermute">,
135135
Results<(outs I32:$res)>,
136136
Arguments<(ins I32:$index,
@@ -525,6 +525,85 @@ def ROCDL_RawBufferAtomicUMinOp :
525525
let hasCustomAssemblyFormat = 1;
526526
}
527527

528+
//===---------------------------------------------------------------------===//
529+
// 8-bit float intrinsics
530+
//===---------------------------------------------------------------------===//
531+
def ROCDL_CvtF32Bf8Op :
532+
ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>,
533+
Arguments<(ins I32:$srcA, I32:$byteSel)> {
534+
let summary = "Convert bf8 to f32";
535+
let description = [{
536+
Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32.
537+
}];
538+
let assemblyFormat = [{
539+
attr-dict $srcA `[` $byteSel `]` `:` type($res)
540+
}];
541+
}
542+
543+
def ROCDL_CvtF32Fp8Op :
544+
ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
545+
Arguments<(ins I32:$srcA, I32:$byteSel)> {
546+
let summary = "Convert fp8 to f32";
547+
let description = [{
548+
Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32.
549+
}];
550+
let assemblyFormat = [{
551+
attr-dict $srcA `[` $byteSel `]` `:` type($res)
552+
}];
553+
}
554+
555+
def ROCDL_CvtPkBf8F32Op :
556+
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
557+
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
558+
let summary = "Convert two f32's to bf8";
559+
let description = [{
560+
Convert `srcA` and `srcB` to bf8 and store into the low/high word of
561+
`old`, preserving the other word.
562+
}];
563+
let assemblyFormat = [{
564+
attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
565+
}];
566+
}
567+
568+
def ROCDL_CvtPkFp8F32Op :
569+
ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>,
570+
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
571+
let summary = "Convert two f32's to fp8";
572+
let description = [{
573+
Convert `srcA` and `srcB` to fp8 and store into the low/high word of
574+
`old`, preserving the other word.
575+
}];
576+
let assemblyFormat = [{
577+
attr-dict $srcA `,` $srcB `->` $old `[` $wordSel `]` `:` type($res)
578+
}];
579+
}
580+
581+
def ROCDL_CvtSrBf8F32Op :
582+
ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>,
583+
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
584+
let summary = "Convert f32 to bf8, stochiastic rounding";
585+
let description = [{
586+
Convert `srcA` to bf8, adding the rounding factor from `srcB`,
587+
and store into the `byteSel`th byte of `old`, preserving the others.
588+
}];
589+
let assemblyFormat = [{
590+
attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
591+
}];
592+
}
593+
594+
def ROCDL_CvtSrFp8F32Op :
595+
ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>,
596+
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
597+
let summary = "Convert f32 to fp8, stochiastic rounding";
598+
let description = [{
599+
Convert `srcA` to fp8, adding the rounding factor from `srcB`,
600+
and store into the `byteSel`th byte of `old`, preserving the others.
601+
}];
602+
let assemblyFormat = [{
603+
attr-dict $srcA `,` $srcB `->` $old `[` $byteSel `]` `:` type($res)
604+
}];
605+
}
606+
528607
//===----------------------------------------------------------------------===//
529608
// ROCDL target attribute.
530609
//===----------------------------------------------------------------------===//
@@ -612,5 +691,4 @@ def ROCDL_TargettAttr :
612691
}
613692
}];
614693
}
615-
616694
#endif // ROCDLIR_OPS

0 commit comments

Comments
 (0)