Skip to content

Commit 947c61a

Browse files
committed
[MLIR][NVVM] Add support for f6x2 conversion
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversion into f6x2 types. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
1 parent 8435de0 commit 947c61a

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,59 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
10661066
}];
10671067
}
10681068

1069+
def CVTFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
1070+
def CVTFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
1071+
1072+
def CVTFP6Type : I32EnumAttr<"CVTFP6Type", "NVVM CVTFP6Type kind",
1073+
[CVTFP6E2M3, CVTFP6E3M2]> {
1074+
let genSpecializedAttr = 0;
1075+
let cppNamespace = "::mlir::NVVM";
1076+
}
1077+
def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
1078+
let assemblyFormat = "`<` $value `>`";
1079+
}
1080+
1081+
def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1082+
let summary = "Convert a pair of float inputs to f6x2";
1083+
let description = [{
1084+
This Op converts each of the given float inputs to the specified fp6 type.
1085+
The result `dst` is represented either as an i16 type or as a vector
1086+
of two i8 types.
1087+
If `dst` is returned as an i16 type, the converted values are packed such
1088+
that the value converted from `a` is stored in the upper 8 bits of `dst`
1089+
with 2 MSB bits padded with zeros and the value converted from `b` is
1090+
stored in the lower 8 bits of `dst` with 2 MSB bits padded with zeros.
1091+
If `dst` is returned as a vector type, each converted value is stored as an
1092+
i8 element in the vector.
1093+
The `relu` attribute, when set, lowers to the '.relu' variant of
1094+
the cvt instruction.
1095+
1096+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1097+
}];
1098+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1099+
let arguments = (ins
1100+
CVTFP6TypeAttr:$type,
1101+
F32:$a,
1102+
F32:$b,
1103+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1104+
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1105+
1106+
let extraClassDeclaration = [{
1107+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP6Type,
1108+
bool hasRelu);
1109+
}];
1110+
1111+
string llvmBuilder = [{
1112+
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
1113+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1114+
if(op.getDst().getType().isInteger(16))
1115+
$dst = packedI16;
1116+
else
1117+
$dst = builder.CreateBitCast(packedI16,
1118+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1119+
}];
1120+
}
1121+
10691122
//===----------------------------------------------------------------------===//
10701123
// NVVM MMA Ops
10711124
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/AsmParser/Parser.h"
3434
#include "llvm/IR/Attributes.h"
3535
#include "llvm/IR/Function.h"
36+
#include "llvm/IR/IRBuilder.h"
3637
#include "llvm/IR/IntrinsicsNVPTX.h"
3738
#include "llvm/IR/Type.h"
3839
#include "llvm/Support/Casting.h"
@@ -1290,6 +1291,22 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12901291
}
12911292
}
12921293

1294+
#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \
1295+
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1296+
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1297+
1298+
llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1299+
bool hasRelu) {
1300+
switch (type) {
1301+
case NVVM::CVTFP6Type::E2M3:
1302+
return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
1303+
case NVVM::CVTFP6Type::E3M2:
1304+
return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
1305+
default:
1306+
llvm_unreachable("Invalid CVTFP6Type for CvtToF6x2Op");
1307+
}
1308+
}
1309+
12931310
llvm::Intrinsic::ID
12941311
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
12951312
LLVM::ModuleTranslation &mt,
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_float_to_fp6x2_packed
4+
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
5+
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6+
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
7+
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8+
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
9+
llvm.return
10+
}
11+
12+
// CHECK-LABEL: @convert_float_to_fp6x2_vector
13+
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
14+
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
15+
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
16+
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
17+
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
18+
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
19+
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
20+
llvm.return
21+
}
22+

0 commit comments

Comments
 (0)