@@ -286,6 +286,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
286
286
}
287
287
};
288
288
289
+ struct IsNaNOpLowering : public ConvertOpToLLVMPattern <math::IsNaNOp> {
290
+ using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
291
+
292
+ LogicalResult
293
+ matchAndRewrite (math::IsNaNOp op, OpAdaptor adaptor,
294
+ ConversionPatternRewriter &rewriter) const override {
295
+ auto operandType = adaptor.getOperand ().getType ();
296
+
297
+ if (!operandType || !LLVM::isCompatibleType (operandType))
298
+ return failure ();
299
+
300
+ rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(op, op.getType (),
301
+ adaptor.getOperand (), 3 );
302
+ return success ();
303
+ }
304
+ };
305
+
306
+ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern <math::IsFiniteOp> {
307
+ using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
308
+
309
+ LogicalResult
310
+ matchAndRewrite (math::IsFiniteOp op, OpAdaptor adaptor,
311
+ ConversionPatternRewriter &rewriter) const override {
312
+ auto operandType = adaptor.getOperand ().getType ();
313
+
314
+ if (!operandType || !LLVM::isCompatibleType (operandType))
315
+ return failure ();
316
+
317
+ rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(op, op.getType (),
318
+ adaptor.getOperand (), 504 );
319
+ return success ();
320
+ }
321
+ };
322
+
289
323
struct ConvertMathToLLVMPass
290
324
: public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
291
325
using Base::Base;
@@ -307,6 +341,8 @@ void mlir::populateMathToLLVMConversionPatterns(
307
341
bool approximateLog1p, PatternBenefit benefit) {
308
342
if (approximateLog1p)
309
343
patterns.add <Log1pOpLowering>(converter, benefit);
344
+ patterns.add <IsNaNOpLowering, IsFiniteOpLowering>(converter);
345
+
310
346
// clang-format off
311
347
patterns.add <
312
348
AbsFOpLowering,
0 commit comments