Skip to content

Commit 79d8a34

Browse files
authored
[mlir] add some FP classification ops and their lowering to libdevice (#127322)
Introduce a subset of floating point classification ops to the Math dialect. These ops mirror functions provided by the C math library and, similarly to the existing `math.copysign`, belong to the math dialect. Add a lowering of those ops to Nvidia libdevice calls when possible as the first mechanism to exercise them.
1 parent 552e465 commit 79d8a34

File tree

6 files changed

+211
-5
lines changed

6 files changed

+211
-5
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@ class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
3434
let assemblyFormat = "$operand attr-dict `:` type($result)";
3535
}
3636

37+
// Base class for floating point classification ops. Require an operand and
38+
// result of the same shape, which can be a floating point scalar, a vector or a
39+
// tensor thereof.
40+
class Math_FloatClassificationOp<string mnemonic, list<Trait> traits = []> :
41+
Math_Op<mnemonic,
42+
traits # [DeclareOpInterfaceMethods<ArithFastMathInterface>,
43+
TypesMatchWith<
44+
"result type has i1 element type and same shape as operands",
45+
"operand", "result", "::getI1SameShape($_self)">]> {
46+
let arguments = (ins FloatLike:$operand,
47+
DefaultValuedAttr<Arith_FastMathAttr,
48+
"::mlir::arith::FastMathFlags::none">:$fastmath);
49+
let results = (outs BoolLike:$result);
50+
51+
let assemblyFormat = "$operand attr-dict `:` type($operand)";
52+
}
53+
3754
// Base class for unary math operations on floating point types. Require an
3855
// operand and result of the same type. This type can be a floating point type,
3956
// vector or tensor thereof.
@@ -678,6 +695,79 @@ def Math_IPowIOp : Math_IntegerBinaryOp<"ipowi"> {
678695
let hasFolder = 1;
679696
}
680697

698+
//===----------------------------------------------------------------------===//
699+
// IsFiniteOp
700+
//===----------------------------------------------------------------------===//
701+
702+
def Math_IsFiniteOp : Math_FloatClassificationOp<"isfinite"> {
703+
let summary = "returns true if the operand classifies as finite";
704+
let description = [{
705+
Determines if the given floating-point number has finite value i.e. it
706+
is normal, subnormal or zero, but not infinite or NaN.
707+
708+
Example:
709+
710+
```mlir
711+
%f = math.isfinite %a : f32
712+
```
713+
}];
714+
}
715+
716+
//===----------------------------------------------------------------------===//
717+
// IsInfOp
718+
//===----------------------------------------------------------------------===//
719+
720+
def Math_IsInfOp : Math_FloatClassificationOp<"isinf"> {
721+
let summary = "returns true if the operand classifies as infinite";
722+
let description = [{
723+
Determines if the given floating-point number is positive or negative
724+
infinity.
725+
726+
Example:
727+
728+
```mlir
729+
%f = math.isinf %a : f32
730+
```
731+
}];
732+
}
733+
734+
//===----------------------------------------------------------------------===//
735+
// IsNaNOp
736+
//===----------------------------------------------------------------------===//
737+
738+
def Math_IsNaNOp : Math_FloatClassificationOp<"isnan"> {
739+
let summary = "returns true if the operand classifies as NaN";
740+
let description = [{
741+
Determines if the given floating-point number is a not-a-number (NaN)
742+
value.
743+
744+
Example:
745+
746+
```mlir
747+
%f = math.isnan %a : f32
748+
```
749+
}];
750+
}
751+
752+
753+
//===----------------------------------------------------------------------===//
754+
// IsNormalOp
755+
//===----------------------------------------------------------------------===//
756+
757+
def Math_IsNormalOp : Math_FloatClassificationOp<"isnormal"> {
758+
let summary = "returns true if the operand classifies as normal";
759+
let description = [{
760+
Determines if the given floating-point number is normal, i.e. is neither
761+
zero, subnormal, infinite, nor NaN.
762+
763+
Example:
764+
765+
```mlir
766+
%f = math.isnormal %a : f32
767+
```
768+
}];
769+
}
770+
681771
//===----------------------------------------------------------------------===//
682772
// LogOp
683773
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
7171
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
7272
"expected single result op");
7373

74+
bool isResultBool = op->getResultTypes().front().isInteger(1);
7475
if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
7576
SourceOp>::value) {
7677
assert(op->getNumOperands() > 0 &&
7778
"expected op to take at least one operand");
78-
assert(op->getResultTypes().front() == op->getOperand(0).getType() &&
79+
assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
80+
isResultBool) &&
7981
"expected op with same operand and result types");
8082
}
8183

@@ -88,10 +90,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
8890
for (Value operand : adaptor.getOperands())
8991
castedOperands.push_back(maybeCast(operand, rewriter));
9092

91-
Type resultType = castedOperands.front().getType();
93+
Type castedOperandType = castedOperands.front().getType();
94+
95+
// At ABI level, booleans are treated as i32.
96+
Type resultType =
97+
isResultBool ? rewriter.getIntegerType(32) : castedOperandType;
9298
Type funcType = getFunctionType(resultType, castedOperands);
93-
StringRef funcName = getFunctionName(
94-
cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
99+
StringRef funcName = getFunctionName(castedOperandType, op);
95100
if (funcName.empty())
96101
return failure();
97102

@@ -104,6 +109,20 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
104109
return success();
105110
}
106111

112+
// Boolean result are mapping to i32 at the ABI level with zero values being
113+
// interpreted as false and non-zero values being interpreted as true. Since
114+
// there is no guarantee of a specific value being used to indicate true,
115+
// compare for inequality with zero (rather than truncate or shift).
116+
if (isResultBool) {
117+
Value zero = rewriter.create<LLVM::ConstantOp>(
118+
op->getLoc(), rewriter.getIntegerType(32),
119+
rewriter.getI32IntegerAttr(0));
120+
Value truncated = rewriter.create<LLVM::ICmpOp>(
121+
op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
122+
rewriter.replaceOp(op, {truncated});
123+
return success();
124+
}
125+
107126
assert(callOp.getResult().getType().isF32() &&
108127
"only f32 types are supposed to be truncated back");
109128
Value truncated = rewriter.create<LLVM::FPTruncOp>(
@@ -118,7 +137,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
118137
if (!isa<Float16Type, BFloat16Type>(type))
119138
return operand;
120139

121-
// if there's a f16 function, no need to cast f16 values
140+
// If there's an f16 function, no need to cast f16 values.
122141
if (!f16Func.empty() && isa<Float16Type>(type))
123142
return operand;
124143

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,13 @@ void mlir::populateGpuToNVVMConversionPatterns(
595595
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
596596
"__nv_floor");
597597
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
598+
// Note: libdevice does not provide `__nv_isfinitef` as of moment of writing.
599+
populateOpPatterns<math::IsFiniteOp>(converter, patterns, "",
600+
"__nv_isfinited");
601+
populateOpPatterns<math::IsInfOp>(converter, patterns, "__nv_isinff",
602+
"__nv_isinfd");
603+
populateOpPatterns<math::IsNaNOp>(converter, patterns, "__nv_isnanf",
604+
"__nv_isnand");
598605
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
599606
"__nv_fast_logf");
600607
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616
using namespace mlir;
1717
using namespace mlir::math;
1818

19+
//===----------------------------------------------------------------------===//
20+
// Common helpers
21+
//===----------------------------------------------------------------------===//
22+
23+
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
24+
static Type getI1SameShape(Type type) {
25+
auto i1Type = IntegerType::get(type.getContext(), 1);
26+
if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
27+
return shapedType.cloneWith(std::nullopt, i1Type);
28+
if (llvm::isa<UnrankedTensorType>(type))
29+
return UnrankedTensorType::get(i1Type);
30+
return i1Type;
31+
}
32+
1933
//===----------------------------------------------------------------------===//
2034
// TableGen'd op method definitions
2135
//===----------------------------------------------------------------------===//

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,3 +1058,40 @@ gpu.module @test_module_53 {
10581058
func.return %result32, %result64 : f32, f64
10591059
}
10601060
}
1061+
1062+
gpu.module @test_module_54 {
1063+
// CHECK: llvm.func @__nv_isinff(f32) -> i32
1064+
// CHECK: llvm.func @__nv_isinfd(f64) -> i32
1065+
// CHECK: llvm.func @__nv_isnanf(f32) -> i32
1066+
// CHECK: llvm.func @__nv_isnand(f64) -> i32
1067+
// CHECK: llvm.func @__nv_isfinited(f64) -> i32
1068+
// CHECK-LABEL: @fpclassify
1069+
func.func @fpclassify(%f32: f32, %f64: f64) -> (i1, i1, i1, i1, i1, i1) {
1070+
// CHECK: %[[INFF:.+]] = llvm.call @__nv_isinff(%{{.*}}) : (f32) -> i32
1071+
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i32) : i32
1072+
// CHECK: %[[R0:.+]] = llvm.icmp "ne" %[[INFF]], %[[ZERO]]
1073+
%0 = math.isinf %f32 : f32
1074+
// CHECK: llvm.call @__nv_isinfd(%{{.*}}) : (f64) -> i32
1075+
// CHECK: llvm.mlir.constant(0
1076+
// CHECK: llvm.icmp "ne"
1077+
%1 = math.isinf %f64 : f64
1078+
// CHECK: llvm.call @__nv_isnanf(%{{.*}}) : (f32) -> i32
1079+
// CHECK: llvm.mlir.constant(0
1080+
// CHECK: llvm.icmp "ne"
1081+
%2 = math.isnan %f32 : f32
1082+
// CHECK: llvm.call @__nv_isnand(%{{.*}}) : (f64) -> i32
1083+
// CHECK: llvm.mlir.constant(0
1084+
// CHECK: llvm.icmp "ne"
1085+
%3 = math.isnan %f64 : f64
1086+
// Note: for some reason, libdevice does not provide isfinite for f32, so
1087+
// this should fail to convert.
1088+
// CHECK: math.isfinite {{.*}} : f32
1089+
%4 = math.isfinite %f32 : f32
1090+
// CHECK: llvm.call @__nv_isfinited(%{{.*}}) : (f64) -> i32
1091+
// CHECK: llvm.mlir.constant(0
1092+
// CHECK: llvm.icmp "ne"
1093+
%5 = math.isfinite %f64 : f64
1094+
// CHECK: llvm.return %[[R0]]
1095+
return %0, %1, %2, %3, %4, %5 : i1, i1, i1, i1, i1, i1
1096+
}
1097+
}

mlir/test/Dialect/Math/ops.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,42 @@ func.func @fastmath(%f: f32, %i: i32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>)
298298
%4 = math.fpowi %f, %i fastmath<fast> : f32, i32
299299
return
300300
}
301+
302+
// CHECK-LABEL: func @fpclassify(
303+
// CHECK-SAME: %[[F:.+]]: f32, %[[D:.+]]: f64,
304+
// CHECK-SAME: %[[V:.+]]: vector<4xf32>, %[[T:.+]]: tensor<4x?xf32>
305+
func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>) {
306+
// CHECK: math.isfinite %[[F]] : f32
307+
// CHECK: math.isfinite %[[D]] : f64
308+
// CHECK: math.isfinite %[[V]] : vector<4xf32>
309+
// CHECK: math.isfinite %[[T]] : tensor<4x?xf32>
310+
math.isfinite %f : f32
311+
math.isfinite %d : f64
312+
math.isfinite %v : vector<4xf32>
313+
math.isfinite %t : tensor<4x?xf32>
314+
// CHECK: math.isinf %[[F]] : f32
315+
// CHECK: math.isinf %[[D]] : f64
316+
// CHECK: math.isinf %[[V]] : vector<4xf32>
317+
// CHECK: math.isinf %[[T]] : tensor<4x?xf32>
318+
math.isinf %f : f32
319+
math.isinf %d : f64
320+
math.isinf %v : vector<4xf32>
321+
math.isinf %t : tensor<4x?xf32>
322+
// CHECK: math.isnan %[[F]] : f32
323+
// CHECK: math.isnan %[[D]] : f64
324+
// CHECK: math.isnan %[[V]] : vector<4xf32>
325+
// CHECK: math.isnan %[[T]] : tensor<4x?xf32>
326+
math.isnan %f : f32
327+
math.isnan %d : f64
328+
math.isnan %v : vector<4xf32>
329+
math.isnan %t : tensor<4x?xf32>
330+
// CHECK: math.isnormal %[[F]] : f32
331+
// CHECK: math.isnormal %[[D]] : f64
332+
// CHECK: math.isnormal %[[V]] : vector<4xf32>
333+
// CHECK: math.isnormal %[[T]] : tensor<4x?xf32>
334+
math.isnormal %f : f32
335+
math.isnormal %d : f64
336+
math.isnormal %v : vector<4xf32>
337+
math.isnormal %t : tensor<4x?xf32>
338+
return
339+
}

0 commit comments

Comments
 (0)