Skip to content

Commit f0134e6

Browse files
wsmosesivanradanov
andauthored
[MLIR][Math] Add lowering for isnan and isfinite (#128125)
Co-authored-by: Ivan R. Ivanov <[email protected]>
1 parent cc675c6 commit f0134e6

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "mlir/IR/TypeUtilities.h"
1919
#include "mlir/Pass/Pass.h"
2020

21+
#include "llvm/ADT/FloatingPointMode.h"
22+
2123
namespace mlir {
2224
#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
2325
#include "mlir/Conversion/Passes.h.inc"
@@ -286,6 +288,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
286288
}
287289
};
288290

291+
struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
292+
using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
293+
294+
LogicalResult
295+
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
296+
ConversionPatternRewriter &rewriter) const override {
297+
auto operandType = adaptor.getOperand().getType();
298+
299+
if (!operandType || !LLVM::isCompatibleType(operandType))
300+
return failure();
301+
302+
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
303+
op, op.getType(), adaptor.getOperand(), llvm::fcNan);
304+
return success();
305+
}
306+
};
307+
308+
struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
309+
using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
310+
311+
LogicalResult
312+
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
313+
ConversionPatternRewriter &rewriter) const override {
314+
auto operandType = adaptor.getOperand().getType();
315+
316+
if (!operandType || !LLVM::isCompatibleType(operandType))
317+
return failure();
318+
319+
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
320+
op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
321+
return success();
322+
}
323+
};
324+
289325
struct ConvertMathToLLVMPass
290326
: public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
291327
using Base::Base;
@@ -309,6 +345,8 @@ void mlir::populateMathToLLVMConversionPatterns(
309345
patterns.add<Log1pOpLowering>(converter, benefit);
310346
// clang-format off
311347
patterns.add<
348+
IsNaNOpLowering,
349+
IsFiniteOpLowering,
312350
AbsFOpLowering,
313351
AbsIOpLowering,
314352
CeilOpLowering,

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,26 @@ func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
263263

264264
// -----
265265

266+
// CHECK-LABEL: func @isnan_double(
267+
// CHECK-SAME: f64
268+
func.func @isnan_double(%arg0 : f64) {
269+
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 3 : i32}> : (f64) -> i1
270+
%0 = math.isnan %arg0 : f64
271+
func.return
272+
}
273+
274+
// -----
275+
276+
// CHECK-LABEL: func @isfinite_double(
277+
// CHECK-SAME: f64
278+
func.func @isfinite_double(%arg0 : f64) {
279+
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 504 : i32}> : (f64) -> i1
280+
%0 = math.isfinite %arg0 : f64
281+
func.return
282+
}
283+
284+
// -----
285+
266286
// CHECK-LABEL: func @rsqrt_double(
267287
// CHECK-SAME: f64
268288
func.func @rsqrt_double(%arg0 : f64) {

0 commit comments

Comments
 (0)