16
16
17
17
namespace mlir {
18
18
19
- // / Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20
- // / `f32ApproxFunc` or `f16Func` depending on the element type and the
21
- // / fastMathFlag of that Op. The function declaration is added in case it was
22
- // / not added before.
19
+ namespace {
20
+ // / Detection trait tor the `getFastmath` instance method.
21
+ template <typename T>
22
+ using has_get_fastmath_t = decltype (std::declval<T>().getFastmath());
23
+ } // namespace
24
+
25
+ // / Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
26
+ // / `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
27
+ // / the fastMathFlag of that Op, if present. The function declaration is added
28
+ // / in case it was not added before.
23
29
// /
24
30
// / If the input values are of bf16 type (or f16 type if f16Func is empty), the
25
31
// / value is first casted to f32, the function called and then the result casted
@@ -39,14 +45,22 @@ namespace mlir {
39
45
// /
40
46
// / will be transformed into
41
47
// / llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
48
+ // /
49
+ // / Final example with NVVM:
50
+ // / %pow_f32 = math.fpowi %arg_f32, %arg_i32
51
+ // /
52
+ // / will be transformed into
53
+ // / llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
42
54
template <typename SourceOp>
43
55
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern <SourceOp> {
44
56
public:
45
57
explicit OpToFuncCallLowering (const LLVMTypeConverter &lowering,
46
58
StringRef f32Func, StringRef f64Func,
47
- StringRef f32ApproxFunc, StringRef f16Func)
59
+ StringRef f32ApproxFunc, StringRef f16Func,
60
+ StringRef i32Func = " " )
48
61
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
49
- f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
62
+ f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
63
+ i32Func(i32Func) {}
50
64
51
65
LogicalResult
52
66
matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -76,9 +90,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
76
90
77
91
Type resultType = castedOperands.front ().getType ();
78
92
Type funcType = getFunctionType (resultType, castedOperands);
79
- StringRef funcName =
80
- getFunctionName (cast<LLVM::LLVMFunctionType>(funcType).getReturnType (),
81
- op.getFastmath ());
93
+ StringRef funcName = getFunctionName (
94
+ cast<LLVM::LLVMFunctionType>(funcType).getReturnType (), op);
82
95
if (funcName.empty ())
83
96
return failure ();
84
97
@@ -91,14 +104,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
91
104
return success ();
92
105
}
93
106
107
+ assert (callOp.getResult ().getType ().isF32 () &&
108
+ " only f32 types are supposed to be truncated back" );
94
109
Value truncated = rewriter.create <LLVM::FPTruncOp>(
95
110
op->getLoc (), adaptor.getOperands ().front ().getType (),
96
111
callOp.getResult ());
97
112
rewriter.replaceOp (op, {truncated});
98
113
return success ();
99
114
}
100
115
101
- private:
102
116
Value maybeCast (Value operand, PatternRewriter &rewriter) const {
103
117
Type type = operand.getType ();
104
118
if (!isa<Float16Type, BFloat16Type>(type))
@@ -117,38 +131,50 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
117
131
return LLVM::LLVMFunctionType::get (resultType, operandTypes);
118
132
}
119
133
120
- StringRef getFunctionName (Type type, arith::FastMathFlags flag) const {
121
- if (isa<Float16Type>(type))
122
- return f16Func;
123
- if (isa<Float32Type>(type)) {
124
- if (((uint32_t )arith::FastMathFlags::afn & (uint32_t )flag) &&
125
- !f32ApproxFunc.empty ())
126
- return f32ApproxFunc;
127
- else
128
- return f32Func;
129
- }
130
- if (isa<Float64Type>(type))
131
- return f64Func;
132
- return " " ;
133
- }
134
-
135
134
LLVM::LLVMFuncOp appendOrGetFuncOp (StringRef funcName, Type funcType,
136
135
Operation *op) const {
137
136
using LLVM::LLVMFuncOp;
138
137
139
138
auto funcAttr = StringAttr::get (op->getContext (), funcName);
140
- Operation *funcOp = SymbolTable::lookupNearestSymbolFrom (op, funcAttr);
139
+ auto funcOp =
140
+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
141
141
if (funcOp)
142
- return cast<LLVMFuncOp>(* funcOp) ;
142
+ return funcOp;
143
143
144
- mlir::OpBuilder b (op->getParentOfType <FunctionOpInterface>());
144
+ auto parentFunc = op->getParentOfType <FunctionOpInterface>();
145
+ assert (parentFunc && " expected there to be a parent function" );
146
+ OpBuilder b (parentFunc);
145
147
return b.create <LLVMFuncOp>(op->getLoc (), funcName, funcType);
146
148
}
147
149
150
+ StringRef getFunctionName (Type type, SourceOp op) const {
151
+ bool useApprox = false ;
152
+ if constexpr (llvm::is_detected<has_get_fastmath_t , SourceOp>::value) {
153
+ arith::FastMathFlags flag = op.getFastmath ();
154
+ useApprox = ((uint32_t )arith::FastMathFlags::afn & (uint32_t )flag) &&
155
+ !f32ApproxFunc.empty ();
156
+ }
157
+
158
+ if (isa<Float16Type>(type))
159
+ return f16Func;
160
+ if (isa<Float32Type>(type)) {
161
+ if (useApprox)
162
+ return f32ApproxFunc;
163
+ return f32Func;
164
+ }
165
+ if (isa<Float64Type>(type))
166
+ return f64Func;
167
+
168
+ if (type.isInteger (32 ))
169
+ return i32Func;
170
+ return " " ;
171
+ }
172
+
148
173
const std::string f32Func;
149
174
const std::string f64Func;
150
175
const std::string f32ApproxFunc;
151
176
const std::string f16Func;
177
+ const std::string i32Func;
152
178
};
153
179
154
180
} // namespace mlir
0 commit comments