Skip to content

[mlir] add some FP classification ops and their lowering to libdevice #127322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions mlir/include/mlir/Dialect/Math/IR/MathOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "$operand attr-dict `:` type($result)";
}

// Base class for floating point classification ops. Require an operand and
// result of the same shape, which can be a floating point scalar, a vector or a
// tensor thereof.
class Math_FloatClassificationOp<string mnemonic, list<Trait> traits = []> :
Math_Op<mnemonic,
traits # [DeclareOpInterfaceMethods<ArithFastMathInterface>,
TypesMatchWith<
"result type has i1 element type and same shape as operands",
"operand", "result", "::getI1SameShape($_self)">]> {
let arguments = (ins FloatLike:$operand,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs BoolLike:$result);

let assemblyFormat = "$operand attr-dict `:` type($operand)";
}

// Base class for unary math operations on floating point types. Require an
// operand and result of the same type. This type can be a floating point type,
// vector or tensor thereof.
Expand Down Expand Up @@ -678,6 +695,79 @@ def Math_IPowIOp : Math_IntegerBinaryOp<"ipowi"> {
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// IsFiniteOp
//===----------------------------------------------------------------------===//

def Math_IsFiniteOp : Math_FloatClassificationOp<"isfinite"> {
let summary = "returns true if the operand classifies as finite";
let description = [{
Determines if the given floating-point number has finite value i.e. it
is normal, subnormal or zero, but not infinite or NaN.

Example:

```mlir
%f = math.isfinite %a : f32
```
}];
}

//===----------------------------------------------------------------------===//
// IsInfOp
//===----------------------------------------------------------------------===//

def Math_IsInfOp : Math_FloatClassificationOp<"isinf"> {
let summary = "returns true if the operand classifies as infinite";
let description = [{
Determines if the given floating-point number is positive or negative
infinity.

Example:

```mlir
%f = math.isinf %a : f32
```
}];
}

//===----------------------------------------------------------------------===//
// IsNaNOp
//===----------------------------------------------------------------------===//

def Math_IsNaNOp : Math_FloatClassificationOp<"isnan"> {
let summary = "returns true if the operand classifies as NaN";
let description = [{
Determines if the given floating-point number is a not-a-number (NaN)
value.

Example:

```mlir
%f = math.isnan %a : f32
```
}];
}


//===----------------------------------------------------------------------===//
// IsNormalOp
//===----------------------------------------------------------------------===//

def Math_IsNormalOp : Math_FloatClassificationOp<"isnormal"> {
let summary = "returns true if the operand classifies as normal";
let description = [{
Determines if the given floating-point number is normal, i.e. is neither
zero, subnormal, infinite, nor NaN.

Example:

```mlir
%f = math.isnormal %a : f32
```
}];
}

//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 24 additions & 5 deletions mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");

bool isResultBool = op->getResultTypes().front().isInteger(1);
if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value) {
assert(op->getNumOperands() > 0 &&
"expected op to take at least one operand");
assert(op->getResultTypes().front() == op->getOperand(0).getType() &&
assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
isResultBool) &&
"expected op with same operand and result types");
}

Expand All @@ -88,10 +90,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
for (Value operand : adaptor.getOperands())
castedOperands.push_back(maybeCast(operand, rewriter));

Type resultType = castedOperands.front().getType();
Type castedOperandType = castedOperands.front().getType();

// At ABI level, booleans are treated as i32.
Type resultType =
isResultBool ? rewriter.getIntegerType(32) : castedOperandType;
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName = getFunctionName(
cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
StringRef funcName = getFunctionName(castedOperandType, op);
if (funcName.empty())
return failure();

Expand All @@ -104,6 +109,20 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return success();
}

// Boolean result are mapping to i32 at the ABI level with zero values being
// interpreted as false and non-zero values being interpreted as true. Since
// there is no guarantee of a specific value being used to indicate true,
// compare for inequality with zero (rather than truncate or shift).
if (isResultBool) {
Value zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getIntegerType(32),
rewriter.getI32IntegerAttr(0));
Value truncated = rewriter.create<LLVM::ICmpOp>(
op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
rewriter.replaceOp(op, {truncated});
return success();
}

assert(callOp.getResult().getType().isF32() &&
"only f32 types are supposed to be truncated back");
Value truncated = rewriter.create<LLVM::FPTruncOp>(
Expand All @@ -118,7 +137,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
if (!isa<Float16Type, BFloat16Type>(type))
return operand;

// if there's a f16 function, no need to cast f16 values
// If there's an f16 function, no need to cast f16 values.
if (!f16Func.empty() && isa<Float16Type>(type))
return operand;

Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,13 @@ void mlir::populateGpuToNVVMConversionPatterns(
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
"__nv_floor");
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
// Note: libdevice does not provide `__nv_isfinitef` as of moment of writing.
populateOpPatterns<math::IsFiniteOp>(converter, patterns, "",
"__nv_isfinited");
populateOpPatterns<math::IsInfOp>(converter, patterns, "__nv_isinff",
"__nv_isinfd");
populateOpPatterns<math::IsNaNOp>(converter, patterns, "__nv_isnanf",
"__nv_isnand");
populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
"__nv_fast_logf");
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/Math/IR/MathOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
using namespace mlir;
using namespace mlir::math;

//===----------------------------------------------------------------------===//
// Common helpers
//===----------------------------------------------------------------------===//

/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
return shapedType.cloneWith(std::nullopt, i1Type);
if (llvm::isa<UnrankedTensorType>(type))
return UnrankedTensorType::get(i1Type);
return i1Type;
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1058,3 +1058,40 @@ gpu.module @test_module_53 {
func.return %result32, %result64 : f32, f64
}
}

gpu.module @test_module_54 {
// CHECK: llvm.func @__nv_isinff(f32) -> i32
// CHECK: llvm.func @__nv_isinfd(f64) -> i32
// CHECK: llvm.func @__nv_isnanf(f32) -> i32
// CHECK: llvm.func @__nv_isnand(f64) -> i32
// CHECK: llvm.func @__nv_isfinited(f64) -> i32
// CHECK-LABEL: @fpclassify
func.func @fpclassify(%f32: f32, %f64: f64) -> (i1, i1, i1, i1, i1, i1) {
// CHECK: %[[INFF:.+]] = llvm.call @__nv_isinff(%{{.*}}) : (f32) -> i32
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[R0:.+]] = llvm.icmp "ne" %[[INFF]], %[[ZERO]]
%0 = math.isinf %f32 : f32
// CHECK: llvm.call @__nv_isinfd(%{{.*}}) : (f64) -> i32
// CHECK: llvm.mlir.constant(0
// CHECK: llvm.icmp "ne"
%1 = math.isinf %f64 : f64
// CHECK: llvm.call @__nv_isnanf(%{{.*}}) : (f32) -> i32
// CHECK: llvm.mlir.constant(0
// CHECK: llvm.icmp "ne"
%2 = math.isnan %f32 : f32
// CHECK: llvm.call @__nv_isnand(%{{.*}}) : (f64) -> i32
// CHECK: llvm.mlir.constant(0
// CHECK: llvm.icmp "ne"
%3 = math.isnan %f64 : f64
// Note: for some reason, libdevice does not provide isfinite for f32, so
// this should fail to convert.
// CHECK: math.isfinite {{.*}} : f32
%4 = math.isfinite %f32 : f32
// CHECK: llvm.call @__nv_isfinited(%{{.*}}) : (f64) -> i32
// CHECK: llvm.mlir.constant(0
// CHECK: llvm.icmp "ne"
%5 = math.isfinite %f64 : f64
// CHECK: llvm.return %[[R0]]
return %0, %1, %2, %3, %4, %5 : i1, i1, i1, i1, i1, i1
}
}
39 changes: 39 additions & 0 deletions mlir/test/Dialect/Math/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,42 @@ func.func @fastmath(%f: f32, %i: i32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>)
%4 = math.fpowi %f, %i fastmath<fast> : f32, i32
return
}

// CHECK-LABEL: func @fpclassify(
// CHECK-SAME: %[[F:.+]]: f32, %[[D:.+]]: f64,
// CHECK-SAME: %[[V:.+]]: vector<4xf32>, %[[T:.+]]: tensor<4x?xf32>
func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>) {
// CHECK: math.isfinite %[[F]] : f32
// CHECK: math.isfinite %[[D]] : f64
// CHECK: math.isfinite %[[V]] : vector<4xf32>
// CHECK: math.isfinite %[[T]] : tensor<4x?xf32>
math.isfinite %f : f32
math.isfinite %d : f64
math.isfinite %v : vector<4xf32>
math.isfinite %t : tensor<4x?xf32>
// CHECK: math.isinf %[[F]] : f32
// CHECK: math.isinf %[[D]] : f64
// CHECK: math.isinf %[[V]] : vector<4xf32>
// CHECK: math.isinf %[[T]] : tensor<4x?xf32>
math.isinf %f : f32
math.isinf %d : f64
math.isinf %v : vector<4xf32>
math.isinf %t : tensor<4x?xf32>
// CHECK: math.isnan %[[F]] : f32
// CHECK: math.isnan %[[D]] : f64
// CHECK: math.isnan %[[V]] : vector<4xf32>
// CHECK: math.isnan %[[T]] : tensor<4x?xf32>
math.isnan %f : f32
math.isnan %d : f64
math.isnan %v : vector<4xf32>
math.isnan %t : tensor<4x?xf32>
// CHECK: math.isnormal %[[F]] : f32
// CHECK: math.isnormal %[[D]] : f64
// CHECK: math.isnormal %[[V]] : vector<4xf32>
// CHECK: math.isnormal %[[T]] : tensor<4x?xf32>
math.isnormal %f : f32
math.isnormal %d : f64
math.isnormal %v : vector<4xf32>
math.isnormal %t : tensor<4x?xf32>
return
}