-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][NVVM] Add support for f6x2 conversion #136537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add support for f6x2 conversion #136537
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis patch adds the For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync Full diff: https://github.com/llvm/llvm-project/pull/136537.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..b8581a7504c67 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1032,6 +1032,55 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
}];
}
+def FP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
+def FP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
+
+def FP6Type : I32EnumAttr<"FP6Type", "NVVM FP6Type kind",
+ [FP6E2M3, FP6E3M2]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def FP6TypeAttr : EnumAttr<NVVM_Dialect, FP6Type, "fp6_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
+ let summary = "Convert the given float input to f6x2";
+ let description = [{
+ This Op converts the given float input to f6x2.
+ The result `res` is represented as an i16 type.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction. The `rnd` and `sat` attributes specify the
+ the rounding and saturation modes respectively.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ FP6TypeAttr:$type,
+ F32:$a,
+ F32:$b,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RN">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::SATFINITE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::FP6Type,
+ bool hasRelu);
+ bool isPacked();
+ llvm::Value* getCastedResult(llvm::Value* packedI16, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+ $dst = op.getCastedResult(packedI16, builder);
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..8540a88653973 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
@@ -160,6 +161,33 @@ LogicalResult CvtFloatToTF32Op::verify() {
return success();
}
+bool CvtToF6x2Op::isPacked() {
+ if (getDst().getType().isInteger(16)) {
+ return true;
+ }
+ return false;
+}
+
+llvm::Value *CvtToF6x2Op::getCastedResult(llvm::Value *packedI16,
+ llvm::IRBuilderBase &builder) {
+ if (isPacked()) {
+ return packedI16;
+ }
+ return builder.CreateBitCast(
+ packedI16, llvm::FixedVectorType::get(
+ llvm::Type::getInt8Ty(builder.getContext()), 2));
+}
+
+LogicalResult CvtToF6x2Op::verify() {
+ if (getRnd() != NVVM::FPRoundingMode::RN) {
+ return emitOpError("RN rounding mode required for CvtToF6x2Op.");
+ }
+ if (getSat() != NVVM::SaturationMode::SATFINITE) {
+ return emitOpError("SATFINITE saturation mode required for CvtToF6x2Op.");
+ }
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1300,6 +1328,23 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+#define CVT_TO_F6X2_ID_IMPL(type, relu) \
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn##relu##_satfinite \
+ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
+
+llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::FP6Type type,
+ bool hasRelu) {
+ switch (type) {
+ case NVVM::FP6Type::E2M3:
+ return CVT_TO_F6X2_ID_IMPL(e2m3x2, _relu);
+ case NVVM::FP6Type::E3M2:
+ return CVT_TO_F6X2_ID_IMPL(e3m2x2, _relu);
+ default:
+ break;
+ }
+ llvm_unreachable("Invalid FP6Type for CvtToF6x2Op");
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
new file mode 100644
index 0000000000000..2237e6faad52d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp6x2_packed
+llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp6x2_vector
+llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
+ //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+ //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..5fcef1aa67139 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,19 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f6x2(%a : f32, %b : f32) {
+ // expected-error @below {{RN rounding mode required for CvtToF6x2Op.}}
+ %res = nvvm.cvt.to.f6x2 <e2m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rna>} : i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f6x2_packed(%a : f32, %b : f32) {
+ // expected-error @below {{SATFINITE saturation mode required for CvtToF6x2Op.}}
+ %res = nvvm.cvt.to.f6x2 <e3m2> %a, %b {sat = #nvvm.sat_mode<none>} : i16
+ llvm.return
+}
|
@llvm/pr-subscribers-mlir Author: Srinivasa Ravi (Wolfram70) ChangesThis patch adds the For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync Full diff: https://github.com/llvm/llvm-project/pull/136537.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..b8581a7504c67 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1032,6 +1032,55 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
}];
}
+def FP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
+def FP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
+
+def FP6Type : I32EnumAttr<"FP6Type", "NVVM FP6Type kind",
+ [FP6E2M3, FP6E3M2]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def FP6TypeAttr : EnumAttr<NVVM_Dialect, FP6Type, "fp6_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
+ let summary = "Convert the given float input to f6x2";
+ let description = [{
+ This Op converts the given float input to f6x2.
+ The result `res` is represented as an i16 type.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction. The `rnd` and `sat` attributes specify the
+ the rounding and saturation modes respectively.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ FP6TypeAttr:$type,
+ F32:$a,
+ F32:$b,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RN">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::SATFINITE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::FP6Type,
+ bool hasRelu);
+ bool isPacked();
+ llvm::Value* getCastedResult(llvm::Value* packedI16, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+ $dst = op.getCastedResult(packedI16, builder);
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..8540a88653973 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
@@ -160,6 +161,33 @@ LogicalResult CvtFloatToTF32Op::verify() {
return success();
}
+bool CvtToF6x2Op::isPacked() {
+ if (getDst().getType().isInteger(16)) {
+ return true;
+ }
+ return false;
+}
+
+llvm::Value *CvtToF6x2Op::getCastedResult(llvm::Value *packedI16,
+ llvm::IRBuilderBase &builder) {
+ if (isPacked()) {
+ return packedI16;
+ }
+ return builder.CreateBitCast(
+ packedI16, llvm::FixedVectorType::get(
+ llvm::Type::getInt8Ty(builder.getContext()), 2));
+}
+
+LogicalResult CvtToF6x2Op::verify() {
+ if (getRnd() != NVVM::FPRoundingMode::RN) {
+ return emitOpError("RN rounding mode required for CvtToF6x2Op.");
+ }
+ if (getSat() != NVVM::SaturationMode::SATFINITE) {
+ return emitOpError("SATFINITE saturation mode required for CvtToF6x2Op.");
+ }
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1300,6 +1328,23 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+#define CVT_TO_F6X2_ID_IMPL(type, relu) \
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn##relu##_satfinite \
+ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
+
+llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::FP6Type type,
+ bool hasRelu) {
+ switch (type) {
+ case NVVM::FP6Type::E2M3:
+ return CVT_TO_F6X2_ID_IMPL(e2m3x2, _relu);
+ case NVVM::FP6Type::E3M2:
+ return CVT_TO_F6X2_ID_IMPL(e3m2x2, _relu);
+ default:
+ break;
+ }
+ llvm_unreachable("Invalid FP6Type for CvtToF6x2Op");
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
new file mode 100644
index 0000000000000..2237e6faad52d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp6x2_packed
+llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp6x2_vector
+llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
+ //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+ //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..5fcef1aa67139 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,19 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f6x2(%a : f32, %b : f32) {
+ // expected-error @below {{RN rounding mode required for CvtToF6x2Op.}}
+ %res = nvvm.cvt.to.f6x2 <e2m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rna>} : i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f6x2_packed(%a : f32, %b : f32) {
+ // expected-error @below {{SATFINITE saturation mode required for CvtToF6x2Op.}}
+ %res = nvvm.cvt.to.f6x2 <e3m2> %a, %b {sat = #nvvm.sat_mode<none>} : i16
+ llvm.return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
f040873
to
f0c4fb5
Compare
The PTX ISA link in the commit message is pointing to shfl_sync. Please update it to refer to the |
f0c4fb5
to
329c76e
Compare
329c76e
to
5960c65
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest revision LGTM.
5960c65
to
947c61a
Compare
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversion into f6x2 types. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
947c61a
to
9815a64
Compare
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversions into the f6x2 types, `e2m3x2` and `e3m2x2`. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversions into the f6x2 types, `e2m3x2` and `e3m2x2`. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversions into the f6x2 types, `e2m3x2` and `e3m2x2`. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversions into the f6x2 types, `e2m3x2` and `e3m2x2`. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversions into the f6x2 types, `e2m3x2` and `e3m2x2`. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This patch adds the
cvt.to.fp6x2
NVVM dialect Op for conversions into the f6x2 types,e2m3x2
ande3m2x2
.For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt