Skip to content

[MLIR][AMDGPU] Introduce fp16 packed arithmetic #105688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H

#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include <memory>
#include <string>

namespace mlir {

Expand All @@ -26,7 +28,10 @@ namespace arith {
/// to the largest value of that type instead of being rewritten to Inf (aka
/// NaN).
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool saturateFP8TruncF);
bool convertFP8Arithmetic,
bool saturateFP8Truncf,
bool allowPackedF16Rtz,
amdgpu::Chipset chipset);
} // namespace arith
} // namespace mlir

Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];

let options = [
Option<"chipset", "chipset", "std::string",
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">,
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
/*default=*/"false",
"Use saturating truncation for 8-bit float types">,
Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
/*default=*/"false",
"Whether we should allow f32->f16 packed round-to-zero conversion">,
];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {


let dependentDialects = [
"ROCDL::ROCDLDialect",
"arith::ArithDialect",
"gpu::GPUDialect"
];
Expand Down
17 changes: 16 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";

let description = [{
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];

Expand Down Expand Up @@ -554,6 +554,21 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}

//===---------------------------------------------------------------------===//
// 16-bit float intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtPkRtz:
ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, F32:$srcB)> {
let summary = "Convert two f32 input into a vector<2xf16>";
let description = [{
Convert two f32 values into a packed vector<2xf16>.
}];
let assemblyFormat = [{
attr-dict $srcA `,` $srcB `:` type($res)
}];
}

//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
Expand Down
119 changes: 112 additions & 7 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -24,6 +27,7 @@ namespace mlir {
} // namespace mlir

using namespace mlir;
using namespace mlir::amdgpu;

namespace {
struct ArithToAMDGPUConversionPass final
Expand All @@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {

struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
Chipset chipset)
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
chipset(chipset) {}
Chipset chipset;

LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
};

struct TruncfToFloat16RewritePattern final
: public OpRewritePattern<arith::TruncFOp> {

using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;

LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
};

} // end namespace

static Value castF32To(Type elementType, Value f32, Location loc,
Expand Down Expand Up @@ -272,17 +289,105 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
rewriter.replaceOp(op, result);
}

LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
Type outType = op.getOut().getType();
Type inputType = getElementTypeOrSelf(op.getIn());
if (auto outVecType = dyn_cast<VectorType>(outType)) {
if (outVecType.isScalable())
return failure();
outType = outVecType.getElementType();
}
return success(outType.isF16() && inputType.isF32());
}

void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType truncResType = VectorType::get(2, outElemType);
auto inVectorTy = dyn_cast<VectorType>(in.getType());

// Handle the case where input type is not a vector type
if (!inVectorTy) {
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
Value result = rewriter.create<vector::ExtractElementOp>(
loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
return rewriter.replaceOp(op, result);
}
VectorType outType = cast<VectorType>(op.getOut().getType());
int64_t numElements = outType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);

if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
}

// Handle the vector case. We also handle the (uncommon) case where the vector
// length is odd
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
Value elemA = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());

if (elemsThisOp == 2) {
elemB = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
}

thisResult =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
// Place back the truncated result into the possibly larger vector. If we
// are operating on a size 2 vector, these operations should be folded away
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
loc, thisResult, 0, elemsThisOp, 1);
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result, i, 1);
}

if (inVectorTy.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
}

rewriter.replaceOp(op, result);
}

void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool saturateFP8TruncF) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8TruncF);
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {

if (convertFP8Arithmetic) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
if (allowPackedF16Rtz)
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
}

void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(op->getContext());
arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset)) {
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
return signalPassFailure();
}

bool convertFP8Arithmetic =
(*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU

LINK_LIBS PUBLIC
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRArithDialect
MLIRArithUtils
MLIRVectorDialect
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
MLIRIR
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s

// CHECK-LABEL: @scalar_trunc
// CHECK-SAME: (%[[value:.*]]: f32)
func.func @scalar_trunc(%v: f32) -> f16{
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
// CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
// CHECK: return %[[extract]] : f16
%w = arith.truncf %v : f32 to f16
return %w : f16
}

// CHECK-LABEL: @vector_trunc
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: return %[[ret]]
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
return %w : vector<2xf16>
}

// CHECK-LABEL: @vector_trunc_long
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
// CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
// CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
// CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
// CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
// CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
// CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
// CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
// CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
// CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
// CHECK: return %[[out4]]
%w = arith.truncf %v : vector<9xf32> to vector<9xf16>
return %w : vector<9xf16>
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt --split-input-file %s \
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
// RUN: | FileCheck %s

// CHECK-LABEL: func.func @scalar_trunc
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s

// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Target/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
llvm.return %source5 : i32
}

llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
// CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
%source = rocdl.cvt.pkrtz %sourceA, %sourceB : vector<2xf16>
llvm.return %source : vector<2xf16>
}

// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"
Expand Down
Loading