Skip to content

[MLIR][LLVM] Add llvm.experimental.constrained.fptrunc operation #86260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 26, 2024

Conversation

victor-eds
Copy link
Contributor

Add operation mapping to the LLVM
llvm.experimental.constrained.fptrunc.* intrinsic.

The new operation implements the new
LLVM::ExceptionBehaviorOpInterface and
LLVM::RoundingModeOpInterface interfaces.

Add operation mapping to the LLVM
`llvm.experimental.constrained.fptrunc.*` intrinsic.

The new operation implements the new
`LLVM::ExceptionBehaviorOpInterface` and
`LLVM::RoundingModeOpInterface` interfaces.

Signed-off-by: Victor Perez <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-llvm

Author: Victor Perez (victor-eds)

Changes

Add operation mapping to the LLVM
llvm.experimental.constrained.fptrunc.* intrinsic.

The new operation implements the new
LLVM::ExceptionBehaviorOpInterface and
LLVM::RoundingModeOpInterface interfaces.


Patch is 20.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86260.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td (+57)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+67)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td (+41)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+5)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+4)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+86)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+23)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+19)
  • (modified) mlir/test/Target/LLVMIR/Import/intrinsic.ll (+19)
  • (modified) mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir (+31)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index a7b269eb41ee2e..19fc69dda16696 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
   let cppNamespace = "::mlir::LLVM::framePointerKind";
 }
 
+//===----------------------------------------------------------------------===//
+// RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::RoundingMode ones.
+// See llvm/include/llvm/ADT/FloatingPointMode.h.
+def RoundTowardZero
+    : LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
+def RoundNearestTiesToEven
+    : LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
+def RoundTowardPositive
+    : LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
+def RoundTowardNegative
+    : LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
+def RoundNearestTiesToAway
+    : LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
+def RoundDynamic
+    : LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
+// Needed as llvm::RoundingMode defines this.
+def RoundInvalid
+    : LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
+
+// RoundingModeAttr should not be used in operations definitions.
+// Use ValidRoundingModeAttr instead.
+def RoundingModeAttr : LLVM_EnumAttr<
+    "RoundingMode",
+    "::llvm::RoundingMode",
+    "LLVM Rounding Mode",
+    [RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
+     RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
+def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
+
+//===----------------------------------------------------------------------===//
+// ExceptionBehavior
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::fp::ExceptionBehavior ones.
+// See llvm/include/llvm/IR/FPEnv.h.
+def ExceptionBehaviorIgnore
+    : LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
+def ExceptionBehaviorMayTrap
+    : LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
+def ExceptionBehaviorStrict
+    : LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
+
+def ExceptionBehaviorAttr : LLVM_EnumAttr<
+    "ExceptionBehavior",
+    "::llvm::fp::ExceptionBehavior",
+    "LLVM Exception Behavior",
+    [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
+     ExceptionBehaviorStrict]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
 #endif // LLVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index e7a1da8ee560ef..ce91fbe1e2b24a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
   ];
 }
 
+def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+  let description = [{
+    An interface for operations receiving an exception behavior attribute
+    controlling FP exception behavior.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a ExceptionBehavior attribute for the operation",
+      /*returnType=*/  "ExceptionBehaviorAttr",
+      /*methodName=*/  "getExceptionBehaviorAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getExceptionbehaviorAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the ExceptionBehaviorAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getExceptionBehaviorAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "exceptionbehavior";
+      }]
+    >
+  ];
+}
+
+def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
+  let description = [{
+    An interface for operations receiving a rounding mode attribute
+    controlling FP rounding mode.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a RoundingMode attribute for the operation",
+      /*returnType=*/  "RoundingModeAttr",
+      /*methodName=*/  "getRoundingModeAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRoundingmodeAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the RoundingModeAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getRoundingModeAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "roundingmode";
+      }]
+    >,
+  ];
+}
 
 //===----------------------------------------------------------------------===//
 // LLVM dialect type interfaces.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b88f1186a44b49..6a2b9b90350e1a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -311,6 +311,47 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
       "qualified(type($ptr))";
 }
 
+// Constrained Floating-Point Intrinsics
+
+class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+    : LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
+                           /*overloadedResults=*/[0],
+                           /*overloadedOperands=*/[0],
+                           /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+                           # !cond(
+                               !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
+                               true : []),
+                           /*requiresFastmath=*/0,
+                           /*immArgPositions=*/[],
+                           /*immArgAttrNames=*/[]> {
+  dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
+  dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
+                            true : (ins)),
+                      (ins ExceptionBehaviorAttr:$exceptionbehavior));
+  let arguments = !con(regularArgs, attrArgs);
+  let llvmBuilder = [{
+    $res = LLVM::detail::createConstrainedIntrinsicCall(
+      builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
+       # mnem
+       # [{);
+  }];
+  let mlirBuilder = [{
+    auto op = moduleImport.translateConstrainedIntrinsic(
+      $_location, $_resultType, llvmOperands,
+      $_qualCppClassName::getOperationName());
+    if (!op)
+      return failure();
+    $res = op;
+  }];
+}
+
+def LLVM_ConstrainedFPTruncIntr
+    : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+  let assemblyFormat = [{
+    $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+  }];
+}
+
 // Intrinsics with multiple returns.
 
 class LLVM_ArithWithOverflowOp<string mnem>
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b49d2f539453e6..16f9994c126e06 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -232,6 +232,11 @@ class ModuleImport {
                             SmallVectorImpl<Value> &valuesOut,
                             SmallVectorImpl<NamedAttribute> &attrsOut);
 
+  /// Import constrained intrinsic.
+  Value translateConstrainedIntrinsic(Location loc, Type type,
+                                      ArrayRef<llvm::Value *> llvmOperands,
+                                      StringRef opName);
+
 private:
   /// Clears the accumulated state before processing a new region.
   void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index fb4392eb223c7f..458ed585167bc0 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -423,6 +423,10 @@ llvm::CallInst *createIntrinsicCall(
     ArrayRef<unsigned> immArgPositions,
     ArrayRef<StringLiteral> immArgAttrNames);
 
+llvm::CallInst *createConstrainedIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic);
+
 } // namespace detail
 
 } // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d63ea12ecd49b1..85a543c174f51d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,6 +1258,92 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
   return success();
 }
 
+static RoundingModeAttr metadataToRoundingMode(Builder &builder,
+                                               llvm::Metadata *metadata) {
+  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+  if (!mdstr)
+    return {};
+  std::optional<llvm::RoundingMode> optLLVM =
+      llvm::convertStrToRoundingMode(mdstr->getString());
+  if (!optLLVM)
+    return {};
+  return builder.getAttr<RoundingModeAttr>(
+      convertRoundingModeFromLLVM(*optLLVM));
+}
+
+static ExceptionBehaviorAttr
+metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+  if (!mdstr)
+    return {};
+  std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+      llvm::convertStrToExceptionBehavior(mdstr->getString());
+  if (!optLLVM)
+    return {};
+  return builder.getAttr<ExceptionBehaviorAttr>(
+      convertExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+static void
+splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
+                       SmallVectorImpl<llvm::Value *> &values,
+                       SmallVectorImpl<llvm::Metadata *> &metadata) {
+  for (llvm::Value *in : inputs) {
+    if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
+      metadata.push_back(mdval->getMetadata());
+    } else {
+      values.push_back(in);
+    }
+  }
+}
+
+Value ModuleImport::translateConstrainedIntrinsic(
+    Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
+    StringRef opName) {
+  // Split metadata values from regular ones.
+  SmallVector<llvm::Value *> values;
+  SmallVector<llvm::Metadata *> metadata;
+  splitMetadataAndValues(llvmOperands, values, metadata);
+
+  // Expect 1 or 2 metadata values.
+  assert((metadata.size() == 1 || metadata.size() == 2) &&
+         "Unexpected number of arguments");
+
+  SmallVector<Value> mlirOperands;
+  SmallVector<NamedAttribute> mlirAttrs;
+  if (failed(
+          convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
+    return {};
+  }
+
+  // Create operation as usual.
+  StringAttr opNameAttr = builder.getStringAttr(opName);
+  Operation *op =
+      builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
+
+  // Set exception behavior attribute.
+  auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
+  ExceptionBehaviorAttr attr =
+      metadataToExceptionBehavior(builder, metadata.back());
+  if (!attr)
+    return {};
+  op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
+
+  // If avaialbe, set rounding mode attribute.
+  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
+    assert(metadata.size() > 1 && "Unexpected number of arguments");
+    // rounding_mode present
+    RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
+    if (!attr)
+      return {};
+    roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
+  } else {
+    assert(metadata.size() == 1 && "Unexpected number of arguments");
+  }
+
+  return op->getResult(0);
+}
+
 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
   IntegerAttr integerAttr;
   FailureOr<Value> converted = convertValue(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f90495d407fdfe..9100b1a18c71c6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,6 +862,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   return builder.CreateCall(llvmIntr, args);
 }
 
+llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  SmallVector<llvm::Type *> overloadedTypes{
+      moduleTranslation.convertType(intrOp->getResult(0).getType()),
+      moduleTranslation.convertType(intrOp->getOperand(0).getType())};
+  llvm::Function *callee =
+      llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+  SmallVector<llvm::Value *> args =
+      moduleTranslation.lookupValues(intrOp->getOperands());
+  std::optional<llvm::RoundingMode> rounding;
+  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
+    rounding = convertRoundingModeToLLVM(
+        roundingModeOp.getRoundingModeAttr().getValue());
+  }
+  llvm::fp::ExceptionBehavior except =
+      convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
+                                         .getExceptionBehaviorAttr()
+                                         .getValue());
+  return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
+}
+
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
 LogicalResult ModuleTranslation::convertOperation(Operation &op,
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b157cf00141842..cc415e7b3662be 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -647,3 +647,22 @@ llvm.func @experimental_noalias_scope_decl() {
   llvm.intr.experimental.noalias.scope.decl #alias_scope
   llvm.return
 }
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+  %0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+  %1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+  %2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+  %3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+  %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
+  %tmp0 = llvm.fadd %0, %1 : f32
+  %tmp1 = llvm.fadd %2, %3 : f32
+  %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
+  %res = llvm.fadd %tmp2, %4 : f32
+  llvm.return %res : f32
+}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 1ec9005458c50b..85561839f31a70 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
   ret float %2
 }
 
+; CHECK-LABEL: experimental_constrained_fptrunc
+define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+  %1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+  %2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+  %3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+  %4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+  %5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
+  %6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+  ret void
+}
+
 declare float @llvm.fmuladd.f32(float, float, float)
 declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
 declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
 declare float @llvm.ssa.copy.f32(float returned)
 declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
 declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
+declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
+declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index fc2e0fd201a738..0013522582a727 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -964,6 +964,35 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
   llvm.return %0 : f32
 }
 
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.towardzero"
+  // CHECK: metadata !"fpexcept.ignore"
+  %0 = llvm.intr.experimental.constrained.fptrunc %s towardzero ignore : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.tonearest"
+  // CHECK: metadata !"fpexcept.maytrap"
+  %1 = llvm.intr.experimental.constrained.fptrunc %s tonearest maytrap : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.upward"
+  // CHECK: metadata !"fpexcept.strict"
+  %2 = llvm.intr.experimental.constrained.fptrunc %s upward strict : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.downward"
+  // CHECK: metadata !"fpexcept.ignore"
+  %3 = llvm.intr.experimental.constrained.fptrunc %s downward ignore : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.tonearestaway"
+  // CHECK: metadata !"fpexcept.ignore"
+  %4 = llvm.intr.experimental.constrained.fptrunc %s tonearestaway ignore : f64 to f32
+  // CHECK: call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(
+  // CHECK: metadata !"round.upward"
+  // CHECK: metadata !"fpexcept.strict"
+  %5 = llvm.intr.experimental.constrained.fptrunc %v upward strict : vector<4xf32> to vector<4xf16>
+  llvm.ret...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-mlir

Author: Victor Perez (victor-eds)

Changes

Add operation mapping to the LLVM
llvm.experimental.constrained.fptrunc.* intrinsic.

The new operation implements the new
LLVM::ExceptionBehaviorOpInterface and
LLVM::RoundingModeOpInterface interfaces.


Patch is 20.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86260.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td (+57)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+67)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td (+41)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+5)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+4)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+86)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+23)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+19)
  • (modified) mlir/test/Target/LLVMIR/Import/intrinsic.ll (+19)
  • (modified) mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir (+31)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index a7b269eb41ee2e..19fc69dda16696 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
   let cppNamespace = "::mlir::LLVM::framePointerKind";
 }
 
+//===----------------------------------------------------------------------===//
+// RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::RoundingMode ones.
+// See llvm/include/llvm/ADT/FloatingPointMode.h.
+def RoundTowardZero
+    : LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
+def RoundNearestTiesToEven
+    : LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
+def RoundTowardPositive
+    : LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
+def RoundTowardNegative
+    : LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
+def RoundNearestTiesToAway
+    : LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
+def RoundDynamic
+    : LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
+// Needed as llvm::RoundingMode defines this.
+def RoundInvalid
+    : LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
+
+// RoundingModeAttr should not be used in operations definitions.
+// Use ValidRoundingModeAttr instead.
+def RoundingModeAttr : LLVM_EnumAttr<
+    "RoundingMode",
+    "::llvm::RoundingMode",
+    "LLVM Rounding Mode",
+    [RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
+     RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
+def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
+
+//===----------------------------------------------------------------------===//
+// ExceptionBehavior
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::fp::ExceptionBehavior ones.
+// See llvm/include/llvm/IR/FPEnv.h.
+def ExceptionBehaviorIgnore
+    : LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
+def ExceptionBehaviorMayTrap
+    : LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
+def ExceptionBehaviorStrict
+    : LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
+
+def ExceptionBehaviorAttr : LLVM_EnumAttr<
+    "ExceptionBehavior",
+    "::llvm::fp::ExceptionBehavior",
+    "LLVM Exception Behavior",
+    [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
+     ExceptionBehaviorStrict]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
 #endif // LLVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index e7a1da8ee560ef..ce91fbe1e2b24a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
   ];
 }
 
+def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+  let description = [{
+    An interface for operations receiving an exception behavior attribute
+    controlling FP exception behavior.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a ExceptionBehavior attribute for the operation",
+      /*returnType=*/  "ExceptionBehaviorAttr",
+      /*methodName=*/  "getExceptionBehaviorAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getExceptionbehaviorAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the ExceptionBehaviorAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getExceptionBehaviorAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "exceptionbehavior";
+      }]
+    >
+  ];
+}
+
+def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
+  let description = [{
+    An interface for operations receiving a rounding mode attribute
+    controlling FP rounding mode.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a RoundingMode attribute for the operation",
+      /*returnType=*/  "RoundingModeAttr",
+      /*methodName=*/  "getRoundingModeAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRoundingmodeAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the RoundingModeAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getRoundingModeAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "roundingmode";
+      }]
+    >,
+  ];
+}
 
 //===----------------------------------------------------------------------===//
 // LLVM dialect type interfaces.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b88f1186a44b49..6a2b9b90350e1a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -311,6 +311,47 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
       "qualified(type($ptr))";
 }
 
+// Constrained Floating-Point Intrinsics
+
+class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+    : LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
+                           /*overloadedResults=*/[0],
+                           /*overloadedOperands=*/[0],
+                           /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+                           # !cond(
+                               !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
+                               true : []),
+                           /*requiresFastmath=*/0,
+                           /*immArgPositions=*/[],
+                           /*immArgAttrNames=*/[]> {
+  dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
+  dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
+                            true : (ins)),
+                      (ins ExceptionBehaviorAttr:$exceptionbehavior));
+  let arguments = !con(regularArgs, attrArgs);
+  let llvmBuilder = [{
+    $res = LLVM::detail::createConstrainedIntrinsicCall(
+      builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
+       # mnem
+       # [{);
+  }];
+  let mlirBuilder = [{
+    auto op = moduleImport.translateConstrainedIntrinsic(
+      $_location, $_resultType, llvmOperands,
+      $_qualCppClassName::getOperationName());
+    if (!op)
+      return failure();
+    $res = op;
+  }];
+}
+
+def LLVM_ConstrainedFPTruncIntr
+    : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+  let assemblyFormat = [{
+    $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+  }];
+}
+
 // Intrinsics with multiple returns.
 
 class LLVM_ArithWithOverflowOp<string mnem>
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b49d2f539453e6..16f9994c126e06 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -232,6 +232,11 @@ class ModuleImport {
                             SmallVectorImpl<Value> &valuesOut,
                             SmallVectorImpl<NamedAttribute> &attrsOut);
 
+  /// Import constrained intrinsic.
+  Value translateConstrainedIntrinsic(Location loc, Type type,
+                                      ArrayRef<llvm::Value *> llvmOperands,
+                                      StringRef opName);
+
 private:
   /// Clears the accumulated state before processing a new region.
   void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index fb4392eb223c7f..458ed585167bc0 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -423,6 +423,10 @@ llvm::CallInst *createIntrinsicCall(
     ArrayRef<unsigned> immArgPositions,
     ArrayRef<StringLiteral> immArgAttrNames);
 
+llvm::CallInst *createConstrainedIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic);
+
 } // namespace detail
 
 } // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d63ea12ecd49b1..85a543c174f51d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,6 +1258,92 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
   return success();
 }
 
+static RoundingModeAttr metadataToRoundingMode(Builder &builder,
+                                               llvm::Metadata *metadata) {
+  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+  if (!mdstr)
+    return {};
+  std::optional<llvm::RoundingMode> optLLVM =
+      llvm::convertStrToRoundingMode(mdstr->getString());
+  if (!optLLVM)
+    return {};
+  return builder.getAttr<RoundingModeAttr>(
+      convertRoundingModeFromLLVM(*optLLVM));
+}
+
+static ExceptionBehaviorAttr
+metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+  if (!mdstr)
+    return {};
+  std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+      llvm::convertStrToExceptionBehavior(mdstr->getString());
+  if (!optLLVM)
+    return {};
+  return builder.getAttr<ExceptionBehaviorAttr>(
+      convertExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+static void
+splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
+                       SmallVectorImpl<llvm::Value *> &values,
+                       SmallVectorImpl<llvm::Metadata *> &metadata) {
+  for (llvm::Value *in : inputs) {
+    if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
+      metadata.push_back(mdval->getMetadata());
+    } else {
+      values.push_back(in);
+    }
+  }
+}
+
+Value ModuleImport::translateConstrainedIntrinsic(
+    Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
+    StringRef opName) {
+  // Split metadata values from regular ones.
+  SmallVector<llvm::Value *> values;
+  SmallVector<llvm::Metadata *> metadata;
+  splitMetadataAndValues(llvmOperands, values, metadata);
+
+  // Expect 1 or 2 metadata values.
+  assert((metadata.size() == 1 || metadata.size() == 2) &&
+         "Unexpected number of arguments");
+
+  SmallVector<Value> mlirOperands;
+  SmallVector<NamedAttribute> mlirAttrs;
+  if (failed(
+          convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
+    return {};
+  }
+
+  // Create operation as usual.
+  StringAttr opNameAttr = builder.getStringAttr(opName);
+  Operation *op =
+      builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
+
+  // Set exception behavior attribute.
+  auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
+  ExceptionBehaviorAttr attr =
+      metadataToExceptionBehavior(builder, metadata.back());
+  if (!attr)
+    return {};
+  op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
+
+  // If avaialbe, set rounding mode attribute.
+  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
+    assert(metadata.size() > 1 && "Unexpected number of arguments");
+    // rounding_mode present
+    RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
+    if (!attr)
+      return {};
+    roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
+  } else {
+    assert(metadata.size() == 1 && "Unexpected number of arguments");
+  }
+
+  return op->getResult(0);
+}
+
 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
   IntegerAttr integerAttr;
   FailureOr<Value> converted = convertValue(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f90495d407fdfe..9100b1a18c71c6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,6 +862,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   return builder.CreateCall(llvmIntr, args);
 }
 
+llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  SmallVector<llvm::Type *> overloadedTypes{
+      moduleTranslation.convertType(intrOp->getResult(0).getType()),
+      moduleTranslation.convertType(intrOp->getOperand(0).getType())};
+  llvm::Function *callee =
+      llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+  SmallVector<llvm::Value *> args =
+      moduleTranslation.lookupValues(intrOp->getOperands());
+  std::optional<llvm::RoundingMode> rounding;
+  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
+    rounding = convertRoundingModeToLLVM(
+        roundingModeOp.getRoundingModeAttr().getValue());
+  }
+  llvm::fp::ExceptionBehavior except =
+      convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
+                                         .getExceptionBehaviorAttr()
+                                         .getValue());
+  return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
+}
+
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
 LogicalResult ModuleTranslation::convertOperation(Operation &op,
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b157cf00141842..cc415e7b3662be 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -647,3 +647,22 @@ llvm.func @experimental_noalias_scope_decl() {
   llvm.intr.experimental.noalias.scope.decl #alias_scope
   llvm.return
 }
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+  %0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+  %1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+  %2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+  %3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+  %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
+  %tmp0 = llvm.fadd %0, %1 : f32
+  %tmp1 = llvm.fadd %2, %3 : f32
+  %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
+  %res = llvm.fadd %tmp2, %4 : f32
+  llvm.return %res : f32
+}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 1ec9005458c50b..85561839f31a70 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
   ret float %2
 }
 
+; CHECK-LABEL: experimental_constrained_fptrunc
+define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+  %1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+  %2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+  %3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+  %4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+  %5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
+  %6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+  ret void
+}
+
 declare float @llvm.fmuladd.f32(float, float, float)
 declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
 declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
 declare float @llvm.ssa.copy.f32(float returned)
 declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
 declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
+declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
+declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index fc2e0fd201a738..0013522582a727 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -964,6 +964,35 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
   llvm.return %0 : f32
 }
 
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.towardzero"
+  // CHECK: metadata !"fpexcept.ignore"
+  %0 = llvm.intr.experimental.constrained.fptrunc %s towardzero ignore : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.tonearest"
+  // CHECK: metadata !"fpexcept.maytrap"
+  %1 = llvm.intr.experimental.constrained.fptrunc %s tonearest maytrap : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.upward"
+  // CHECK: metadata !"fpexcept.strict"
+  %2 = llvm.intr.experimental.constrained.fptrunc %s upward strict : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.downward"
+  // CHECK: metadata !"fpexcept.ignore"
+  %3 = llvm.intr.experimental.constrained.fptrunc %s downward ignore : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.tonearestaway"
+  // CHECK: metadata !"fpexcept.ignore"
+  %4 = llvm.intr.experimental.constrained.fptrunc %s tonearestaway ignore : f64 to f32
+  // CHECK: call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(
+  // CHECK: metadata !"round.upward"
+  // CHECK: metadata !"fpexcept.strict"
+  %5 = llvm.intr.experimental.constrained.fptrunc %v upward strict : vector<4xf32> to vector<4xf16>
+  llvm.ret...
[truncated]

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

I added some suggestions that are a bit closer to the solutions we took in other places. Specifically when importing intrinsic with metadata arguments. If the mechanism also works in your case, it would be great to follow the same idea.

Signed-off-by: Victor Perez <[email protected]>
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Mar 25, 2024
Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

@victor-eds victor-eds requested review from zero9178 and gysit March 25, 2024 15:14
@victor-eds
Copy link
Contributor Author

@zero9178 @gysit thanks for your comments! I hope I've addressed them all

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, so if @gysit is happy then so am I 🙂

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes!

LGTM modulo some minor comments.

Its a shame these intrinsics seem to exists with all kinds of argument number and rounding mode configurations. Otherwise, it may have been nice to factor out base classes unary / binary etc to make things a bit less complex...

@victor-eds
Copy link
Contributor Author

victor-eds commented Mar 26, 2024

Its a shame these intrinsics seem to exists with all kinds of argument number and rounding mode configurations. Otherwise, it may have been nice to factor out base classes unary / binary etc to make things a bit less complex...

I guess we could have some kind of argument for the base intrinsic class in a similar way to immediate values, but that'd require further modifications. I don't think it'd be 100 % automatic anyway...

@gysit
Copy link
Contributor

gysit commented Mar 26, 2024

I guess we could have some kind of argument for the base intrinsic class in a similar way to immediate values, but that'd require further modifications. I don't think it'd be 100 % automatic anyway...

A solution where the arguments are defined in a derived class being the actual intrinsic or some unary / binary etc base class could be nice. Especially since that may allow us to properly type the intrinsic operands etc. However, with tablegen it is really hard to guess how complex things are in the end. If you have a good idea you are very welcome to give it a shot.

Otherwise it is also ok to land and follow up if needed (e.g. after adding more intrinsics).

@victor-eds
Copy link
Contributor Author

I guess we could have some kind of argument for the base intrinsic class in a similar way to immediate values, but that'd require further modifications. I don't think it'd be 100 % automatic anyway...

A solution where the arguments are defined in a derived class being the actual intrinsic or some unary / binary etc base class could be nice. Especially since that may allow us to properly type the intrinsic operands etc. However, with tablegen it is really hard to guess how complex things are in the end. If you have a good idea you are very welcome to give it a shot.

Otherwise it is also ok to land and follow up if needed (e.g. after adding more intrinsics).

Yes, that's more or less what I was thinking of. I'm depending on this PR for further work in other dialects, so I'd rather get this merged and give this a shot in the future. I'll wait till CI passes. Thanks a lot!

llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Mar 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants