Skip to content

Commit 589764a

Browse files
committed
[mlir][math] Initial support for fastmath flag attributes for Math dialect.
Added arith::FastMathAttr and ArithFastMathInterface support for Math dialect floating point operations. This change-set creates ArithCommon conversion utils that currently provide classes and methods to aid with arith::FastMathAttr conversion into LLVM::FastmathFlags. These utils are used in ArithToLLVM and MathToLLVM convertors, but may eventually be used by other converters that need to convert fast math attributes. Since Math dialect operations use arith::FastMathAttr, MathOps.td now has to include enum and attributes definitions from Arith dialect. To minimize the amount of TD code included from Arith dialect, I moved FastMathAttr definition into ArithBase.td. Differential Revision: https://reviews.llvm.org/D136312
1 parent d0d4b63 commit 589764a

File tree

14 files changed

+359
-138
lines changed

14 files changed

+359
-138
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
10+
#define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
11+
12+
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// Support for converting Arith FastMathFlags to LLVM FastmathFlags
17+
//===----------------------------------------------------------------------===//
18+
19+
namespace mlir {
20+
namespace arith {
21+
// Map arithmetic fastmath enum values to LLVMIR enum values.
22+
LLVM::FastmathFlags
23+
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
24+
25+
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
26+
LLVM::FastmathFlagsAttr
27+
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
28+
29+
// Attribute converter that populates a NamedAttrList by removing the fastmath
30+
// attribute from the source operation attributes, and replacing it with an
31+
// equivalent LLVM fastmath attribute.
32+
template <typename SourceOp, typename TargetOp>
33+
class AttrConvertFastMathToLLVM {
34+
public:
35+
AttrConvertFastMathToLLVM(SourceOp srcOp) {
36+
// Copy the source attributes.
37+
convertedAttr = NamedAttrList{srcOp->getAttrs()};
38+
// Get the name of the arith fastmath attribute.
39+
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
40+
// Remove the source fastmath attribute.
41+
auto arithFMFAttr =
42+
convertedAttr.erase(arithFMFAttrName)
43+
.template dyn_cast_or_null<arith::FastMathFlagsAttr>();
44+
if (arithFMFAttr) {
45+
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
46+
convertedAttr.set(targetAttrName,
47+
convertArithFastMathAttrToLLVM(arithFMFAttr));
48+
}
49+
}
50+
51+
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
52+
53+
private:
54+
NamedAttrList convertedAttr;
55+
};
56+
57+
// Attribute converter that populates a NamedAttrList by removing the fastmath
58+
// attribute from the source operation attributes. This may be useful for
59+
// target operations that do not require the fastmath attribute, or for targets
60+
// that do not yet support the LLVM fastmath attribute.
61+
template <typename SourceOp, typename TargetOp>
62+
class AttrDropFastMath {
63+
public:
64+
AttrDropFastMath(SourceOp srcOp) {
65+
// Copy the source attributes.
66+
convertedAttr = NamedAttrList{srcOp->getAttrs()};
67+
// Get the name of the arith fastmath attribute.
68+
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
69+
// Remove the source fastmath attribute.
70+
convertedAttr.erase(arithFMFAttrName);
71+
}
72+
73+
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
74+
75+
private:
76+
NamedAttrList convertedAttr;
77+
};
78+
} // namespace arith
79+
} // namespace mlir
80+
81+
#endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,9 @@ def FastMathFlags : I32BitEnumAttr<
121121
let printBitEnumPrimaryGroups = 1;
122122
}
123123

124+
def Arith_FastMathAttr :
125+
EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
126+
let assemblyFormat = "`<` $value `>`";
127+
}
128+
124129
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
2020
include "mlir/IR/OpAsmInterface.td"
2121
include "mlir/IR/EnumAttr.td"
2222

23-
def Arith_FastMathAttr :
24-
EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
25-
let assemblyFormat = "`<` $value `>`";
26-
}
27-
2823
// Base class for Arith dialect ops. Ops in this dialect have no side
2924
// effects and can be applied element-wise to vectors and tensors.
3025
class Arith_Op<string mnemonic, list<Trait> traits = []> :

mlir/include/mlir/Dialect/Math/IR/Math.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_MATH_IR_MATH_H_
1010
#define MLIR_DIALECT_MATH_IR_MATH_H_
1111

12+
#include "mlir/Dialect/Arith/IR/Arith.h"
1213
#include "mlir/IR/BuiltinTypes.h"
1314
#include "mlir/IR/Dialect.h"
1415
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef MATH_OPS
1010
#define MATH_OPS
1111

12+
include "mlir/Dialect/Arith/IR/ArithBase.td"
13+
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
1214
include "mlir/Dialect/Math/IR/MathBase.td"
1315
include "mlir/Interfaces/InferTypeOpInterface.td"
1416
include "mlir/Interfaces/VectorInterfaces.td"
@@ -36,11 +38,16 @@ class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
3638
// operand and result of the same type. This type can be a floating point type,
3739
// vector or tensor thereof.
3840
class Math_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
39-
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
40-
let arguments = (ins FloatLike:$operand);
41+
Math_Op<mnemonic,
42+
traits # [SameOperandsAndResultType,
43+
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
44+
let arguments = (ins FloatLike:$operand,
45+
DefaultValuedAttr<Arith_FastMathAttr,
46+
"::mlir::arith::FastMathFlags::none">:$fastmath);
4147
let results = (outs FloatLike:$result);
4248

43-
let assemblyFormat = "$operand attr-dict `:` type($result)";
49+
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
50+
attr-dict `:` type($result) }];
4451
}
4552

4653
// Base class for binary math operations on integer types. Require two
@@ -58,22 +65,32 @@ class Math_IntegerBinaryOp<string mnemonic, list<Trait> traits = []> :
5865
// operands and one result of the same type. This type can be a floating point
5966
// type, vector or tensor thereof.
6067
class Math_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
61-
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
62-
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
68+
Math_Op<mnemonic,
69+
traits # [SameOperandsAndResultType,
70+
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
71+
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
72+
DefaultValuedAttr<Arith_FastMathAttr,
73+
"::mlir::arith::FastMathFlags::none">:$fastmath);
6374
let results = (outs FloatLike:$result);
6475

65-
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
76+
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
77+
attr-dict `:` type($result) }];
6678
}
6779

6880
// Base class for floating point ternary operations. Require three operands and
6981
// one result of the same type. This type can be a floating point type, vector
7082
// or tensor thereof.
7183
class Math_FloatTernaryOp<string mnemonic, list<Trait> traits = []> :
72-
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
73-
let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c);
84+
Math_Op<mnemonic,
85+
traits # [SameOperandsAndResultType,
86+
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
87+
let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c,
88+
DefaultValuedAttr<Arith_FastMathAttr,
89+
"::mlir::arith::FastMathFlags::none">:$fastmath);
7490
let results = (outs FloatLike:$result);
7591

76-
let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)";
92+
let assemblyFormat = [{ $a `,` $b `,` $c (`fastmath` `` $fastmath^)?
93+
attr-dict `:` type($result) }];
7794
}
7895

7996
//===----------------------------------------------------------------------===//
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
10+
11+
using namespace mlir;
12+
13+
// Map arithmetic fastmath enum values to LLVMIR enum values.
14+
LLVM::FastmathFlags
15+
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
16+
LLVM::FastmathFlags llvmFMF{};
17+
const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
18+
{arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
19+
{arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
20+
{arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
21+
{arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
22+
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
23+
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
24+
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
25+
for (auto fmfMap : flags) {
26+
if (bitEnumContainsAny(arithFMF, fmfMap.first))
27+
llvmFMF = llvmFMF | fmfMap.second;
28+
}
29+
return llvmFMF;
30+
}
31+
32+
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
33+
LLVM::FastmathFlagsAttr
34+
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
35+
arith::FastMathFlags arithFMF = fmfAttr.getValue();
36+
return LLVM::FastmathFlagsAttr::get(
37+
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
38+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
add_mlir_conversion_library(MLIRArithAttrToLLVMConversion
2+
AttrToLLVMConverter.cpp
3+
4+
LINK_COMPONENTS
5+
Core
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArithDialect
9+
MLIRLLVMDialect
10+
)

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 22 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1010

11+
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
1112
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1213
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1314
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -24,93 +25,20 @@ using namespace mlir;
2425

2526
namespace {
2627

27-
// Map arithmetic fastmath enum values to LLVMIR enum values.
28-
static LLVM::FastmathFlags
29-
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
30-
LLVM::FastmathFlags llvmFMF{};
31-
const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
32-
{arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
33-
{arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
34-
{arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
35-
{arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
36-
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
37-
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
38-
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
39-
for (auto fmfMap : flags) {
40-
if (bitEnumContainsAny(arithFMF, fmfMap.first))
41-
llvmFMF = llvmFMF | fmfMap.second;
42-
}
43-
return llvmFMF;
44-
}
45-
46-
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
47-
static LLVM::FastmathFlagsAttr
48-
convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) {
49-
arith::FastMathFlags arithFMF = fmfAttr.getValue();
50-
return LLVM::FastmathFlagsAttr::get(
51-
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
52-
}
53-
54-
// Attribute converter that populates a NamedAttrList by removing the fastmath
55-
// attribute from the source operation attributes, and replacing it with an
56-
// equivalent LLVM fastmath attribute.
57-
template <typename SourceOp, typename TargetOp>
58-
class AttrConvertFastMath {
59-
public:
60-
AttrConvertFastMath(SourceOp srcOp) {
61-
// Copy the source attributes.
62-
convertedAttr = NamedAttrList{srcOp->getAttrs()};
63-
// Get the name of the arith fastmath attribute.
64-
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
65-
// Remove the source fastmath attribute.
66-
auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName)
67-
.template dyn_cast_or_null<arith::FastMathFlagsAttr>();
68-
if (arithFMFAttr) {
69-
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
70-
convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr));
71-
}
72-
}
73-
74-
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
75-
76-
private:
77-
NamedAttrList convertedAttr;
78-
};
79-
80-
// Attribute converter that populates a NamedAttrList by removing the fastmath
81-
// attribute from the source operation attributes. This may be useful for
82-
// target operations that do not require the fastmath attribute, or for targets
83-
// that do not yet support the LLVM fastmath attribute.
84-
template <typename SourceOp, typename TargetOp>
85-
class AttrDropFastMath {
86-
public:
87-
AttrDropFastMath(SourceOp srcOp) {
88-
// Copy the source attributes.
89-
convertedAttr = NamedAttrList{srcOp->getAttrs()};
90-
// Get the name of the arith fastmath attribute.
91-
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
92-
// Remove the source fastmath attribute.
93-
convertedAttr.erase(arithFMFAttrName);
94-
}
95-
96-
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
97-
98-
private:
99-
NamedAttrList convertedAttr;
100-
};
101-
10228
//===----------------------------------------------------------------------===//
10329
// Straightforward Op Lowerings
10430
//===----------------------------------------------------------------------===//
10531

106-
using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
107-
AttrConvertFastMath>;
32+
using AddFOpLowering =
33+
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
34+
arith::AttrConvertFastMathToLLVM>;
10835
using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
10936
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
11037
using BitcastOpLowering =
11138
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
112-
using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
113-
AttrConvertFastMath>;
39+
using DivFOpLowering =
40+
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
41+
arith::AttrConvertFastMathToLLVM>;
11442
using DivSIOpLowering =
11543
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
11644
using DivUIOpLowering =
@@ -125,28 +53,30 @@ using FPToSIOpLowering =
12553
using FPToUIOpLowering =
12654
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
12755
// TODO: Add LLVM intrinsic support for fastmath
128-
using MaxFOpLowering =
129-
VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp, AttrDropFastMath>;
56+
using MaxFOpLowering = VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp,
57+
arith::AttrDropFastMath>;
13058
using MaxSIOpLowering =
13159
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
13260
using MaxUIOpLowering =
13361
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
13462
// TODO: Add LLVM intrinsic support for fastmath
135-
using MinFOpLowering =
136-
VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp, AttrDropFastMath>;
63+
using MinFOpLowering = VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp,
64+
arith::AttrDropFastMath>;
13765
using MinSIOpLowering =
13866
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
13967
using MinUIOpLowering =
14068
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
141-
using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
142-
AttrConvertFastMath>;
69+
using MulFOpLowering =
70+
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
71+
arith::AttrConvertFastMathToLLVM>;
14372
using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
144-
using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
145-
AttrConvertFastMath>;
73+
using NegFOpLowering =
74+
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
75+
arith::AttrConvertFastMathToLLVM>;
14676
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
14777
// TODO: Add LLVM intrinsic support for fastmath
148-
using RemFOpLowering =
149-
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, AttrDropFastMath>;
78+
using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
79+
arith::AttrDropFastMath>;
15080
using RemSIOpLowering =
15181
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
15282
using RemUIOpLowering =
@@ -160,8 +90,9 @@ using ShRUIOpLowering =
16090
VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
16191
using SIToFPOpLowering =
16292
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
163-
using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
164-
AttrConvertFastMath>;
93+
using SubFOpLowering =
94+
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
95+
arith::AttrConvertFastMathToLLVM>;
16596
using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
16697
using TruncFOpLowering =
16798
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;

mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToLLVM
1111
Core
1212

1313
LINK_LIBS PUBLIC
14+
MLIRArithAttrToLLVMConversion
1415
MLIRArithDialect
1516
MLIRLLVMCommonConversion
1617
MLIRLLVMDialect

0 commit comments

Comments
 (0)