Skip to content

Commit 77cbc9b

Browse files
authored
[MLIR][LLVM] Add llvm.experimental.constrained.fptrunc operation (#86260)
Add operation mapping to the LLVM `llvm.experimental.constrained.fptrunc.*` intrinsic. The new operation implements the new `LLVM::FPExceptionBehaviorOpInterface` and `LLVM::RoundingModeOpInterface` interfaces. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 256343a commit 77cbc9b

File tree

12 files changed

+328
-0
lines changed

12 files changed

+328
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
705705
let cppNamespace = "::mlir::LLVM::framePointerKind";
706706
}
707707

708+
//===----------------------------------------------------------------------===//
709+
// RoundingMode
710+
//===----------------------------------------------------------------------===//
711+
712+
// These values must match llvm::RoundingMode ones.
713+
// See llvm/include/llvm/ADT/FloatingPointMode.h.
714+
def RoundTowardZero
715+
: LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
716+
def RoundNearestTiesToEven
717+
: LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
718+
def RoundTowardPositive
719+
: LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
720+
def RoundTowardNegative
721+
: LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
722+
def RoundNearestTiesToAway
723+
: LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
724+
def RoundDynamic
725+
: LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
726+
// Needed as llvm::RoundingMode defines this.
727+
def RoundInvalid
728+
: LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
729+
730+
// RoundingModeAttr should not be used in operations definitions.
731+
// Use ValidRoundingModeAttr instead.
732+
def RoundingModeAttr : LLVM_EnumAttr<
733+
"RoundingMode",
734+
"::llvm::RoundingMode",
735+
"LLVM Rounding Mode",
736+
[RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
737+
RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
738+
let cppNamespace = "::mlir::LLVM";
739+
}
740+
741+
def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
742+
743+
//===----------------------------------------------------------------------===//
744+
// FPExceptionBehavior
745+
//===----------------------------------------------------------------------===//
746+
747+
// These values must match llvm::fp::ExceptionBehavior ones.
748+
// See llvm/include/llvm/IR/FPEnv.h.
749+
def FPExceptionBehaviorIgnore
750+
: LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
751+
def FPExceptionBehaviorMayTrap
752+
: LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
753+
def FPExceptionBehaviorStrict
754+
: LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
755+
756+
def FPExceptionBehaviorAttr : LLVM_EnumAttr<
757+
"FPExceptionBehavior",
758+
"::llvm::fp::ExceptionBehavior",
759+
"LLVM Exception Behavior",
760+
[FPExceptionBehaviorIgnore, FPExceptionBehaviorMayTrap,
761+
FPExceptionBehaviorStrict]> {
762+
let cppNamespace = "::mlir::LLVM";
763+
}
764+
708765
#endif // LLVMIR_ENUMS

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
290290
];
291291
}
292292

293+
def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
294+
let description = [{
295+
An interface for operations receiving an exception behavior attribute
296+
controlling FP exception behavior.
297+
}];
298+
299+
let cppNamespace = "::mlir::LLVM";
300+
301+
let methods = [
302+
InterfaceMethod<
303+
/*desc=*/ "Returns a FPExceptionBehavior attribute for the operation",
304+
/*returnType=*/ "FPExceptionBehaviorAttr",
305+
/*methodName=*/ "getFPExceptionBehaviorAttr",
306+
/*args=*/ (ins),
307+
/*methodBody=*/ [{}],
308+
/*defaultImpl=*/ [{
309+
auto op = cast<ConcreteOp>(this->getOperation());
310+
return op.getFpExceptionBehaviorAttr();
311+
}]
312+
>,
313+
StaticInterfaceMethod<
314+
/*desc=*/ [{Returns the name of the FPExceptionBehaviorAttr
315+
attribute for the operation}],
316+
/*returnType=*/ "StringRef",
317+
/*methodName=*/ "getFPExceptionBehaviorAttrName",
318+
/*args=*/ (ins),
319+
/*methodBody=*/ [{}],
320+
/*defaultImpl=*/ [{
321+
return "fpExceptionBehavior";
322+
}]
323+
>
324+
];
325+
}
326+
327+
def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
328+
let description = [{
329+
An interface for operations receiving a rounding mode attribute
330+
controlling FP rounding mode.
331+
}];
332+
333+
let cppNamespace = "::mlir::LLVM";
334+
335+
let methods = [
336+
InterfaceMethod<
337+
/*desc=*/ "Returns a RoundingMode attribute for the operation",
338+
/*returnType=*/ "RoundingModeAttr",
339+
/*methodName=*/ "getRoundingModeAttr",
340+
/*args=*/ (ins),
341+
/*methodBody=*/ [{}],
342+
/*defaultImpl=*/ [{
343+
auto op = cast<ConcreteOp>(this->getOperation());
344+
return op.getRoundingmodeAttr();
345+
}]
346+
>,
347+
StaticInterfaceMethod<
348+
/*desc=*/ [{Returns the name of the RoundingModeAttr attribute
349+
for the operation}],
350+
/*returnType=*/ "StringRef",
351+
/*methodName=*/ "getRoundingModeAttrName",
352+
/*args=*/ (ins),
353+
/*methodBody=*/ [{}],
354+
/*defaultImpl=*/ [{
355+
return "roundingmode";
356+
}]
357+
>,
358+
];
359+
}
293360

294361
//===----------------------------------------------------------------------===//
295362
// LLVM dialect type interfaces.

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,91 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
311311
"qualified(type($ptr))";
312312
}
313313

314+
// Constrained Floating-Point Intrinsics.
315+
316+
class LLVM_ConstrainedIntr<string mnem, int numArgs,
317+
bit overloadedResult, list<int> overloadedOperands,
318+
bit hasRoundingMode>
319+
: LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
320+
/*overloadedResults=*/
321+
!cond(!gt(overloadedResult, 0) : [0],
322+
true : []),
323+
overloadedOperands,
324+
/*traits=*/[Pure, DeclareOpInterfaceMethods<FPExceptionBehaviorOpInterface>]
325+
# !cond(
326+
!gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
327+
true : []),
328+
/*requiresFastmath=*/0,
329+
/*immArgPositions=*/[],
330+
/*immArgAttrNames=*/[]> {
331+
dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
332+
dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
333+
true : (ins)),
334+
(ins FPExceptionBehaviorAttr:$fpExceptionBehavior));
335+
let arguments = !con(regularArgs, attrArgs);
336+
let llvmBuilder = [{
337+
SmallVector<llvm::Value *> args =
338+
moduleTranslation.lookupValues(opInst.getOperands());
339+
SmallVector<llvm::Type *> overloadedTypes; }] #
340+
!cond(!gt(overloadedResult, 0) : [{
341+
// Take into account overloaded result type.
342+
overloadedTypes.push_back($_resultType); }],
343+
// No overloaded result type.
344+
true : "") # [{
345+
llvm::transform(ArrayRef<unsigned>}] # overloadedOperandsCpp # [{,
346+
std::back_inserter(overloadedTypes),
347+
[&args](unsigned index) { return args[index]->getType(); });
348+
llvm::Module *module = builder.GetInsertBlock()->getModule();
349+
llvm::Function *callee =
350+
llvm::Intrinsic::getDeclaration(module,
351+
llvm::Intrinsic::experimental_constrained_}] #
352+
mnem # [{, overloadedTypes); }] #
353+
!cond(!gt(hasRoundingMode, 0) : [{
354+
// Get rounding mode using interface.
355+
llvm::RoundingMode rounding =
356+
moduleTranslation.translateRoundingMode($roundingmode); }],
357+
true : [{
358+
// No rounding mode.
359+
std::optional<llvm::RoundingMode> rounding; }]) # [{
360+
llvm::fp::ExceptionBehavior except =
361+
moduleTranslation.translateFPExceptionBehavior($fpExceptionBehavior);
362+
$res = builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
363+
}];
364+
let mlirBuilder = [{
365+
SmallVector<Value> mlirOperands;
366+
SmallVector<NamedAttribute> mlirAttrs;
367+
if (failed(moduleImport.convertIntrinsicArguments(
368+
llvmOperands.take_front( }] # numArgs # [{),
369+
{}, {}, mlirOperands, mlirAttrs))) {
370+
return failure();
371+
}
372+
373+
FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
374+
$_fpExceptionBehavior_attr($fpExceptionBehavior);
375+
mlirAttrs.push_back(
376+
$_builder.getNamedAttr(
377+
$_qualCppClassName::getFPExceptionBehaviorAttrName(),
378+
fpExceptionBehaviorAttr)); }] #
379+
!cond(!gt(hasRoundingMode, 0) : [{
380+
RoundingModeAttr roundingModeAttr = $_roundingMode_attr($roundingmode);
381+
mlirAttrs.push_back(
382+
$_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
383+
roundingModeAttr));
384+
}], true : "") # [{
385+
$res = $_builder.create<$_qualCppClassName>($_location,
386+
$_resultType, mlirOperands, mlirAttrs);
387+
}];
388+
}
389+
390+
def LLVM_ConstrainedFPTruncIntr
391+
: LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1,
392+
/*overloadedResult=*/1, /*overloadedOperands=*/[0],
393+
/*hasRoundingMode=*/1> {
394+
let assemblyFormat = [{
395+
$arg_0 $roundingmode $fpExceptionBehavior attr-dict `:` type($arg_0) `to` type(results)
396+
}];
397+
}
398+
314399
// Intrinsics with multiple returns.
315400

316401
class LLVM_ArithWithOverflowOp<string mnem>

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<Trait> traits = []> :
170170
// - $_float_attr - substituted by a call to a float attribute matcher;
171171
// - $_var_attr - substituted by a call to a variable attribute matcher;
172172
// - $_label_attr - substituted by a call to a label attribute matcher;
173+
// - $_roundingMode_attr - substituted by a call to a rounding mode
174+
// attribute matcher;
175+
// - $_fpExceptionBehavior_attr - substituted by a call to a FP exception
176+
// behavior attribute matcher;
173177
// - $_resultType - substituted with the MLIR result type;
174178
// - $_location - substituted with the MLIR location;
175179
// - $_builder - substituted with the MLIR builder;

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ class ModuleImport {
152152
/// Converts `value` to a label attribute. Asserts if the matching fails.
153153
DILabelAttr matchLabelAttr(llvm::Value *value);
154154

155+
/// Converts `value` to a FP exception behavior attribute. Asserts if the
156+
/// matching fails.
157+
FPExceptionBehaviorAttr matchFPExceptionBehaviorAttr(llvm::Value *value);
158+
159+
/// Converts `value` to a rounding mode attribute. Asserts if the matching
160+
/// fails.
161+
RoundingModeAttr matchRoundingModeAttr(llvm::Value *value);
162+
155163
/// Converts `value` to an array of alias scopes or returns failure if the
156164
/// conversion fails.
157165
FailureOr<SmallVector<AliasScopeAttr>>

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ class ModuleTranslation {
201201
/// Translates the given LLVM debug info metadata.
202202
llvm::Metadata *translateDebugInfo(LLVM::DINodeAttr attr);
203203

204+
/// Translates the given LLVM rounding mode metadata.
205+
llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding);
206+
207+
/// Translates the given LLVM FP exception behavior metadata.
208+
llvm::fp::ExceptionBehavior
209+
translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior);
210+
204211
/// Translates the contents of the given block to LLVM IR using this
205212
/// translator. The LLVM IR basic block corresponding to the given block is
206213
/// expected to exist in the mapping of this translator. Uses `builder` to

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,27 @@ DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
12901290
return debugImporter->translate(node);
12911291
}
12921292

1293+
FPExceptionBehaviorAttr
1294+
ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
1295+
auto *metadata = cast<llvm::MetadataAsValue>(value);
1296+
auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
1297+
std::optional<llvm::fp::ExceptionBehavior> optLLVM =
1298+
llvm::convertStrToExceptionBehavior(mdstr->getString());
1299+
assert(optLLVM && "Expecting FP exception behavior");
1300+
return builder.getAttr<FPExceptionBehaviorAttr>(
1301+
convertFPExceptionBehaviorFromLLVM(*optLLVM));
1302+
}
1303+
1304+
RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
1305+
auto *metadata = cast<llvm::MetadataAsValue>(value);
1306+
auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
1307+
std::optional<llvm::RoundingMode> optLLVM =
1308+
llvm::convertStrToRoundingMode(mdstr->getString());
1309+
assert(optLLVM && "Expecting rounding mode");
1310+
return builder.getAttr<RoundingModeAttr>(
1311+
convertRoundingModeFromLLVM(*optLLVM));
1312+
}
1313+
12931314
FailureOr<SmallVector<AliasScopeAttr>>
12941315
ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
12951316
auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,6 +1721,16 @@ llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
17211721
return debugTranslation->translate(attr);
17221722
}
17231723

1724+
llvm::RoundingMode
1725+
ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
1726+
return convertRoundingModeToLLVM(rounding);
1727+
}
1728+
1729+
llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
1730+
LLVM::FPExceptionBehavior exceptionBehavior) {
1731+
return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
1732+
}
1733+
17241734
llvm::NamedMDNode *
17251735
ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
17261736
return llvmModule->getOrInsertNamedMetadata(name);

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,18 @@ llvm.func @experimental_noalias_scope_decl() {
647647
llvm.intr.experimental.noalias.scope.decl #alias_scope
648648
llvm.return
649649
}
650+
651+
// CHECK-LABEL: @experimental_constrained_fptrunc
652+
llvm.func @experimental_constrained_fptrunc(%in: f64) {
653+
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
654+
%0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
655+
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
656+
%1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
657+
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
658+
%2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
659+
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
660+
%3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
661+
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
662+
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
663+
llvm.return
664+
}

mlir/test/Target/LLVMIR/Import/intrinsic.ll

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
894894
ret float %2
895895
}
896896

897+
; CHECK-LABEL: experimental_constrained_fptrunc
898+
define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
899+
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
900+
%1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
901+
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
902+
%2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
903+
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
904+
%3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
905+
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
906+
%4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
907+
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
908+
%5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
909+
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
910+
%6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
911+
ret void
912+
}
913+
897914
declare float @llvm.fmuladd.f32(float, float, float)
898915
declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
899916
declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
11201137
declare float @llvm.ssa.copy.f32(float returned)
11211138
declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
11221139
declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
1140+
declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
1141+
declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)

0 commit comments

Comments
 (0)