@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309
309
target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
310
310
target.addLegalDialect <::mlir::NVVM::NVVMDialect>();
311
311
target.addIllegalDialect <gpu::GPUDialect>();
312
- target.addIllegalOp <LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313
- LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314
- LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315
- LLVM::SqrtOp>();
312
+ target.addIllegalOp <LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
313
+ LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
314
+ LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
315
+ LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
316
+ LLVM::SinOp, LLVM::SqrtOp>();
316
317
317
318
// TODO: Remove once we support replacing non-root ops.
318
319
target.addLegalOp <gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
321
322
template <typename OpTy>
322
323
static void populateOpPatterns (LLVMTypeConverter &converter,
323
324
RewritePatternSet &patterns, StringRef f32Func,
324
- StringRef f64Func) {
325
+ StringRef f64Func,
326
+ StringRef f32ApproxFunc = " " ) {
325
327
patterns.add <ScalarizeVectorOpLowering<OpTy>>(converter);
326
- patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
328
+ patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
329
+ f32ApproxFunc);
327
330
}
328
331
329
332
void mlir::populateGpuSubgroupReduceOpLoweringPattern (
@@ -370,42 +373,68 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
370
373
StringAttr::get (&converter.getContext (),
371
374
NVVM::NVVMDialect::getMaxntidAttrName ()));
372
375
376
+ populateOpPatterns<arith::RemFOp>(converter, patterns, " __nv_fmodf" ,
377
+ " __nv_fmod" );
373
378
populateOpPatterns<math::AbsFOp>(converter, patterns, " __nv_fabsf" ,
374
379
" __nv_fabs" );
380
+ populateOpPatterns<math::AcosOp>(converter, patterns, " __nv_acosf" ,
381
+ " __nv_acos" );
382
+ populateOpPatterns<math::AcoshOp>(converter, patterns, " __nv_acoshf" ,
383
+ " __nv_acosh" );
384
+ populateOpPatterns<math::AsinOp>(converter, patterns, " __nv_asinf" ,
385
+ " __nv_asin" );
386
+ populateOpPatterns<math::AsinhOp>(converter, patterns, " __nv_asinhf" ,
387
+ " __nv_asinh" );
375
388
populateOpPatterns<math::AtanOp>(converter, patterns, " __nv_atanf" ,
376
389
" __nv_atan" );
377
390
populateOpPatterns<math::Atan2Op>(converter, patterns, " __nv_atan2f" ,
378
391
" __nv_atan2" );
392
+ populateOpPatterns<math::AtanhOp>(converter, patterns, " __nv_atanhf" ,
393
+ " __nv_atanh" );
379
394
populateOpPatterns<math::CbrtOp>(converter, patterns, " __nv_cbrtf" ,
380
395
" __nv_cbrt" );
381
396
populateOpPatterns<math::CeilOp>(converter, patterns, " __nv_ceilf" ,
382
397
" __nv_ceil" );
383
- populateOpPatterns<math::CosOp>(converter, patterns, " __nv_cosf" , " __nv_cos" );
398
+ populateOpPatterns<math::CopySignOp>(converter, patterns, " __nv_copysignf" ,
399
+ " __nv_copysign" );
400
+ populateOpPatterns<math::CosOp>(converter, patterns, " __nv_cosf" , " __nv_cos" ,
401
+ " __nv_fast_cosf" );
402
+ populateOpPatterns<math::CoshOp>(converter, patterns, " __nv_coshf" ,
403
+ " __nv_cosh" );
384
404
populateOpPatterns<math::ErfOp>(converter, patterns, " __nv_erff" , " __nv_erf" );
385
- populateOpPatterns<math::ExpOp>(converter, patterns, " __nv_expf" , " __nv_exp" );
405
+ populateOpPatterns<math::ExpOp>(converter, patterns, " __nv_expf" , " __nv_exp" ,
406
+ " __nv_fast_expf" );
386
407
populateOpPatterns<math::Exp2Op>(converter, patterns, " __nv_exp2f" ,
387
408
" __nv_exp2" );
388
409
populateOpPatterns<math::ExpM1Op>(converter, patterns, " __nv_expm1f" ,
389
410
" __nv_expm1" );
390
411
populateOpPatterns<math::FloorOp>(converter, patterns, " __nv_floorf" ,
391
412
" __nv_floor" );
392
- populateOpPatterns<arith::RemFOp>(converter, patterns, " __nv_fmodf" ,
393
- " __nv_fmod" );
394
- populateOpPatterns<math::LogOp>(converter, patterns, " __nv_logf" , " __nv_log" );
413
+ populateOpPatterns<math::FmaOp>(converter, patterns, " __nv_fmaf" , " __nv_fma" );
414
+ populateOpPatterns<math::LogOp>(converter, patterns, " __nv_logf" , " __nv_log" ,
415
+ " __nv_fast_logf" );
416
+ populateOpPatterns<math::Log10Op>(converter, patterns, " __nv_log10f" ,
417
+ " __nv_log10" , " __nv_fast_log10f" );
395
418
populateOpPatterns<math::Log1pOp>(converter, patterns, " __nv_log1pf" ,
396
419
" __nv_log1p" );
397
- populateOpPatterns<math::Log10Op>(converter, patterns, " __nv_log10f" ,
398
- " __nv_log10" );
399
420
populateOpPatterns<math::Log2Op>(converter, patterns, " __nv_log2f" ,
400
- " __nv_log2" );
401
- populateOpPatterns<math::PowFOp>(converter, patterns, " __nv_powf" ,
402
- " __nv_pow" );
421
+ " __nv_log2" , " __nv_fast_log2f" );
422
+ populateOpPatterns<math::PowFOp>(converter, patterns, " __nv_powf" , " __nv_pow" ,
423
+ " __nv_fast_powf" );
424
+ populateOpPatterns<math::RoundOp>(converter, patterns, " __nv_roundf" ,
425
+ " __nv_round" );
426
+ populateOpPatterns<math::RoundEvenOp>(converter, patterns, " __nv_rintf" ,
427
+ " __nv_rint" );
403
428
populateOpPatterns<math::RsqrtOp>(converter, patterns, " __nv_rsqrtf" ,
404
429
" __nv_rsqrt" );
405
- populateOpPatterns<math::SinOp>(converter, patterns, " __nv_sinf" , " __nv_sin" );
430
+ populateOpPatterns<math::SinOp>(converter, patterns, " __nv_sinf" , " __nv_sin" ,
431
+ " __nv_fast_sinf" );
432
+ populateOpPatterns<math::SinhOp>(converter, patterns, " __nv_sinhf" ,
433
+ " __nv_sinh" );
406
434
populateOpPatterns<math::SqrtOp>(converter, patterns, " __nv_sqrtf" ,
407
435
" __nv_sqrt" );
436
+ populateOpPatterns<math::TanOp>(converter, patterns, " __nv_tanf" , " __nv_tan" ,
437
+ " __nv_fast_tanf" );
408
438
populateOpPatterns<math::TanhOp>(converter, patterns, " __nv_tanhf" ,
409
439
" __nv_tanh" );
410
- populateOpPatterns<math::TanOp>(converter, patterns, " __nv_tanf" , " __nv_tan" );
411
440
}
0 commit comments