Skip to content

Commit 1387ba4

Browse files
authored
[MLIR][AMDGPU] Introduce fp16 packed arithmetic (#105688)
This PR is introducing rocdl.cvt.pkrtz in the ROCDL dialect and it is using that instruction when lowering `arith::TruncFOp`.
1 parent 643bf6c commit 1387ba4

File tree

12 files changed

+203
-11
lines changed

12 files changed

+203
-11
lines changed

mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
1010
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
1111

12+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1213
#include <memory>
14+
#include <string>
1315

1416
namespace mlir {
1517

@@ -26,7 +28,10 @@ namespace arith {
2628
/// to the largest value of that type instead of being rewritten to Inf (aka
2729
/// NaN).
2830
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
29-
bool saturateFP8TruncF);
31+
bool convertFP8Arithmetic,
32+
bool saturateFP8Truncf,
33+
bool allowPackedF16Rtz,
34+
amdgpu::Chipset chipset);
3035
} // namespace arith
3136
} // namespace mlir
3237

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
150150
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
151151

152152
let options = [
153+
Option<"chipset", "chipset", "std::string",
154+
/*default=*/"\"gfx000\"",
155+
"Chipset that these operations will run on">,
153156
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
154157
/*default=*/"false",
155158
"Use saturating truncation for 8-bit float types">,
159+
Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
160+
/*default=*/"false",
161+
"Whether we should allow f32->f16 packed round-to-zero conversion">,
156162
];
157163
}
158164

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {
2525

2626

2727
let dependentDialects = [
28+
"ROCDL::ROCDLDialect",
2829
"arith::ArithDialect",
2930
"gpu::GPUDialect"
3031
];

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def ROCDL_BallotOp :
166166
let summary = "Vote across thread group";
167167

168168
let description = [{
169-
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
169+
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
170170
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
171171
}];
172172

@@ -579,6 +579,21 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
579579
}];
580580
}
581581

582+
//===---------------------------------------------------------------------===//
583+
// 16-bit float intrinsics
584+
//===---------------------------------------------------------------------===//
585+
def ROCDL_CvtPkRtz:
586+
ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
587+
Arguments<(ins F32:$srcA, F32:$srcB)> {
588+
let summary = "Convert two f32 input into a vector<2xf16>";
589+
let description = [{
590+
Convert two f32 values into a packed vector<2xf16>.
591+
}];
592+
let assemblyFormat = [{
593+
attr-dict $srcA `,` $srcB `:` type($res)
594+
}];
595+
}
596+
582597
//===---------------------------------------------------------------------===//
583598
// 8-bit float intrinsics
584599
//===---------------------------------------------------------------------===//

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
1010

1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1213
#include "mlir/Dialect/Arith/IR/Arith.h"
1314
#include "mlir/Dialect/Arith/Utils/Utils.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1417
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1518
#include "mlir/IR/BuiltinTypes.h"
1619
#include "mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
2427
} // namespace mlir
2528

2629
using namespace mlir;
30+
using namespace mlir::amdgpu;
2731

2832
namespace {
2933
struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4347

4448
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
4549
bool saturateFP8 = false;
46-
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
47-
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
50+
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
51+
Chipset chipset)
52+
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
53+
chipset(chipset) {}
54+
Chipset chipset;
4855

4956
LogicalResult match(arith::TruncFOp op) const override;
5057
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
5158
};
59+
60+
struct TruncfToFloat16RewritePattern final
61+
: public OpRewritePattern<arith::TruncFOp> {
62+
63+
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
64+
65+
LogicalResult match(arith::TruncFOp op) const override;
66+
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
67+
};
68+
5269
} // end namespace
5370

5471
static Value castF32To(Type elementType, Value f32, Location loc,
@@ -272,17 +289,105 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
272289
rewriter.replaceOp(op, result);
273290
}
274291

292+
LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
293+
Type outType = op.getOut().getType();
294+
Type inputType = getElementTypeOrSelf(op.getIn());
295+
if (auto outVecType = dyn_cast<VectorType>(outType)) {
296+
if (outVecType.isScalable())
297+
return failure();
298+
outType = outVecType.getElementType();
299+
}
300+
return success(outType.isF16() && inputType.isF32());
301+
}
302+
303+
void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
304+
PatternRewriter &rewriter) const {
305+
Location loc = op.getLoc();
306+
Value in = op.getIn();
307+
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
308+
VectorType truncResType = VectorType::get(2, outElemType);
309+
auto inVectorTy = dyn_cast<VectorType>(in.getType());
310+
311+
// Handle the case where input type is not a vector type
312+
if (!inVectorTy) {
313+
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
314+
Value asF16s =
315+
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316+
Value result = rewriter.create<vector::ExtractElementOp>(
317+
loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
318+
return rewriter.replaceOp(op, result);
319+
}
320+
VectorType outType = cast<VectorType>(op.getOut().getType());
321+
int64_t numElements = outType.getNumElements();
322+
Value zero = rewriter.createOrFold<arith::ConstantOp>(
323+
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
324+
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
325+
326+
if (inVectorTy.getRank() > 1) {
327+
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
328+
inVectorTy.getElementType());
329+
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
330+
}
331+
332+
// Handle the vector case. We also handle the (uncommon) case where the vector
333+
// length is odd
334+
for (int64_t i = 0; i < numElements; i += 2) {
335+
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
336+
Value thisResult = nullptr;
337+
Value elemA = rewriter.create<vector::ExtractElementOp>(
338+
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
339+
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
340+
341+
if (elemsThisOp == 2) {
342+
elemB = rewriter.create<vector::ExtractElementOp>(
343+
loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
344+
}
345+
346+
thisResult =
347+
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
348+
// Place back the truncated result into the possibly larger vector. If we
349+
// are operating on a size 2 vector, these operations should be folded away
350+
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
351+
loc, thisResult, 0, elemsThisOp, 1);
352+
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
353+
result, i, 1);
354+
}
355+
356+
if (inVectorTy.getRank() != outType.getRank()) {
357+
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
358+
}
359+
360+
rewriter.replaceOp(op, result);
361+
}
362+
275363
void mlir::arith::populateArithToAMDGPUConversionPatterns(
276-
RewritePatternSet &patterns, bool saturateFP8TruncF) {
277-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
278-
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
279-
saturateFP8TruncF);
364+
RewritePatternSet &patterns, bool convertFP8Arithmetic,
365+
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
366+
367+
if (convertFP8Arithmetic) {
368+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
369+
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
370+
saturateFP8Truncf, chipset);
371+
}
372+
if (allowPackedF16Rtz)
373+
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
280374
}
281375

282376
void ArithToAMDGPUConversionPass::runOnOperation() {
283377
Operation *op = getOperation();
378+
MLIRContext *ctx = &getContext();
284379
RewritePatternSet patterns(op->getContext());
285-
arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
380+
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
381+
if (failed(maybeChipset)) {
382+
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
383+
return signalPassFailure();
384+
}
385+
386+
bool convertFP8Arithmetic =
387+
(*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
388+
arith::populateArithToAMDGPUConversionPatterns(
389+
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
390+
*maybeChipset);
286391
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
287392
return signalPassFailure();
288393
}

mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
1212

1313
LINK_LIBS PUBLIC
1414
MLIRAMDGPUDialect
15+
MLIRAMDGPUUtils
1516
MLIRArithDialect
1617
MLIRArithUtils
1718
MLIRVectorDialect

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1718
#include "mlir/IR/Builders.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/Diagnostics.h"

mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
1111

1212
LINK_LIBS PUBLIC
1313
MLIRArithDialect
14+
MLIRROCDLDialect
1415
# Needed for GPU address space enum definition
1516
MLIRGPUDialect
1617
MLIRIR
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s
2+
3+
// CHECK-LABEL: @scalar_trunc
4+
// CHECK-SAME: (%[[value:.*]]: f32)
5+
func.func @scalar_trunc(%v: f32) -> f16{
6+
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
7+
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
8+
// CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
9+
// CHECK: return %[[extract]] : f16
10+
%w = arith.truncf %v : f32 to f16
11+
return %w : f16
12+
}
13+
14+
// CHECK-LABEL: @vector_trunc
15+
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
16+
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
17+
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
18+
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
19+
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
20+
// CHECK: return %[[ret]]
21+
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
22+
return %w : vector<2xf16>
23+
}
24+
25+
// CHECK-LABEL: @vector_trunc_long
26+
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
27+
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
28+
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
29+
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
30+
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
31+
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
32+
// CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
33+
// CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
34+
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
35+
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
36+
// CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
37+
// CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
38+
// CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
39+
// CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
40+
// CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
41+
// CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
42+
// CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
43+
// CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
44+
// CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
45+
// CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
46+
// CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
47+
// CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
48+
// CHECK: return %[[out4]]
49+
%w = arith.truncf %v : vector<9xf32> to vector<9xf16>
50+
return %w : vector<9xf16>
51+
}

mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt --split-input-file %s \
2-
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
2+
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
33
// RUN: | FileCheck %s
44

55
// CHECK-LABEL: func.func @scalar_trunc

mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
1+
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
22

33
// CHECK-LABEL: func.func @scalar_ext
44
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
530530
llvm.return %source5 : i32
531531
}
532532

533+
llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
534+
// CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
535+
%source = rocdl.cvt.pkrtz %sourceA, %sourceB : vector<2xf16>
536+
llvm.return %source : vector<2xf16>
537+
}
538+
533539
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
534540
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
535541
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"

0 commit comments

Comments
 (0)