-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][LLVM] Add llvm.experimental.constrained.fptrunc
operation
#86260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add operation mapping to the LLVM `llvm.experimental.constrained.fptrunc.*` intrinsic. The new operation implements the new `LLVM::ExceptionBehaviorOpInterface` and `LLVM::RoundingModeOpInterface` interfaces. Signed-off-by: Victor Perez <[email protected]>
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-llvm Author: Victor Perez (victor-eds) ChangesAdd operation mapping to the LLVM The new operation implements the new Patch is 20.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86260.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index a7b269eb41ee2e..19fc69dda16696 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
let cppNamespace = "::mlir::LLVM::framePointerKind";
}
+//===----------------------------------------------------------------------===//
+// RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::RoundingMode ones.
+// See llvm/include/llvm/ADT/FloatingPointMode.h.
+def RoundTowardZero
+ : LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
+def RoundNearestTiesToEven
+ : LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
+def RoundTowardPositive
+ : LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
+def RoundTowardNegative
+ : LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
+def RoundNearestTiesToAway
+ : LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
+def RoundDynamic
+ : LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
+// Needed as llvm::RoundingMode defines this.
+def RoundInvalid
+ : LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
+
+// RoundingModeAttr should not be used in operations definitions.
+// Use ValidRoundingModeAttr instead.
+def RoundingModeAttr : LLVM_EnumAttr<
+ "RoundingMode",
+ "::llvm::RoundingMode",
+ "LLVM Rounding Mode",
+ [RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
+ RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
+def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
+
+//===----------------------------------------------------------------------===//
+// ExceptionBehavior
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::fp::ExceptionBehavior ones.
+// See llvm/include/llvm/IR/FPEnv.h.
+def ExceptionBehaviorIgnore
+ : LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
+def ExceptionBehaviorMayTrap
+ : LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
+def ExceptionBehaviorStrict
+ : LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
+
+def ExceptionBehaviorAttr : LLVM_EnumAttr<
+ "ExceptionBehavior",
+ "::llvm::fp::ExceptionBehavior",
+ "LLVM Exception Behavior",
+ [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
+ ExceptionBehaviorStrict]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
#endif // LLVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index e7a1da8ee560ef..ce91fbe1e2b24a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
];
}
+def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+ let description = [{
+ An interface for operations receiving an exception behavior attribute
+ controlling FP exception behavior.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a ExceptionBehavior attribute for the operation",
+ /*returnType=*/ "ExceptionBehaviorAttr",
+ /*methodName=*/ "getExceptionBehaviorAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getExceptionbehaviorAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the ExceptionBehaviorAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getExceptionBehaviorAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "exceptionbehavior";
+ }]
+ >
+ ];
+}
+
+def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
+ let description = [{
+ An interface for operations receiving a rounding mode attribute
+ controlling FP rounding mode.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a RoundingMode attribute for the operation",
+ /*returnType=*/ "RoundingModeAttr",
+ /*methodName=*/ "getRoundingModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getRoundingmodeAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the RoundingModeAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getRoundingModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "roundingmode";
+ }]
+ >,
+ ];
+}
//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b88f1186a44b49..6a2b9b90350e1a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -311,6 +311,47 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
"qualified(type($ptr))";
}
+// Constrained Floating-Point Intrinsics
+
+class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+ : LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
+ /*overloadedResults=*/[0],
+ /*overloadedOperands=*/[0],
+ /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+ # !cond(
+ !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
+ true : []),
+ /*requiresFastmath=*/0,
+ /*immArgPositions=*/[],
+ /*immArgAttrNames=*/[]> {
+ dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
+ dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
+ true : (ins)),
+ (ins ExceptionBehaviorAttr:$exceptionbehavior));
+ let arguments = !con(regularArgs, attrArgs);
+ let llvmBuilder = [{
+ $res = LLVM::detail::createConstrainedIntrinsicCall(
+ builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
+ # mnem
+ # [{);
+ }];
+ let mlirBuilder = [{
+ auto op = moduleImport.translateConstrainedIntrinsic(
+ $_location, $_resultType, llvmOperands,
+ $_qualCppClassName::getOperationName());
+ if (!op)
+ return failure();
+ $res = op;
+ }];
+}
+
+def LLVM_ConstrainedFPTruncIntr
+ : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+ let assemblyFormat = [{
+ $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+ }];
+}
+
// Intrinsics with multiple returns.
class LLVM_ArithWithOverflowOp<string mnem>
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b49d2f539453e6..16f9994c126e06 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -232,6 +232,11 @@ class ModuleImport {
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);
+ /// Import constrained intrinsic.
+ Value translateConstrainedIntrinsic(Location loc, Type type,
+ ArrayRef<llvm::Value *> llvmOperands,
+ StringRef opName);
+
private:
/// Clears the accumulated state before processing a new region.
void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index fb4392eb223c7f..458ed585167bc0 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -423,6 +423,10 @@ llvm::CallInst *createIntrinsicCall(
ArrayRef<unsigned> immArgPositions,
ArrayRef<StringLiteral> immArgAttrNames);
+llvm::CallInst *createConstrainedIntrinsicCall(
+ llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+ Operation *intrOp, llvm::Intrinsic::ID intrinsic);
+
} // namespace detail
} // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d63ea12ecd49b1..85a543c174f51d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,6 +1258,92 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
return success();
}
+static RoundingModeAttr metadataToRoundingMode(Builder &builder,
+ llvm::Metadata *metadata) {
+ auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+ if (!mdstr)
+ return {};
+ std::optional<llvm::RoundingMode> optLLVM =
+ llvm::convertStrToRoundingMode(mdstr->getString());
+ if (!optLLVM)
+ return {};
+ return builder.getAttr<RoundingModeAttr>(
+ convertRoundingModeFromLLVM(*optLLVM));
+}
+
+static ExceptionBehaviorAttr
+metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+ auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+ if (!mdstr)
+ return {};
+ std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+ llvm::convertStrToExceptionBehavior(mdstr->getString());
+ if (!optLLVM)
+ return {};
+ return builder.getAttr<ExceptionBehaviorAttr>(
+ convertExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+static void
+splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
+ SmallVectorImpl<llvm::Value *> &values,
+ SmallVectorImpl<llvm::Metadata *> &metadata) {
+ for (llvm::Value *in : inputs) {
+ if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
+ metadata.push_back(mdval->getMetadata());
+ } else {
+ values.push_back(in);
+ }
+ }
+}
+
+Value ModuleImport::translateConstrainedIntrinsic(
+ Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
+ StringRef opName) {
+ // Split metadata values from regular ones.
+ SmallVector<llvm::Value *> values;
+ SmallVector<llvm::Metadata *> metadata;
+ splitMetadataAndValues(llvmOperands, values, metadata);
+
+ // Expect 1 or 2 metadata values.
+ assert((metadata.size() == 1 || metadata.size() == 2) &&
+ "Unexpected number of arguments");
+
+ SmallVector<Value> mlirOperands;
+ SmallVector<NamedAttribute> mlirAttrs;
+ if (failed(
+ convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
+ return {};
+ }
+
+ // Create operation as usual.
+ StringAttr opNameAttr = builder.getStringAttr(opName);
+ Operation *op =
+ builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
+
+ // Set exception behavior attribute.
+ auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
+ ExceptionBehaviorAttr attr =
+ metadataToExceptionBehavior(builder, metadata.back());
+ if (!attr)
+ return {};
+ op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
+
+ // If avaialbe, set rounding mode attribute.
+ if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
+ assert(metadata.size() > 1 && "Unexpected number of arguments");
+ // rounding_mode present
+ RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
+ if (!attr)
+ return {};
+ roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
+ } else {
+ assert(metadata.size() == 1 && "Unexpected number of arguments");
+ }
+
+ return op->getResult(0);
+}
+
IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
IntegerAttr integerAttr;
FailureOr<Value> converted = convertValue(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f90495d407fdfe..9100b1a18c71c6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,6 +862,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(llvmIntr, args);
}
+llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
+ llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+ Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ SmallVector<llvm::Type *> overloadedTypes{
+ moduleTranslation.convertType(intrOp->getResult(0).getType()),
+ moduleTranslation.convertType(intrOp->getOperand(0).getType())};
+ llvm::Function *callee =
+ llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+ SmallVector<llvm::Value *> args =
+ moduleTranslation.lookupValues(intrOp->getOperands());
+ std::optional<llvm::RoundingMode> rounding;
+ if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
+ rounding = convertRoundingModeToLLVM(
+ roundingModeOp.getRoundingModeAttr().getValue());
+ }
+ llvm::fp::ExceptionBehavior except =
+ convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
+ .getExceptionBehaviorAttr()
+ .getValue());
+ return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
+}
+
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
LogicalResult ModuleTranslation::convertOperation(Operation &op,
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b157cf00141842..cc415e7b3662be 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -647,3 +647,22 @@ llvm.func @experimental_noalias_scope_decl() {
llvm.intr.experimental.noalias.scope.decl #alias_scope
llvm.return
}
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+ %0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+ %1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+ %2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+ %3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+ %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
+ %tmp0 = llvm.fadd %0, %1 : f32
+ %tmp1 = llvm.fadd %2, %3 : f32
+ %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
+ %res = llvm.fadd %tmp2, %4 : f32
+ llvm.return %res : f32
+}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 1ec9005458c50b..85561839f31a70 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
ret float %2
}
+; CHECK-LABEL: experimental_constrained_fptrunc
+define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+ %1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+ %2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+ %3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+ %4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+ %5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
+ %6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+ ret void
+}
+
declare float @llvm.fmuladd.f32(float, float, float)
declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
declare float @llvm.ssa.copy.f32(float returned)
declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
+declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
+declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index fc2e0fd201a738..0013522582a727 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -964,6 +964,35 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
llvm.return %0 : f32
}
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.towardzero"
+ // CHECK: metadata !"fpexcept.ignore"
+ %0 = llvm.intr.experimental.constrained.fptrunc %s towardzero ignore : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.tonearest"
+ // CHECK: metadata !"fpexcept.maytrap"
+ %1 = llvm.intr.experimental.constrained.fptrunc %s tonearest maytrap : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.upward"
+ // CHECK: metadata !"fpexcept.strict"
+ %2 = llvm.intr.experimental.constrained.fptrunc %s upward strict : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.downward"
+ // CHECK: metadata !"fpexcept.ignore"
+ %3 = llvm.intr.experimental.constrained.fptrunc %s downward ignore : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.tonearestaway"
+ // CHECK: metadata !"fpexcept.ignore"
+ %4 = llvm.intr.experimental.constrained.fptrunc %s tonearestaway ignore : f64 to f32
+ // CHECK: call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(
+ // CHECK: metadata !"round.upward"
+ // CHECK: metadata !"fpexcept.strict"
+ %5 = llvm.intr.experimental.constrained.fptrunc %v upward strict : vector<4xf32> to vector<4xf16>
+ llvm.ret...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Victor Perez (victor-eds) ChangesAdd operation mapping to the LLVM The new operation implements the new Patch is 20.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86260.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index a7b269eb41ee2e..19fc69dda16696 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
let cppNamespace = "::mlir::LLVM::framePointerKind";
}
+//===----------------------------------------------------------------------===//
+// RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::RoundingMode ones.
+// See llvm/include/llvm/ADT/FloatingPointMode.h.
+def RoundTowardZero
+ : LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
+def RoundNearestTiesToEven
+ : LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
+def RoundTowardPositive
+ : LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
+def RoundTowardNegative
+ : LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
+def RoundNearestTiesToAway
+ : LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
+def RoundDynamic
+ : LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
+// Needed as llvm::RoundingMode defines this.
+def RoundInvalid
+ : LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
+
+// RoundingModeAttr should not be used in operations definitions.
+// Use ValidRoundingModeAttr instead.
+def RoundingModeAttr : LLVM_EnumAttr<
+ "RoundingMode",
+ "::llvm::RoundingMode",
+ "LLVM Rounding Mode",
+ [RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
+ RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
+def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
+
+//===----------------------------------------------------------------------===//
+// ExceptionBehavior
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::fp::ExceptionBehavior ones.
+// See llvm/include/llvm/IR/FPEnv.h.
+def ExceptionBehaviorIgnore
+ : LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
+def ExceptionBehaviorMayTrap
+ : LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
+def ExceptionBehaviorStrict
+ : LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
+
+def ExceptionBehaviorAttr : LLVM_EnumAttr<
+ "ExceptionBehavior",
+ "::llvm::fp::ExceptionBehavior",
+ "LLVM Exception Behavior",
+ [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
+ ExceptionBehaviorStrict]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
#endif // LLVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index e7a1da8ee560ef..ce91fbe1e2b24a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
];
}
+def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+ let description = [{
+ An interface for operations receiving an exception behavior attribute
+ controlling FP exception behavior.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a ExceptionBehavior attribute for the operation",
+ /*returnType=*/ "ExceptionBehaviorAttr",
+ /*methodName=*/ "getExceptionBehaviorAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getExceptionbehaviorAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the ExceptionBehaviorAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getExceptionBehaviorAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "exceptionbehavior";
+ }]
+ >
+ ];
+}
+
+def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
+ let description = [{
+ An interface for operations receiving a rounding mode attribute
+ controlling FP rounding mode.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a RoundingMode attribute for the operation",
+ /*returnType=*/ "RoundingModeAttr",
+ /*methodName=*/ "getRoundingModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getRoundingmodeAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the RoundingModeAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getRoundingModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "roundingmode";
+ }]
+ >,
+ ];
+}
//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b88f1186a44b49..6a2b9b90350e1a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -311,6 +311,47 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
"qualified(type($ptr))";
}
+// Constrained Floating-Point Intrinsics
+
+class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+ : LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
+ /*overloadedResults=*/[0],
+ /*overloadedOperands=*/[0],
+ /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+ # !cond(
+ !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
+ true : []),
+ /*requiresFastmath=*/0,
+ /*immArgPositions=*/[],
+ /*immArgAttrNames=*/[]> {
+ dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
+ dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
+ true : (ins)),
+ (ins ExceptionBehaviorAttr:$exceptionbehavior));
+ let arguments = !con(regularArgs, attrArgs);
+ let llvmBuilder = [{
+ $res = LLVM::detail::createConstrainedIntrinsicCall(
+ builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
+ # mnem
+ # [{);
+ }];
+ let mlirBuilder = [{
+ auto op = moduleImport.translateConstrainedIntrinsic(
+ $_location, $_resultType, llvmOperands,
+ $_qualCppClassName::getOperationName());
+ if (!op)
+ return failure();
+ $res = op;
+ }];
+}
+
+def LLVM_ConstrainedFPTruncIntr
+ : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+ let assemblyFormat = [{
+ $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+ }];
+}
+
// Intrinsics with multiple returns.
class LLVM_ArithWithOverflowOp<string mnem>
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b49d2f539453e6..16f9994c126e06 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -232,6 +232,11 @@ class ModuleImport {
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);
+ /// Import constrained intrinsic.
+ Value translateConstrainedIntrinsic(Location loc, Type type,
+ ArrayRef<llvm::Value *> llvmOperands,
+ StringRef opName);
+
private:
/// Clears the accumulated state before processing a new region.
void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index fb4392eb223c7f..458ed585167bc0 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -423,6 +423,10 @@ llvm::CallInst *createIntrinsicCall(
ArrayRef<unsigned> immArgPositions,
ArrayRef<StringLiteral> immArgAttrNames);
+llvm::CallInst *createConstrainedIntrinsicCall(
+ llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+ Operation *intrOp, llvm::Intrinsic::ID intrinsic);
+
} // namespace detail
} // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d63ea12ecd49b1..85a543c174f51d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,6 +1258,92 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
return success();
}
+static RoundingModeAttr metadataToRoundingMode(Builder &builder,
+ llvm::Metadata *metadata) {
+ auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+ if (!mdstr)
+ return {};
+ std::optional<llvm::RoundingMode> optLLVM =
+ llvm::convertStrToRoundingMode(mdstr->getString());
+ if (!optLLVM)
+ return {};
+ return builder.getAttr<RoundingModeAttr>(
+ convertRoundingModeFromLLVM(*optLLVM));
+}
+
+static ExceptionBehaviorAttr
+metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+ auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+ if (!mdstr)
+ return {};
+ std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+ llvm::convertStrToExceptionBehavior(mdstr->getString());
+ if (!optLLVM)
+ return {};
+ return builder.getAttr<ExceptionBehaviorAttr>(
+ convertExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+static void
+splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
+ SmallVectorImpl<llvm::Value *> &values,
+ SmallVectorImpl<llvm::Metadata *> &metadata) {
+ for (llvm::Value *in : inputs) {
+ if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
+ metadata.push_back(mdval->getMetadata());
+ } else {
+ values.push_back(in);
+ }
+ }
+}
+
+Value ModuleImport::translateConstrainedIntrinsic(
+ Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
+ StringRef opName) {
+ // Split metadata values from regular ones.
+ SmallVector<llvm::Value *> values;
+ SmallVector<llvm::Metadata *> metadata;
+ splitMetadataAndValues(llvmOperands, values, metadata);
+
+ // Expect 1 or 2 metadata values.
+ assert((metadata.size() == 1 || metadata.size() == 2) &&
+ "Unexpected number of arguments");
+
+ SmallVector<Value> mlirOperands;
+ SmallVector<NamedAttribute> mlirAttrs;
+ if (failed(
+ convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
+ return {};
+ }
+
+ // Create operation as usual.
+ StringAttr opNameAttr = builder.getStringAttr(opName);
+ Operation *op =
+ builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
+
+ // Set exception behavior attribute.
+ auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
+ ExceptionBehaviorAttr attr =
+ metadataToExceptionBehavior(builder, metadata.back());
+ if (!attr)
+ return {};
+ op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
+
+ // If avaialbe, set rounding mode attribute.
+ if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
+ assert(metadata.size() > 1 && "Unexpected number of arguments");
+ // rounding_mode present
+ RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
+ if (!attr)
+ return {};
+ roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
+ } else {
+ assert(metadata.size() == 1 && "Unexpected number of arguments");
+ }
+
+ return op->getResult(0);
+}
+
IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
IntegerAttr integerAttr;
FailureOr<Value> converted = convertValue(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f90495d407fdfe..9100b1a18c71c6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,6 +862,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(llvmIntr, args);
}
+llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
+ llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+ Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ SmallVector<llvm::Type *> overloadedTypes{
+ moduleTranslation.convertType(intrOp->getResult(0).getType()),
+ moduleTranslation.convertType(intrOp->getOperand(0).getType())};
+ llvm::Function *callee =
+ llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+ SmallVector<llvm::Value *> args =
+ moduleTranslation.lookupValues(intrOp->getOperands());
+ std::optional<llvm::RoundingMode> rounding;
+ if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
+ rounding = convertRoundingModeToLLVM(
+ roundingModeOp.getRoundingModeAttr().getValue());
+ }
+ llvm::fp::ExceptionBehavior except =
+ convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
+ .getExceptionBehaviorAttr()
+ .getValue());
+ return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
+}
+
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
LogicalResult ModuleTranslation::convertOperation(Operation &op,
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b157cf00141842..cc415e7b3662be 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -647,3 +647,22 @@ llvm.func @experimental_noalias_scope_decl() {
llvm.intr.experimental.noalias.scope.decl #alias_scope
llvm.return
}
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+ %0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+ %1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+ %2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+ %3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+ %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
+ %tmp0 = llvm.fadd %0, %1 : f32
+ %tmp1 = llvm.fadd %2, %3 : f32
+ %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
+ %res = llvm.fadd %tmp2, %4 : f32
+ llvm.return %res : f32
+}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 1ec9005458c50b..85561839f31a70 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
ret float %2
}
+; CHECK-LABEL: experimental_constrained_fptrunc
+define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+ %1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+ %2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+ %3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+ %4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+ %5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
+ %6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+ ret void
+}
+
declare float @llvm.fmuladd.f32(float, float, float)
declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
declare float @llvm.ssa.copy.f32(float returned)
declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
+declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
+declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index fc2e0fd201a738..0013522582a727 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -964,6 +964,35 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
llvm.return %0 : f32
}
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.towardzero"
+ // CHECK: metadata !"fpexcept.ignore"
+ %0 = llvm.intr.experimental.constrained.fptrunc %s towardzero ignore : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.tonearest"
+ // CHECK: metadata !"fpexcept.maytrap"
+ %1 = llvm.intr.experimental.constrained.fptrunc %s tonearest maytrap : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.upward"
+ // CHECK: metadata !"fpexcept.strict"
+ %2 = llvm.intr.experimental.constrained.fptrunc %s upward strict : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.downward"
+ // CHECK: metadata !"fpexcept.ignore"
+ %3 = llvm.intr.experimental.constrained.fptrunc %s downward ignore : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.tonearestaway"
+ // CHECK: metadata !"fpexcept.ignore"
+ %4 = llvm.intr.experimental.constrained.fptrunc %s tonearestaway ignore : f64 to f32
+ // CHECK: call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(
+ // CHECK: metadata !"round.upward"
+ // CHECK: metadata !"fpexcept.strict"
+ %5 = llvm.intr.experimental.constrained.fptrunc %v upward strict : vector<4xf32> to vector<4xf16>
+ llvm.ret...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
I added some suggestions that are a bit closer to the solutions we took in other places. Specifically when importing intrinsic with metadata arguments. If the mechanism also works in your case, it would be great to follow the same idea.
Signed-off-by: Victor Perez <[email protected]>
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, so if @gysit is happy then so am I 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes!
LGTM modulo some minor comments.
Its a shame these intrinsics seem to exists with all kinds of argument number and rounding mode configurations. Otherwise, it may have been nice to factor out base classes unary / binary etc to make things a bit less complex...
I guess we could have some kind of argument for the base intrinsic class in a similar way to immediate values, but that'd require further modifications. I don't think it'd be 100 % automatic anyway... |
A solution where the arguments are defined in a derived class being the actual intrinsic or some unary / binary etc base class could be nice. Especially since that may allow us to properly type the intrinsic operands etc. However, with tablegen it is really hard to guess how complex things are in the end. If you have a good idea you are very welcome to give it a shot. Otherwise it is also ok to land and follow up if needed (e.g. after adding more intrinsics). |
Yes, that's more or less what I was thinking of. I'm depending on this PR for further work in other dialects, so I'd rather get this merged and give this a shot in the future. I'll wait till CI passes. Thanks a lot! |
…ration (#129054) Ref: llvm/llvm-project#86260
Add operation mapping to the LLVM
llvm.experimental.constrained.fptrunc.*
intrinsic.The new operation implements the new
LLVM::ExceptionBehaviorOpInterface
andLLVM::RoundingModeOpInterface
interfaces.