-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][AMDGPU] Introduce fp16 packed arithmetic #105688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-backend-amdgpu Author: Giuseppe Rossini (giuseros) ChangesThis PR is introducing rocdl.cvt.pkrtz in the ROCDL dialect and it is using that instruction when lowering Full diff: https://github.com/llvm/llvm-project/pull/105688.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 78c79c915e0607..28fdc234e5ef07 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -9,7 +9,9 @@
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include <memory>
+#include <string>
namespace mlir {
@@ -26,7 +28,10 @@ namespace arith {
/// to the largest value of that type instead of being rewritten to Inf (aka
/// NaN).
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
- bool saturateFP8TruncF);
+ bool convertFP8Arithmetic,
+ bool saturateFP8Truncf,
+ bool allowPackedF16Rtz,
+ amdgpu::Chipset chipset);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961c..24dc3b67db5a56 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
let options = [
+ Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">,
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
/*default=*/"false",
"Use saturating truncation for 8-bit float types">,
+ Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
+ /*default=*/"false",
+ "Whether we should allow f32->f16 packed round-to-zero conversion">,
];
}
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..d6fcf7329b6099 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {
let dependentDialects = [
+ "ROCDL::ROCDLDialect",
"arith::ArithDialect",
"gpu::GPUDialect"
];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..082148ddb13d6f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -554,6 +554,21 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}
+//===---------------------------------------------------------------------===//
+// 16-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtPkRtz:
+ ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
+ Arguments<(ins F32:$srcA, F32:$srcB)> {
+ let summary = "Convert two f32 input into a vector<2xf16>";
+ let description = [{
+ Convert two f32 values into a packed vector<2xf16>.
+ }];
+ let assemblyFormat = [{
+ attr-dict $srcA `,` $srcB `:` type($res)
+ }];
+}
+
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b3798a3f7624b0..5c37ec536d8963 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -9,8 +9,11 @@
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
} // namespace mlir
using namespace mlir;
+using namespace mlir::amdgpu;
namespace {
struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
- TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
+ TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
+ Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+ chipset(chipset) {}
+ Chipset chipset;
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
};
+
+struct TruncfToFloat16RewritePattern final
+ : public OpRewritePattern<arith::TruncFOp> {
+
+ using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+
+ LogicalResult match(arith::TruncFOp op) const override;
+ void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+};
+
} // end namespace
static Value castF32To(Type elementType, Value f32, Location loc,
@@ -272,17 +289,96 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
rewriter.replaceOp(op, result);
}
+LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
+ Type outType = op.getOut().getType();
+ Type inputType = getElementTypeOrSelf(op.getIn());
+ if (auto outVecType = dyn_cast<VectorType>(outType)) {
+ if (outVecType.isScalable())
+ return failure();
+ if (outVecType.getShape().size() > 1)
+ // Multi-dimensional vectors are currently unsupported.
+ return failure();
+ outType = outVecType.getElementType();
+ }
+ return success(outType.isF16() && inputType.isF32());
+}
+
+void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ Value in = op.getIn();
+ Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ VectorType truncResType = VectorType::get(2, outElemType);
+
+ // Handle the case where input type is not a vector type
+ if (!isa<VectorType>(in.getType())) {
+ auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ Value asF16s =
+ rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
+ Value result = rewriter.create<vector::ExtractElementOp>(
+ loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+ return rewriter.replaceOp(op, result);
+ }
+ VectorType outType = cast<VectorType>(op.getOut().getType());
+ int64_t numElements = outType.getNumElements();
+ Value zero = rewriter.createOrFold<arith::ConstantOp>(
+ loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+
+ // Handle the vector case. We also handle the (uncommon) case where the vector
+ // length is odd
+ for (int64_t i = 0; i < numElements; i += 2) {
+ int64_t elemsThisOp = std::min(numElements, i + 2) - i;
+ Value thisResult = nullptr;
+ Value elemA = rewriter.create<vector::ExtractElementOp>(
+ loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+ Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+
+ if (elemsThisOp == 2) {
+ elemB = rewriter.create<vector::ExtractElementOp>(
+ loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+ }
+
+ thisResult =
+ rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+ // Place back the truncated result into the possibly larger vector. If we
+ // are operating on a size 2 vector, these operations should be folded away
+ thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, thisResult, 0, elemsThisOp, 1);
+ result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
+ result, i, 1);
+ }
+ rewriter.replaceOp(op, result);
+}
+
void mlir::arith::populateArithToAMDGPUConversionPatterns(
- RewritePatternSet &patterns, bool saturateFP8TruncF) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
- patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
- saturateFP8TruncF);
+ RewritePatternSet &patterns, bool convertFP8Arithmetic,
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+
+ if (convertFP8Arithmetic) {
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+ saturateFP8Truncf, chipset);
+ }
+ if (allowPackedF16Rtz)
+ patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
}
void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
+ MLIRContext *ctx = &getContext();
RewritePatternSet patterns(op->getContext());
- arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset)) {
+ emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+ return signalPassFailure();
+ }
+
+ bool convertFP8Arithmetic =
+ (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+ arith::populateArithToAMDGPUConversionPatterns(
+ patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
+ *maybeChipset);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index e2c951b0b34d8b..50be09ab5a7c5b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
+ MLIRAMDGPUUtils
MLIRArithDialect
MLIRArithUtils
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..0b1dd79ded3a71 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 0551d13b5a0cf0..78d78cf48a747c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
MLIRIR
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
new file mode 100644
index 00000000000000..121cae26748a82
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s
+
+// CHECK-LABEL: @scalar_trunc
+// CHECK-SAME: (%[[value:.*]]: f32)
+func.func @scalar_trunc(%v: f32) -> f16{
+ // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
+ // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
+ // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+ // CHECK: return %[[extract]] : f16
+ %w = arith.truncf %v : f32 to f16
+ return %w : f16
+}
+
+// CHECK-LABEL: @vector_trunc
+// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
+ // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+ // CHECK: return %[[ret]]
+ %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
+ return %w : vector<2xf16>
+}
+
+// CHECK-LABEL: @vector_trunc_long
+// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
+ // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
+ // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+ // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+ // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
+ // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+ // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
+ // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
+ // CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
+ // CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
+ // CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
+ // CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
+ // CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+ // CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
+ // CHECK: return %[[out4]]
+ %w = arith.truncf %v : vector<9xf32> to vector<9xf16>
+ return %w : vector<9xf16>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
index c7f39440a349b0..cd921da2294e13 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt --split-input-file %s \
-// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
// RUN: | FileCheck %s
// CHECK-LABEL: func.func @scalar_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 26a222a4a788e5..bd90facb615440 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..d04978ff6deeb7 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -516,6 +516,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
llvm.return %source5 : i32
}
+llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
+ // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
+ %source = rocdl.cvt.pkrtz %sourceA, %sourceB : vector<2xf16>
+ llvm.return %source : vector<2xf16>
+}
+
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"
|
@llvm/pr-subscribers-mlir-llvm Author: Giuseppe Rossini (giuseros) ChangesThis PR is introducing rocdl.cvt.pkrtz in the ROCDL dialect and it is using that instruction when lowering Full diff: https://github.com/llvm/llvm-project/pull/105688.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 78c79c915e0607..28fdc234e5ef07 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -9,7 +9,9 @@
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include <memory>
+#include <string>
namespace mlir {
@@ -26,7 +28,10 @@ namespace arith {
/// to the largest value of that type instead of being rewritten to Inf (aka
/// NaN).
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
- bool saturateFP8TruncF);
+ bool convertFP8Arithmetic,
+ bool saturateFP8Truncf,
+ bool allowPackedF16Rtz,
+ amdgpu::Chipset chipset);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961c..24dc3b67db5a56 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
let options = [
+ Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">,
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
/*default=*/"false",
"Use saturating truncation for 8-bit float types">,
+ Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
+ /*default=*/"false",
+ "Whether we should allow f32->f16 packed round-to-zero conversion">,
];
}
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..d6fcf7329b6099 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {
let dependentDialects = [
+ "ROCDL::ROCDLDialect",
"arith::ArithDialect",
"gpu::GPUDialect"
];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..082148ddb13d6f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -554,6 +554,21 @@ def ROCDL_RawBufferAtomicUMinOp :
let hasCustomAssemblyFormat = 1;
}
+//===---------------------------------------------------------------------===//
+// 16-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtPkRtz:
+ ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
+ Arguments<(ins F32:$srcA, F32:$srcB)> {
+ let summary = "Convert two f32 input into a vector<2xf16>";
+ let description = [{
+ Convert two f32 values into a packed vector<2xf16>.
+ }];
+ let assemblyFormat = [{
+ attr-dict $srcA `,` $srcB `:` type($res)
+ }];
+}
+
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b3798a3f7624b0..5c37ec536d8963 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -9,8 +9,11 @@
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
} // namespace mlir
using namespace mlir;
+using namespace mlir::amdgpu;
namespace {
struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
- TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
+ TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
+ Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+ chipset(chipset) {}
+ Chipset chipset;
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
};
+
+struct TruncfToFloat16RewritePattern final
+ : public OpRewritePattern<arith::TruncFOp> {
+
+ using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+
+ LogicalResult match(arith::TruncFOp op) const override;
+ void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+};
+
} // end namespace
static Value castF32To(Type elementType, Value f32, Location loc,
@@ -272,17 +289,96 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
rewriter.replaceOp(op, result);
}
+LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
+ Type outType = op.getOut().getType();
+ Type inputType = getElementTypeOrSelf(op.getIn());
+ if (auto outVecType = dyn_cast<VectorType>(outType)) {
+ if (outVecType.isScalable())
+ return failure();
+ if (outVecType.getShape().size() > 1)
+ // Multi-dimensional vectors are currently unsupported.
+ return failure();
+ outType = outVecType.getElementType();
+ }
+ return success(outType.isF16() && inputType.isF32());
+}
+
+void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ Value in = op.getIn();
+ Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ VectorType truncResType = VectorType::get(2, outElemType);
+
+ // Handle the case where input type is not a vector type
+ if (!isa<VectorType>(in.getType())) {
+ auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ Value asF16s =
+ rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
+ Value result = rewriter.create<vector::ExtractElementOp>(
+ loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+ return rewriter.replaceOp(op, result);
+ }
+ VectorType outType = cast<VectorType>(op.getOut().getType());
+ int64_t numElements = outType.getNumElements();
+ Value zero = rewriter.createOrFold<arith::ConstantOp>(
+ loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+
+ // Handle the vector case. We also handle the (uncommon) case where the vector
+ // length is odd
+ for (int64_t i = 0; i < numElements; i += 2) {
+ int64_t elemsThisOp = std::min(numElements, i + 2) - i;
+ Value thisResult = nullptr;
+ Value elemA = rewriter.create<vector::ExtractElementOp>(
+ loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+ Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+
+ if (elemsThisOp == 2) {
+ elemB = rewriter.create<vector::ExtractElementOp>(
+ loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+ }
+
+ thisResult =
+ rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+ // Place back the truncated result into the possibly larger vector. If we
+ // are operating on a size 2 vector, these operations should be folded away
+ thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, thisResult, 0, elemsThisOp, 1);
+ result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
+ result, i, 1);
+ }
+ rewriter.replaceOp(op, result);
+}
+
void mlir::arith::populateArithToAMDGPUConversionPatterns(
- RewritePatternSet &patterns, bool saturateFP8TruncF) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
- patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
- saturateFP8TruncF);
+ RewritePatternSet &patterns, bool convertFP8Arithmetic,
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+
+ if (convertFP8Arithmetic) {
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+ saturateFP8Truncf, chipset);
+ }
+ if (allowPackedF16Rtz)
+ patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
}
void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
+ MLIRContext *ctx = &getContext();
RewritePatternSet patterns(op->getContext());
- arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset)) {
+ emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+ return signalPassFailure();
+ }
+
+ bool convertFP8Arithmetic =
+ (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+ arith::populateArithToAMDGPUConversionPatterns(
+ patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
+ *maybeChipset);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index e2c951b0b34d8b..50be09ab5a7c5b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
+ MLIRAMDGPUUtils
MLIRArithDialect
MLIRArithUtils
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..0b1dd79ded3a71 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 0551d13b5a0cf0..78d78cf48a747c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
MLIRIR
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
new file mode 100644
index 00000000000000..121cae26748a82
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s
+
+// CHECK-LABEL: @scalar_trunc
+// CHECK-SAME: (%[[value:.*]]: f32)
+func.func @scalar_trunc(%v: f32) -> f16{
+ // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
+ // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
+ // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+ // CHECK: return %[[extract]] : f16
+ %w = arith.truncf %v : f32 to f16
+ return %w : f16
+}
+
+// CHECK-LABEL: @vector_trunc
+// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
+ // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+ // CHECK: return %[[ret]]
+ %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
+ return %w : vector<2xf16>
+}
+
+// CHECK-LABEL: @vector_trunc_long
+// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
+ // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
+ // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+ // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+ // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
+ // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+ // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
+ // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
+ // CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
+ // CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
+ // CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
+ // CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
+ // CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
+ // CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+ // CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
+ // CHECK: return %[[out4]]
+ %w = arith.truncf %v : vector<9xf32> to vector<9xf16>
+ return %w : vector<9xf16>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
index c7f39440a349b0..cd921da2294e13 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt --split-input-file %s \
-// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
// RUN: | FileCheck %s
// CHECK-LABEL: func.func @scalar_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 26a222a4a788e5..bd90facb615440 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
// CHECK-LABEL: func.func @scalar_ext
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..d04978ff6deeb7 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -516,6 +516,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
llvm.return %source5 : i32
}
+llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
+ // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
+ %source = rocdl.cvt.pkrtz %sourceA, %sourceB : vector<2xf16>
+ llvm.return %source : vector<2xf16>
+}
+
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"
|
cc @pcf000 (not sure why I cannot add you as reviewer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM mod one nit
This reverts commit 1387ba4.
This PR is introducing rocdl.cvt.pkrtz in the ROCDL dialect and it is using that instruction when lowering
arith::TruncFOp
.