@@ -20,6 +20,9 @@ namespace mlir {
20
20
// / depending on the element type that Op operates upon. The function
21
21
// / declaration is added in case it was not added before.
22
22
// /
23
+ // / If the input values are of f16 type, the value is first casted to f32, the
24
+ // / function called and then the result casted back.
25
+ // /
23
26
// / Example with NVVM:
24
27
// / %exp_f32 = std.exp %arg_f32 : f32
25
28
// /
@@ -44,21 +47,48 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
44
47
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
45
48
" expected single result op" );
46
49
47
- LLVMType resultType = typeConverter.convertType (op->getResult (0 ).getType ())
48
- .template cast <LLVM::LLVMType>();
49
- LLVMType funcType = getFunctionType (resultType, operands);
50
- StringRef funcName = getFunctionName (resultType);
50
+ static_assert (std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
51
+ SourceOp>::value,
52
+ " expected op with same operand and result types" );
53
+
54
+ SmallVector<Value, 1 > castedOperands;
55
+ for (Value operand : operands)
56
+ castedOperands.push_back (maybeCast (operand, rewriter));
57
+
58
+ LLVMType resultType =
59
+ castedOperands.front ().getType ().cast <LLVM::LLVMType>();
60
+ LLVMType funcType = getFunctionType (resultType, castedOperands);
61
+ StringRef funcName = getFunctionName (funcType.getFunctionResultType ());
51
62
if (funcName.empty ())
52
63
return failure ();
53
64
54
65
LLVMFuncOp funcOp = appendOrGetFuncOp (funcName, funcType, op);
55
66
auto callOp = rewriter.create <LLVM::CallOp>(
56
- op->getLoc (), resultType, rewriter.getSymbolRefAttr (funcOp), operands);
57
- rewriter.replaceOp (op, {callOp.getResult (0 )});
67
+ op->getLoc (), resultType, rewriter.getSymbolRefAttr (funcOp),
68
+ castedOperands);
69
+
70
+ if (resultType == operands.front ().getType ()) {
71
+ rewriter.replaceOp (op, {callOp.getResult (0 )});
72
+ return success ();
73
+ }
74
+
75
+ Value truncated = rewriter.create <LLVM::FPTruncOp>(
76
+ op->getLoc (), operands.front ().getType (), callOp.getResult (0 ));
77
+ rewriter.replaceOp (op, {truncated});
58
78
return success ();
59
79
}
60
80
61
81
private:
82
+ Value maybeCast (Value operand, PatternRewriter &rewriter) const {
83
+ LLVM::LLVMType type = operand.getType ().cast <LLVM::LLVMType>();
84
+ if (!type.isHalfTy ())
85
+ return operand;
86
+
87
+ return rewriter.create <LLVM::FPExtOp>(
88
+ operand.getLoc (), LLVM::LLVMType::getFloatTy (&type.getDialect ()),
89
+ operand);
90
+ }
91
+
62
92
LLVM::LLVMType getFunctionType (LLVM::LLVMType resultType,
63
93
ArrayRef<Value> operands) const {
64
94
using LLVM::LLVMType;
0 commit comments