@@ -309,11 +309,10 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
309
309
target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
310
310
target.addLegalDialect <::mlir::NVVM::NVVMDialect>();
311
311
target.addIllegalDialect <gpu::GPUDialect>();
312
- target.addIllegalOp <LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
313
- LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
314
- LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
315
- LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
316
- LLVM::SinOp, LLVM::SqrtOp>();
312
+ target.addIllegalOp <LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315
+ LLVM::SqrtOp>();
317
316
318
317
// TODO: Remove once we support replacing non-root ops.
319
318
target.addLegalOp <gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -322,11 +321,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
322
321
template <typename OpTy>
323
322
static void populateOpPatterns (LLVMTypeConverter &converter,
324
323
RewritePatternSet &patterns, StringRef f32Func,
325
- StringRef f64Func,
326
- StringRef f32ApproxFunc = " " ) {
324
+ StringRef f64Func) {
327
325
patterns.add <ScalarizeVectorOpLowering<OpTy>>(converter);
328
- patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
329
- f32ApproxFunc);
326
+ patterns.add <OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
330
327
}
331
328
332
329
void mlir::populateGpuSubgroupReduceOpLoweringPattern (
@@ -373,68 +370,42 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
373
370
StringAttr::get (&converter.getContext (),
374
371
NVVM::NVVMDialect::getMaxntidAttrName ()));
375
372
376
- populateOpPatterns<arith::RemFOp>(converter, patterns, " __nv_fmodf" ,
377
- " __nv_fmod" );
378
373
populateOpPatterns<math::AbsFOp>(converter, patterns, " __nv_fabsf" ,
379
374
" __nv_fabs" );
380
- populateOpPatterns<math::AcosOp>(converter, patterns, " __nv_acosf" ,
381
- " __nv_acos" );
382
- populateOpPatterns<math::AcoshOp>(converter, patterns, " __nv_acoshf" ,
383
- " __nv_acosh" );
384
- populateOpPatterns<math::AsinOp>(converter, patterns, " __nv_asinf" ,
385
- " __nv_asin" );
386
- populateOpPatterns<math::AsinhOp>(converter, patterns, " __nv_asinhf" ,
387
- " __nv_asinh" );
388
375
populateOpPatterns<math::AtanOp>(converter, patterns, " __nv_atanf" ,
389
376
" __nv_atan" );
390
377
populateOpPatterns<math::Atan2Op>(converter, patterns, " __nv_atan2f" ,
391
378
" __nv_atan2" );
392
- populateOpPatterns<math::AtanhOp>(converter, patterns, " __nv_atanhf" ,
393
- " __nv_atanh" );
394
379
populateOpPatterns<math::CbrtOp>(converter, patterns, " __nv_cbrtf" ,
395
380
" __nv_cbrt" );
396
381
populateOpPatterns<math::CeilOp>(converter, patterns, " __nv_ceilf" ,
397
382
" __nv_ceil" );
398
- populateOpPatterns<math::CopySignOp>(converter, patterns, " __nv_copysignf" ,
399
- " __nv_copysign" );
400
- populateOpPatterns<math::CosOp>(converter, patterns, " __nv_cosf" , " __nv_cos" ,
401
- " __nv_fast_cosf" );
402
- populateOpPatterns<math::CoshOp>(converter, patterns, " __nv_coshf" ,
403
- " __nv_cosh" );
383
+ populateOpPatterns<math::CosOp>(converter, patterns, " __nv_cosf" , " __nv_cos" );
404
384
populateOpPatterns<math::ErfOp>(converter, patterns, " __nv_erff" , " __nv_erf" );
405
- populateOpPatterns<math::ExpOp>(converter, patterns, " __nv_expf" , " __nv_exp" ,
406
- " __nv_fast_expf" );
385
+ populateOpPatterns<math::ExpOp>(converter, patterns, " __nv_expf" , " __nv_exp" );
407
386
populateOpPatterns<math::Exp2Op>(converter, patterns, " __nv_exp2f" ,
408
387
" __nv_exp2" );
409
388
populateOpPatterns<math::ExpM1Op>(converter, patterns, " __nv_expm1f" ,
410
389
" __nv_expm1" );
411
390
populateOpPatterns<math::FloorOp>(converter, patterns, " __nv_floorf" ,
412
391
" __nv_floor" );
413
- populateOpPatterns<math::FmaOp>(converter, patterns, " __nv_fmaf" , " __nv_fma" );
414
- populateOpPatterns<math::LogOp>(converter, patterns, " __nv_logf" , " __nv_log" ,
415
- " __nv_fast_logf" );
416
- populateOpPatterns<math::Log10Op>(converter, patterns, " __nv_log10f" ,
417
- " __nv_log10" , " __nv_fast_log10f" );
392
+ populateOpPatterns<arith::RemFOp>(converter, patterns, " __nv_fmodf" ,
393
+ " __nv_fmod" );
394
+ populateOpPatterns<math::LogOp>(converter, patterns, " __nv_logf" , " __nv_log" );
418
395
populateOpPatterns<math::Log1pOp>(converter, patterns, " __nv_log1pf" ,
419
396
" __nv_log1p" );
397
+ populateOpPatterns<math::Log10Op>(converter, patterns, " __nv_log10f" ,
398
+ " __nv_log10" );
420
399
populateOpPatterns<math::Log2Op>(converter, patterns, " __nv_log2f" ,
421
- " __nv_log2" , " __nv_fast_log2f" );
422
- populateOpPatterns<math::PowFOp>(converter, patterns, " __nv_powf" , " __nv_pow" ,
423
- " __nv_fast_powf" );
424
- populateOpPatterns<math::RoundOp>(converter, patterns, " __nv_roundf" ,
425
- " __nv_round" );
426
- populateOpPatterns<math::RoundEvenOp>(converter, patterns, " __nv_rintf" ,
427
- " __nv_rint" );
400
+ " __nv_log2" );
401
+ populateOpPatterns<math::PowFOp>(converter, patterns, " __nv_powf" ,
402
+ " __nv_pow" );
428
403
populateOpPatterns<math::RsqrtOp>(converter, patterns, " __nv_rsqrtf" ,
429
404
" __nv_rsqrt" );
430
- populateOpPatterns<math::SinOp>(converter, patterns, " __nv_sinf" , " __nv_sin" ,
431
- " __nv_fast_sinf" );
432
- populateOpPatterns<math::SinhOp>(converter, patterns, " __nv_sinhf" ,
433
- " __nv_sinh" );
405
+ populateOpPatterns<math::SinOp>(converter, patterns, " __nv_sinf" , " __nv_sin" );
434
406
populateOpPatterns<math::SqrtOp>(converter, patterns, " __nv_sqrtf" ,
435
407
" __nv_sqrt" );
436
- populateOpPatterns<math::TanOp>(converter, patterns, " __nv_tanf" , " __nv_tan" ,
437
- " __nv_fast_tanf" );
438
408
populateOpPatterns<math::TanhOp>(converter, patterns, " __nv_tanhf" ,
439
409
" __nv_tanh" );
410
+ populateOpPatterns<math::TanOp>(converter, patterns, " __nv_tanf" , " __nv_tan" );
440
411
}
0 commit comments