Skip to content

Commit d4c5100

Browse files
committed
[MLIR][Math] Add lowering for isnan and isfinite
1 parent 22f5268 commit d4c5100

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
286286
}
287287
};
288288

289+
struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
290+
using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
291+
292+
LogicalResult
293+
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
294+
ConversionPatternRewriter &rewriter) const override {
295+
auto operandType = adaptor.getOperand().getType();
296+
297+
if (!operandType || !LLVM::isCompatibleType(operandType))
298+
return failure();
299+
300+
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(op, op.getType(),
301+
adaptor.getOperand(), 3);
302+
return success();
303+
}
304+
};
305+
306+
struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
307+
using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
308+
309+
LogicalResult
310+
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
311+
ConversionPatternRewriter &rewriter) const override {
312+
auto operandType = adaptor.getOperand().getType();
313+
314+
if (!operandType || !LLVM::isCompatibleType(operandType))
315+
return failure();
316+
317+
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(op, op.getType(),
318+
adaptor.getOperand(), 504);
319+
return success();
320+
}
321+
};
322+
289323
struct ConvertMathToLLVMPass
290324
: public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
291325
using Base::Base;
@@ -307,6 +341,8 @@ void mlir::populateMathToLLVMConversionPatterns(
307341
bool approximateLog1p, PatternBenefit benefit) {
308342
if (approximateLog1p)
309343
patterns.add<Log1pOpLowering>(converter, benefit);
344+
patterns.add<IsNaNOpLowering, IsFiniteOpLowering>(converter);
345+
310346
// clang-format off
311347
patterns.add<
312348
AbsFOpLowering,

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,26 @@ func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
263263

264264
// -----
265265

266+
// CHECK-LABEL: func @isnan_double(
267+
// CHECK-SAME: f64
268+
func.func @isnan_double(%arg0 : f64) {
269+
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 3 : i32}> : (f64) -> i1
270+
%0 = math.isnan %arg0 : f64
271+
func.return
272+
}
273+
274+
// -----
275+
276+
// CHECK-LABEL: func @isfinite_double(
277+
// CHECK-SAME: f64
278+
func.func @isfinite_double(%arg0 : f64) {
279+
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 504 : i32}> : (f64) -> i1
280+
%0 = math.isfinite %arg0 : f64
281+
func.return
282+
}
283+
284+
// -----
285+
266286
// CHECK-LABEL: func @rsqrt_double(
267287
// CHECK-SAME: f64
268288
func.func @rsqrt_double(%arg0 : f64) {

0 commit comments

Comments
 (0)