8
8
9
9
#include " mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
10
10
11
+ #include " mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
11
12
#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
12
13
#include " mlir/Conversion/LLVMCommon/VectorPattern.h"
13
14
#include " mlir/Dialect/Arith/IR/Arith.h"
@@ -24,93 +25,20 @@ using namespace mlir;
24
25
25
26
namespace {
26
27
27
- // Map arithmetic fastmath enum values to LLVMIR enum values.
28
- static LLVM::FastmathFlags
29
- convertArithFastMathFlagsToLLVM (arith::FastMathFlags arithFMF) {
30
- LLVM::FastmathFlags llvmFMF{};
31
- const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
32
- {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
33
- {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
34
- {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
35
- {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
36
- {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
37
- {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
38
- {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
39
- for (auto fmfMap : flags) {
40
- if (bitEnumContainsAny (arithFMF, fmfMap.first ))
41
- llvmFMF = llvmFMF | fmfMap.second ;
42
- }
43
- return llvmFMF;
44
- }
45
-
46
- // Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
47
- static LLVM::FastmathFlagsAttr
48
- convertArithFastMathAttr (arith::FastMathFlagsAttr fmfAttr) {
49
- arith::FastMathFlags arithFMF = fmfAttr.getValue ();
50
- return LLVM::FastmathFlagsAttr::get (
51
- fmfAttr.getContext (), convertArithFastMathFlagsToLLVM (arithFMF));
52
- }
53
-
54
- // Attribute converter that populates a NamedAttrList by removing the fastmath
55
- // attribute from the source operation attributes, and replacing it with an
56
- // equivalent LLVM fastmath attribute.
57
- template <typename SourceOp, typename TargetOp>
58
- class AttrConvertFastMath {
59
- public:
60
- AttrConvertFastMath (SourceOp srcOp) {
61
- // Copy the source attributes.
62
- convertedAttr = NamedAttrList{srcOp->getAttrs ()};
63
- // Get the name of the arith fastmath attribute.
64
- llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName ();
65
- // Remove the source fastmath attribute.
66
- auto arithFMFAttr = convertedAttr.erase (arithFMFAttrName)
67
- .template dyn_cast_or_null <arith::FastMathFlagsAttr>();
68
- if (arithFMFAttr) {
69
- llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName ();
70
- convertedAttr.set (targetAttrName, convertArithFastMathAttr (arithFMFAttr));
71
- }
72
- }
73
-
74
- ArrayRef<NamedAttribute> getAttrs () const { return convertedAttr.getAttrs (); }
75
-
76
- private:
77
- NamedAttrList convertedAttr;
78
- };
79
-
80
- // Attribute converter that populates a NamedAttrList by removing the fastmath
81
- // attribute from the source operation attributes. This may be useful for
82
- // target operations that do not require the fastmath attribute, or for targets
83
- // that do not yet support the LLVM fastmath attribute.
84
- template <typename SourceOp, typename TargetOp>
85
- class AttrDropFastMath {
86
- public:
87
- AttrDropFastMath (SourceOp srcOp) {
88
- // Copy the source attributes.
89
- convertedAttr = NamedAttrList{srcOp->getAttrs ()};
90
- // Get the name of the arith fastmath attribute.
91
- llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName ();
92
- // Remove the source fastmath attribute.
93
- convertedAttr.erase (arithFMFAttrName);
94
- }
95
-
96
- ArrayRef<NamedAttribute> getAttrs () const { return convertedAttr.getAttrs (); }
97
-
98
- private:
99
- NamedAttrList convertedAttr;
100
- };
101
-
102
28
// ===----------------------------------------------------------------------===//
103
29
// Straightforward Op Lowerings
104
30
// ===----------------------------------------------------------------------===//
105
31
106
- using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
107
- AttrConvertFastMath>;
32
+ using AddFOpLowering =
33
+ VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
34
+ arith::AttrConvertFastMathToLLVM>;
108
35
using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
109
36
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
110
37
using BitcastOpLowering =
111
38
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
112
- using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
113
- AttrConvertFastMath>;
39
+ using DivFOpLowering =
40
+ VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
41
+ arith::AttrConvertFastMathToLLVM>;
114
42
using DivSIOpLowering =
115
43
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
116
44
using DivUIOpLowering =
@@ -125,28 +53,30 @@ using FPToSIOpLowering =
125
53
using FPToUIOpLowering =
126
54
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
127
55
// TODO: Add LLVM intrinsic support for fastmath
128
- using MaxFOpLowering =
129
- VectorConvertToLLVMPattern< arith::MaxFOp, LLVM::MaxNumOp, AttrDropFastMath>;
56
+ using MaxFOpLowering = VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp,
57
+ arith::AttrDropFastMath>;
130
58
using MaxSIOpLowering =
131
59
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
132
60
using MaxUIOpLowering =
133
61
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
134
62
// TODO: Add LLVM intrinsic support for fastmath
135
- using MinFOpLowering =
136
- VectorConvertToLLVMPattern< arith::MinFOp, LLVM::MinNumOp, AttrDropFastMath>;
63
+ using MinFOpLowering = VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp,
64
+ arith::AttrDropFastMath>;
137
65
using MinSIOpLowering =
138
66
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
139
67
using MinUIOpLowering =
140
68
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
141
- using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
142
- AttrConvertFastMath>;
69
+ using MulFOpLowering =
70
+ VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
71
+ arith::AttrConvertFastMathToLLVM>;
143
72
using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
144
- using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
145
- AttrConvertFastMath>;
73
+ using NegFOpLowering =
74
+ VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
75
+ arith::AttrConvertFastMathToLLVM>;
146
76
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
147
77
// TODO: Add LLVM intrinsic support for fastmath
148
- using RemFOpLowering =
149
- VectorConvertToLLVMPattern< arith::RemFOp, LLVM::FRemOp, AttrDropFastMath>;
78
+ using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
79
+ arith::AttrDropFastMath>;
150
80
using RemSIOpLowering =
151
81
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
152
82
using RemUIOpLowering =
@@ -160,8 +90,9 @@ using ShRUIOpLowering =
160
90
VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
161
91
using SIToFPOpLowering =
162
92
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
163
- using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
164
- AttrConvertFastMath>;
93
+ using SubFOpLowering =
94
+ VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
95
+ arith::AttrConvertFastMathToLLVM>;
165
96
using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
166
97
using TruncFOpLowering =
167
98
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
0 commit comments