Skip to content

Commit 2c8afe1

Browse files
author
Stephan Herhut
committed
[mlir][gpu] Add support for f16 when lowering to nvvm intrinsics
Summary: The NVVM target only provides implementations for tanh etc. on f32 and f64 operands. To also support f16, we now insert operations to extend to f32 and truncate back to f16 around the intrinsic call. Differential Revision: https://reviews.llvm.org/D81473
1 parent b7d3692 commit 2c8afe1

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ namespace mlir {
2020
/// depending on the element type that Op operates upon. The function
2121
/// declaration is added in case it was not added before.
2222
///
23+
/// If the input values are of f16 type, the value is first casted to f32, the
24+
/// function called and then the result casted back.
25+
///
2326
/// Example with NVVM:
2427
/// %exp_f32 = std.exp %arg_f32 : f32
2528
///
@@ -44,21 +47,48 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
4447
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
4548
"expected single result op");
4649

47-
LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
48-
.template cast<LLVM::LLVMType>();
49-
LLVMType funcType = getFunctionType(resultType, operands);
50-
StringRef funcName = getFunctionName(resultType);
50+
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
51+
SourceOp>::value,
52+
"expected op with same operand and result types");
53+
54+
SmallVector<Value, 1> castedOperands;
55+
for (Value operand : operands)
56+
castedOperands.push_back(maybeCast(operand, rewriter));
57+
58+
LLVMType resultType =
59+
castedOperands.front().getType().cast<LLVM::LLVMType>();
60+
LLVMType funcType = getFunctionType(resultType, castedOperands);
61+
StringRef funcName = getFunctionName(funcType.getFunctionResultType());
5162
if (funcName.empty())
5263
return failure();
5364

5465
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
5566
auto callOp = rewriter.create<LLVM::CallOp>(
56-
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
57-
rewriter.replaceOp(op, {callOp.getResult(0)});
67+
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
68+
castedOperands);
69+
70+
if (resultType == operands.front().getType()) {
71+
rewriter.replaceOp(op, {callOp.getResult(0)});
72+
return success();
73+
}
74+
75+
Value truncated = rewriter.create<LLVM::FPTruncOp>(
76+
op->getLoc(), operands.front().getType(), callOp.getResult(0));
77+
rewriter.replaceOp(op, {truncated});
5878
return success();
5979
}
6080

6181
private:
82+
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
83+
LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
84+
if (!type.isHalfTy())
85+
return operand;
86+
87+
return rewriter.create<LLVM::FPExtOp>(
88+
operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
89+
operand);
90+
}
91+
6292
LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
6393
ArrayRef<Value> operands) const {
6494
using LLVM::LLVMType;

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,16 @@ gpu.module @test_module {
219219
// CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
220220
// CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
221221
// CHECK-LABEL: func @gpu_tanh
222-
func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
222+
func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
223+
%result16 = std.tanh %arg_f16 : f16
224+
// CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
225+
// CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
226+
// CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
223227
%result32 = std.tanh %arg_f32 : f32
224228
// CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
225229
%result64 = std.tanh %arg_f64 : f64
226230
// CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
227-
std.return %result32, %result64 : f32, f64
231+
std.return %result16, %result32, %result64 : f16, f32, f64
228232
}
229233
}
230234

0 commit comments

Comments
 (0)