@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309
309
target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
310
310
target.addLegalDialect <::mlir::NVVM::NVVMDialect>();
311
311
target.addIllegalDialect <gpu::GPUDialect>();
312
- target.addIllegalOp <LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313
- LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314
- LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315
- LLVM::SqrtOp>();
312
+ target.addIllegalOp <LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
313
+ LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
314
+ LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
315
+ LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
316
+ LLVM::SinOp, LLVM::SqrtOp>();
316
317
317
318
// TODO: Remove once we support replacing non-root ops.
318
319
target.addLegalOp <gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,10 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
321
322
template <typename OpTy>
322
323
static void populateOpPatterns (LLVMTypeConverter &converter,
323
324
RewritePatternSet &patterns, StringRef f32Func,
324
- StringRef f64Func) {
325
+ StringRef f64Func, StringRef f32FastFunc = " " ) {
325
326
patterns.add <ScalarizeVectorOpLowering<OpTy>>(converter);
326
- patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
327
+ patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
328
+ f32FastFunc);
327
329
}
328
330
329
331
void mlir::populateGpuSubgroupReduceOpLoweringPattern (
@@ -370,42 +372,68 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
370
372
StringAttr::get (&converter.getContext (),
371
373
NVVM::NVVMDialect::getMaxntidAttrName ()));
372
374
375
+ populateOpPatterns<arith::RemFOp>(converter, patterns, " __nv_fmodf" ,
376
+ " __nv_fmod" );
373
377
populateOpPatterns<math::AbsFOp>(converter, patterns, " __nv_fabsf" ,
374
378
" __nv_fabs" );
379
+ populateOpPatterns<math::AcosOp>(converter, patterns, " __nv_acosf" ,
380
+ " __nv_acos" );
381
+ populateOpPatterns<math::AcoshOp>(converter, patterns, " __nv_acoshf" ,
382
+ " __nv_acosh" );
383
+ populateOpPatterns<math::AsinOp>(converter, patterns, " __nv_asinf" ,
384
+ " __nv_asin" );
385
+ populateOpPatterns<math::AsinhOp>(converter, patterns, " __nv_asinhf" ,
386
+ " __nv_asinh" );
375
387
populateOpPatterns<math::AtanOp>(converter, patterns, " __nv_atanf" ,
376
388
" __nv_atan" );
377
389
populateOpPatterns<math::Atan2Op>(converter, patterns, " __nv_atan2f" ,
378
390
" __nv_atan2" );
391
+ populateOpPatterns<math::AtanhOp>(converter, patterns, " __nv_atanhf" ,
392
+ " __nv_atanh" );
379
393
populateOpPatterns<math::CbrtOp>(converter, patterns, " __nv_cbrtf" ,
380
394
" __nv_cbrt" );
381
395
populateOpPatterns<math::CeilOp>(converter, patterns, " __nv_ceilf" ,
382
396
" __nv_ceil" );
383
- populateOpPatterns<math::CosOp>(converter, patterns, " __nv_cosf" , " __nv_cos" );
397
+ populateOpPatterns<math::CopySignOp>(converter, patterns, " __nv_copysignf" ,
398
+ " __nv_copysign" );
399
+ populateOpPatterns<math::CosOp>(converter, patterns, " __nv_cosf" , " __nv_cos" ,
400
+ " __nv_fast_cosf" );
401
+ populateOpPatterns<math::CoshOp>(converter, patterns, " __nv_coshf" ,
402
+ " __nv_cosh" );
384
403
populateOpPatterns<math::ErfOp>(converter, patterns, " __nv_erff" , " __nv_erf" );
385
- populateOpPatterns<math::ExpOp>(converter, patterns, " __nv_expf" , " __nv_exp" );
404
+ populateOpPatterns<math::ExpOp>(converter, patterns, " __nv_expf" , " __nv_exp" ,
405
+ " __nv_fast_expf" );
386
406
populateOpPatterns<math::Exp2Op>(converter, patterns, " __nv_exp2f" ,
387
407
" __nv_exp2" );
388
408
populateOpPatterns<math::ExpM1Op>(converter, patterns, " __nv_expm1f" ,
389
409
" __nv_expm1" );
390
410
populateOpPatterns<math::FloorOp>(converter, patterns, " __nv_floorf" ,
391
411
" __nv_floor" );
392
- populateOpPatterns<arith::RemFOp>(converter, patterns, " __nv_fmodf" ,
393
- " __nv_fmod" );
394
- populateOpPatterns<math::LogOp>(converter, patterns, " __nv_logf" , " __nv_log" );
412
+ populateOpPatterns<math::FmaOp>(converter, patterns, " __nv_fmaf" , " __nv_fma" );
413
+ populateOpPatterns<math::LogOp>(converter, patterns, " __nv_logf" , " __nv_log" ,
414
+ " __nv_fast_logf" );
415
+ populateOpPatterns<math::Log10Op>(converter, patterns, " __nv_log10f" ,
416
+ " __nv_log10" , " __nv_fast_log10f" );
395
417
populateOpPatterns<math::Log1pOp>(converter, patterns, " __nv_log1pf" ,
396
418
" __nv_log1p" );
397
- populateOpPatterns<math::Log10Op>(converter, patterns, " __nv_log10f" ,
398
- " __nv_log10" );
399
419
populateOpPatterns<math::Log2Op>(converter, patterns, " __nv_log2f" ,
400
- " __nv_log2" );
401
- populateOpPatterns<math::PowFOp>(converter, patterns, " __nv_powf" ,
402
- " __nv_pow" );
420
+ " __nv_log2" , " __nv_fast_log2f" );
421
+ populateOpPatterns<math::PowFOp>(converter, patterns, " __nv_powf" , " __nv_pow" ,
422
+ " __nv_fast_powf" );
423
+ populateOpPatterns<math::RoundOp>(converter, patterns, " __nv_roundf" ,
424
+ " __nv_round" );
425
+ populateOpPatterns<math::RoundEvenOp>(converter, patterns, " __nv_rintf" ,
426
+ " __nv_rint" );
403
427
populateOpPatterns<math::RsqrtOp>(converter, patterns, " __nv_rsqrtf" ,
404
428
" __nv_rsqrt" );
405
- populateOpPatterns<math::SinOp>(converter, patterns, " __nv_sinf" , " __nv_sin" );
429
+ populateOpPatterns<math::SinOp>(converter, patterns, " __nv_sinf" , " __nv_sin" ,
430
+ " __nv_fast_sinf" );
431
+ populateOpPatterns<math::SinhOp>(converter, patterns, " __nv_sinhf" ,
432
+ " __nv_sinh" );
406
433
populateOpPatterns<math::SqrtOp>(converter, patterns, " __nv_sqrtf" ,
407
434
" __nv_sqrt" );
435
+ populateOpPatterns<math::TanOp>(converter, patterns, " __nv_tanf" , " __nv_tan" ,
436
+ " __nv_fast_tanf" );
408
437
populateOpPatterns<math::TanhOp>(converter, patterns, " __nv_tanhf" ,
409
438
" __nv_tanh" );
410
- populateOpPatterns<math::TanOp>(converter, patterns, " __nv_tanf" , " __nv_tan" );
411
439
}
0 commit comments