Skip to content

Commit a28f657

Browse files
authored
[MLIR][NVVM] Add support for f6x2 conversion (llvm#136537)
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversions into the f6x2 types, `e2m3x2` and `e3m2x2`. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent 268f0d4 commit a28f657

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,59 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
10671067
}];
10681068
}
10691069

1070+
def CVTFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
1071+
def CVTFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
1072+
1073+
def CVTFP6Type : I32EnumAttr<"CVTFP6Type", "NVVM CVTFP6Type kind",
1074+
[CVTFP6E2M3, CVTFP6E3M2]> {
1075+
let genSpecializedAttr = 0;
1076+
let cppNamespace = "::mlir::NVVM";
1077+
}
1078+
def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
1079+
let assemblyFormat = "`<` $value `>`";
1080+
}
1081+
1082+
def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1083+
let summary = "Convert a pair of float inputs to f6x2";
1084+
let description = [{
1085+
This Op converts each of the given float inputs to the specified fp6 type.
1086+
The result `dst` is represented either as an i16 type or as a vector
1087+
of two i8 types.
1088+
If `dst` is returned as an i16 type, the converted values are packed such
1089+
that the value converted from `a` is stored in the upper 8 bits of `dst`
1090+
with 2 MSB bits padded with zeros and the value converted from `b` is
1091+
stored in the lower 8 bits of `dst` with 2 MSB bits padded with zeros.
1092+
If `dst` is returned as a vector type, each converted value is stored as an
1093+
i8 element in the vector.
1094+
The `relu` attribute, when set, lowers to the '.relu' variant of
1095+
the cvt instruction.
1096+
1097+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1098+
}];
1099+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1100+
let arguments = (ins
1101+
CVTFP6TypeAttr:$type,
1102+
F32:$a,
1103+
F32:$b,
1104+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1105+
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1106+
1107+
let extraClassDeclaration = [{
1108+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP6Type,
1109+
bool hasRelu);
1110+
}];
1111+
1112+
string llvmBuilder = [{
1113+
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
1114+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1115+
if(op.getDst().getType().isInteger(16))
1116+
$dst = packedI16;
1117+
else
1118+
$dst = builder.CreateBitCast(packedI16,
1119+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1120+
}];
1121+
}
1122+
10701123
//===----------------------------------------------------------------------===//
10711124
// NVVM MMA Ops
10721125
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,22 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12901290
}
12911291
}
12921292

1293+
#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \
1294+
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1295+
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1296+
1297+
llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1298+
bool hasRelu) {
1299+
switch (type) {
1300+
case NVVM::CVTFP6Type::E2M3:
1301+
return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
1302+
case NVVM::CVTFP6Type::E3M2:
1303+
return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
1304+
default:
1305+
llvm_unreachable("Invalid CVTFP6Type for CvtToF6x2Op");
1306+
}
1307+
}
1308+
12931309
llvm::Intrinsic::ID
12941310
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
12951311
LLVM::ModuleTranslation &mt,
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_float_to_fp6x2_packed
4+
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
5+
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6+
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
7+
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8+
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
9+
llvm.return
10+
}
11+
12+
// CHECK-LABEL: @convert_float_to_fp6x2_vector
13+
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
14+
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
15+
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
16+
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
17+
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
18+
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
19+
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
20+
llvm.return
21+
}
22+

0 commit comments

Comments
 (0)