@@ -1067,6 +1067,59 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
1067
1067
}];
1068
1068
}
1069
1069
1070
+ def CVTFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
1071
+ def CVTFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
1072
+
1073
+ def CVTFP6Type : I32EnumAttr<"CVTFP6Type", "NVVM CVTFP6Type kind",
1074
+ [CVTFP6E2M3, CVTFP6E3M2]> {
1075
+ let genSpecializedAttr = 0;
1076
+ let cppNamespace = "::mlir::NVVM";
1077
+ }
1078
+ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
1079
+ let assemblyFormat = "`<` $value `>`";
1080
+ }
1081
+
1082
+ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1083
+ let summary = "Convert a pair of float inputs to f6x2";
1084
+ let description = [{
1085
+ This Op converts each of the given float inputs to the specified fp6 type.
1086
+ The result `dst` is represented either as an i16 type or as a vector
1087
+ of two i8 types.
1088
+ If `dst` is returned as an i16 type, the converted values are packed such
1089
+ that the value converted from `a` is stored in the upper 8 bits of `dst`
1090
+ with 2 MSB bits padded with zeros and the value converted from `b` is
1091
+ stored in the lower 8 bits of `dst` with 2 MSB bits padded with zeros.
1092
+ If `dst` is returned as a vector type, each converted value is stored as an
1093
+ i8 element in the vector.
1094
+ The `relu` attribute, when set, lowers to the '.relu' variant of
1095
+ the cvt instruction.
1096
+
1097
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1098
+ }];
1099
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1100
+ let arguments = (ins
1101
+ CVTFP6TypeAttr:$type,
1102
+ F32:$a,
1103
+ F32:$b,
1104
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
1105
+ let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1106
+
1107
+ let extraClassDeclaration = [{
1108
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP6Type,
1109
+ bool hasRelu);
1110
+ }];
1111
+
1112
+ string llvmBuilder = [{
1113
+ auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
1114
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1115
+ if(op.getDst().getType().isInteger(16))
1116
+ $dst = packedI16;
1117
+ else
1118
+ $dst = builder.CreateBitCast(packedI16,
1119
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1120
+ }];
1121
+ }
1122
+
1070
1123
//===----------------------------------------------------------------------===//
1071
1124
// NVVM MMA Ops
1072
1125
//===----------------------------------------------------------------------===//
0 commit comments