Skip to content

[MLIR][LLVM] Add llvm.experimental.constrained.fptrunc operation #86260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
let cppNamespace = "::mlir::LLVM::framePointerKind";
}

//===----------------------------------------------------------------------===//
// RoundingMode
//===----------------------------------------------------------------------===//

// These values must match llvm::RoundingMode ones.
// See llvm/include/llvm/ADT/FloatingPointMode.h.
def RoundTowardZero
: LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
def RoundNearestTiesToEven
: LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
def RoundTowardPositive
: LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
def RoundTowardNegative
: LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
def RoundNearestTiesToAway
: LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
def RoundDynamic
: LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
// Needed as llvm::RoundingMode defines this.
def RoundInvalid
: LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;

// RoundingModeAttr should not be used in operations definitions.
// Use ValidRoundingModeAttr instead.
def RoundingModeAttr : LLVM_EnumAttr<
"RoundingMode",
"::llvm::RoundingMode",
"LLVM Rounding Mode",
[RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
let cppNamespace = "::mlir::LLVM";
}

def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;

//===----------------------------------------------------------------------===//
// FPExceptionBehavior
//===----------------------------------------------------------------------===//

// These values must match llvm::fp::ExceptionBehavior ones.
// See llvm/include/llvm/IR/FPEnv.h.
def FPExceptionBehaviorIgnore
: LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
def FPExceptionBehaviorMayTrap
: LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
def FPExceptionBehaviorStrict
: LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;

def FPExceptionBehaviorAttr : LLVM_EnumAttr<
"FPExceptionBehavior",
"::llvm::fp::ExceptionBehavior",
"LLVM Exception Behavior",
[FPExceptionBehaviorIgnore, FPExceptionBehaviorMayTrap,
FPExceptionBehaviorStrict]> {
let cppNamespace = "::mlir::LLVM";
}

#endif // LLVMIR_ENUMS
67 changes: 67 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
];
}

def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
let description = [{
An interface for operations receiving an exception behavior attribute
controlling FP exception behavior.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns a FPExceptionBehavior attribute for the operation",
/*returnType=*/ "FPExceptionBehaviorAttr",
/*methodName=*/ "getFPExceptionBehaviorAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFpExceptionBehaviorAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the FPExceptionBehaviorAttr
attribute for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getFPExceptionBehaviorAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "fpExceptionBehavior";
}]
>
];
}

def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
let description = [{
An interface for operations receiving a rounding mode attribute
controlling FP rounding mode.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns a RoundingMode attribute for the operation",
/*returnType=*/ "RoundingModeAttr",
/*methodName=*/ "getRoundingModeAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getRoundingmodeAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the RoundingModeAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getRoundingModeAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "roundingmode";
}]
>,
];
}

//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.
Expand Down
85 changes: 85 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,91 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
"qualified(type($ptr))";
}

// Constrained Floating-Point Intrinsics.

class LLVM_ConstrainedIntr<string mnem, int numArgs,
bit overloadedResult, list<int> overloadedOperands,
bit hasRoundingMode>
: LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
/*overloadedResults=*/
!cond(!gt(overloadedResult, 0) : [0],
true : []),
overloadedOperands,
/*traits=*/[Pure, DeclareOpInterfaceMethods<FPExceptionBehaviorOpInterface>]
# !cond(
!gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
true : []),
/*requiresFastmath=*/0,
/*immArgPositions=*/[],
/*immArgAttrNames=*/[]> {
dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
true : (ins)),
(ins FPExceptionBehaviorAttr:$fpExceptionBehavior));
let arguments = !con(regularArgs, attrArgs);
let llvmBuilder = [{
SmallVector<llvm::Value *> args =
moduleTranslation.lookupValues(opInst.getOperands());
SmallVector<llvm::Type *> overloadedTypes; }] #
!cond(!gt(overloadedResult, 0) : [{
// Take into account overloaded result type.
overloadedTypes.push_back($_resultType); }],
// No overloaded result type.
true : "") # [{
llvm::transform(ArrayRef<unsigned>}] # overloadedOperandsCpp # [{,
std::back_inserter(overloadedTypes),
[&args](unsigned index) { return args[index]->getType(); });
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *callee =
llvm::Intrinsic::getDeclaration(module,
llvm::Intrinsic::experimental_constrained_}] #
mnem # [{, overloadedTypes); }] #
!cond(!gt(hasRoundingMode, 0) : [{
// Get rounding mode using interface.
llvm::RoundingMode rounding =
moduleTranslation.translateRoundingMode($roundingmode); }],
true : [{
// No rounding mode.
std::optional<llvm::RoundingMode> rounding; }]) # [{
llvm::fp::ExceptionBehavior except =
moduleTranslation.translateFPExceptionBehavior($fpExceptionBehavior);
$res = builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
}];
let mlirBuilder = [{
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
llvmOperands.take_front( }] # numArgs # [{),
{}, {}, mlirOperands, mlirAttrs))) {
return failure();
}

FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
$_fpExceptionBehavior_attr($fpExceptionBehavior);
mlirAttrs.push_back(
$_builder.getNamedAttr(
$_qualCppClassName::getFPExceptionBehaviorAttrName(),
fpExceptionBehaviorAttr)); }] #
!cond(!gt(hasRoundingMode, 0) : [{
RoundingModeAttr roundingModeAttr = $_roundingMode_attr($roundingmode);
mlirAttrs.push_back(
$_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
roundingModeAttr));
}], true : "") # [{
$res = $_builder.create<$_qualCppClassName>($_location,
$_resultType, mlirOperands, mlirAttrs);
}];
}

def LLVM_ConstrainedFPTruncIntr
: LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1,
/*overloadedResult=*/1, /*overloadedOperands=*/[0],
/*hasRoundingMode=*/1> {
let assemblyFormat = [{
$arg_0 $roundingmode $fpExceptionBehavior attr-dict `:` type($arg_0) `to` type(results)
}];
}

// Intrinsics with multiple returns.

class LLVM_ArithWithOverflowOp<string mnem>
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<Trait> traits = []> :
// - $_float_attr - substituted by a call to a float attribute matcher;
// - $_var_attr - substituted by a call to a variable attribute matcher;
// - $_label_attr - substituted by a call to a label attribute matcher;
// - $_roundingMode_attr - substituted by a call to a rounding mode
// attribute matcher;
// - $_fpExceptionBehavior_attr - substituted by a call to a FP exception
// behavior attribute matcher;
// - $_resultType - substituted with the MLIR result type;
// - $_location - substituted with the MLIR location;
// - $_builder - substituted with the MLIR builder;
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ class ModuleImport {
/// Converts `value` to a label attribute. Asserts if the matching fails.
DILabelAttr matchLabelAttr(llvm::Value *value);

/// Converts `value` to a FP exception behavior attribute. Asserts if the
/// matching fails.
FPExceptionBehaviorAttr matchFPExceptionBehaviorAttr(llvm::Value *value);

/// Converts `value` to a rounding mode attribute. Asserts if the matching
/// fails.
RoundingModeAttr matchRoundingModeAttr(llvm::Value *value);

/// Converts `value` to an array of alias scopes or returns failure if the
/// conversion fails.
FailureOr<SmallVector<AliasScopeAttr>>
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ class ModuleTranslation {
/// Translates the given LLVM debug info metadata.
llvm::Metadata *translateDebugInfo(LLVM::DINodeAttr attr);

/// Translates the given LLVM rounding mode metadata.
llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding);

/// Translates the given LLVM FP exception behavior metadata.
llvm::fp::ExceptionBehavior
translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior);

/// Translates the contents of the given block to LLVM IR using this
/// translator. The LLVM IR basic block corresponding to the given block is
/// expected to exist in the mapping of this translator. Uses `builder` to
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,27 @@ DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
return debugImporter->translate(node);
}

FPExceptionBehaviorAttr
ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
auto *metadata = cast<llvm::MetadataAsValue>(value);
auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
std::optional<llvm::fp::ExceptionBehavior> optLLVM =
llvm::convertStrToExceptionBehavior(mdstr->getString());
assert(optLLVM && "Expecting FP exception behavior");
return builder.getAttr<FPExceptionBehaviorAttr>(
convertFPExceptionBehaviorFromLLVM(*optLLVM));
}

RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
auto *metadata = cast<llvm::MetadataAsValue>(value);
auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
std::optional<llvm::RoundingMode> optLLVM =
llvm::convertStrToRoundingMode(mdstr->getString());
assert(optLLVM && "Expecting rounding mode");
return builder.getAttr<RoundingModeAttr>(
convertRoundingModeFromLLVM(*optLLVM));
}

FailureOr<SmallVector<AliasScopeAttr>>
ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,16 @@ llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
return debugTranslation->translate(attr);
}

llvm::RoundingMode
ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
return convertRoundingModeToLLVM(rounding);
}

llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
LLVM::FPExceptionBehavior exceptionBehavior) {
return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
}

llvm::NamedMDNode *
ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
return llvmModule->getOrInsertNamedMetadata(name);
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,18 @@ llvm.func @experimental_noalias_scope_decl() {
llvm.intr.experimental.noalias.scope.decl #alias_scope
llvm.return
}

// CHECK-LABEL: @experimental_constrained_fptrunc
llvm.func @experimental_constrained_fptrunc(%in: f64) {
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
%0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
%1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
%2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
%3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
llvm.return
}
19 changes: 19 additions & 0 deletions mlir/test/Target/LLVMIR/Import/intrinsic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
ret float %2
}

; CHECK-LABEL: experimental_constrained_fptrunc
define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
%1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
%2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
%3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
%4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
%5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
%6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
ret void
}

declare float @llvm.fmuladd.f32(float, float, float)
declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare float @llvm.fma.f32(float, float, float)
Expand Down Expand Up @@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
declare float @llvm.ssa.copy.f32(float returned)
declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
Loading