@@ -445,6 +445,44 @@ __global__ void elementwise_kernel(int N, func_t f) {
445
445
}
446
446
}
447
447
448
+ #ifdef USE_ROCM
449
+ template <int nt, int vt, typename func_t >
450
+ C10_LAUNCH_BOUNDS_2 (nt, 4 )
451
+ __global__ void elementwise_kernel_manual_unroll(int N, func_t f) {
452
+ int tid = threadIdx .x ;
453
+ int nv = nt * vt;
454
+ int idx = nv * blockIdx .x + tid;
455
+ if ((idx + nt*(vt-1 )) < N) {
456
+ f (idx, true );
457
+ } else {
458
+ #pragma unroll
459
+ for (int i = 0 ; i < vt; i++) {
460
+ if (idx < N) {
461
+ f (idx, false );
462
+ idx += nt;
463
+ }
464
+ }
465
+ }
466
+ }
467
+
468
+ template <int nt, int vt, typename func_t >
469
+ C10_LAUNCH_BOUNDS_2 (nt, 4 )
470
+ __global__ void elementwise_kernel_strided(int N, func_t f) {
471
+ int tid = threadIdx .x ;
472
+ int idx = nt * vt * blockIdx .x + tid;
473
+ int step = nt * vt * gridDim .x ;
474
+ while (idx < N) {
475
+ #pragma unroll
476
+ for (int i = 0 ; i < vt; i++) {
477
+ if ((idx + nt * i) < N) {
478
+ f (idx + nt * i);
479
+ }
480
+ }
481
+ idx += step;
482
+ }
483
+ }
484
+ #endif
485
+
448
486
template <int nt, int vt, typename func_t >
449
487
static void launch_legacy_kernel (int64_t N, const func_t & f) {
450
488
TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
@@ -458,6 +496,37 @@ static void launch_legacy_kernel(int64_t N, const func_t& f) {
458
496
C10_CUDA_KERNEL_LAUNCH_CHECK ();
459
497
}
460
498
499
+ #ifdef USE_ROCM
500
+ template <int nt, int vt, typename func_t >
501
+ static void launch_legacy_kernel_manual_unroll (int64_t N, const func_t & f) {
502
+ TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
503
+ if (N == 0 ) {
504
+ return ;
505
+ }
506
+ dim3 block (nt);
507
+ dim3 grid ((N + block.x * vt - 1 ) / (block.x * vt));
508
+ auto stream = at::cuda::getCurrentCUDAStream ();
509
+ elementwise_kernel_manual_unroll<nt, vt, func_t ><<<grid, block, 0 , stream>>> (N, f);
510
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
511
+ }
512
+
513
+ template <int nt, int vt, typename func_t >
514
+ static void launch_legacy_kernel_strided (int64_t N, const func_t & f) {
515
+ TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
516
+ if (N == 0 ) {
517
+ return ;
518
+ }
519
+ dim3 block (nt);
520
+ dim3 grid (8192 );
521
+ auto stream = at::cuda::getCurrentCUDAStream ();
522
+ int ub_idx = nt * vt;
523
+ ub_idx = ub_idx * (grid.x - 1 ) +(block.x - 1 );
524
+ ub_idx = ub_idx + nt*vt;
525
+ elementwise_kernel_strided<nt, vt, func_t ><<<grid, block, 0 , stream>>> (N, f);
526
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
527
+ }
528
+ #endif
529
+
461
530
template <typename traits, typename func_t , typename index_t , size_t ... INDEX>
462
531
C10_HOST_DEVICE typename traits::result_type invoke_impl (
463
532
const func_t & f,
@@ -536,12 +605,84 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
536
605
return launch_vectorized_kernel (numel, f, data);
537
606
}
538
607
auto offset_calc = ::make_offset_calculator<traits::arity + 1 >(iter);
608
+ #ifndef USE_ROCM
539
609
constexpr int unroll_factor = sizeof (arg0_t ) >= 4 ? 2 : 4 ;
540
610
launch_legacy_kernel<128 , unroll_factor>(numel, [=] GPU_LAMBDA (int idx) {
541
611
auto offsets = offset_calc.get (idx);
542
612
arg0_t * out = (arg0_t *)(data[0 ] + offsets[0 ]);
543
613
*out = invoke (f, &data[1 ], &offsets[1 ], 1 );
544
614
});
615
+ #else
616
+ constexpr int unroll_factor = sizeof (arg0_t ) >= 4 ? 4 : 8 ;
617
+ constexpr int grp_sz = 128 ;
618
+ launch_legacy_kernel_manual_unroll<grp_sz, unroll_factor>(numel, [=] GPU_LAMBDA (int idx, bool unrl4x) {
619
+ if constexpr (unroll_factor == 4 ) {
620
+ if (unrl4x) {
621
+ auto offsets0 = offset_calc.get (idx);
622
+ auto offsets1 = offset_calc.get (idx+grp_sz);
623
+ auto offsets2 = offset_calc.get (idx+grp_sz*2 );
624
+ auto offsets3 = offset_calc.get (idx+grp_sz*3 );
625
+ arg0_t * out0 = (arg0_t *)(data[0 ] + offsets0[0 ]);
626
+ arg0_t * out1 = (arg0_t *)(data[0 ] + offsets1[0 ]);
627
+ arg0_t * out2 = (arg0_t *)(data[0 ] + offsets2[0 ]);
628
+ arg0_t * out3 = (arg0_t *)(data[0 ] + offsets3[0 ]);
629
+ auto tmp0 = invoke (f, &data[1 ], &offsets0[1 ], 1 );
630
+ auto tmp1 = invoke (f, &data[1 ], &offsets1[1 ], 1 );
631
+ auto tmp2 = invoke (f, &data[1 ], &offsets2[1 ], 1 );
632
+ auto tmp3 = invoke (f, &data[1 ], &offsets3[1 ], 1 );
633
+ *out0 = tmp0;
634
+ *out1 = tmp1;
635
+ *out2 = tmp2;
636
+ *out3 = tmp3;
637
+ }
638
+ else {
639
+ auto offsets = offset_calc.get (idx);
640
+ arg0_t * out = (arg0_t *)(data[0 ] + offsets[0 ]);
641
+ *out = invoke (f, &data[1 ], &offsets[1 ], 1 );
642
+ }
643
+ } else {
644
+ if (unrl4x) {
645
+ auto offsets0 = offset_calc.get (idx);
646
+ auto offsets1 = offset_calc.get (idx+grp_sz);
647
+ auto offsets2 = offset_calc.get (idx+grp_sz*2 );
648
+ auto offsets3 = offset_calc.get (idx+grp_sz*3 );
649
+ auto offsets4 = offset_calc.get (idx+grp_sz*4 );
650
+ auto offsets5 = offset_calc.get (idx+grp_sz*5 );
651
+ auto offsets6 = offset_calc.get (idx+grp_sz*6 );
652
+ auto offsets7 = offset_calc.get (idx+grp_sz*7 );
653
+ arg0_t * out0 = (arg0_t *)(data[0 ] + offsets0[0 ]);
654
+ arg0_t * out1 = (arg0_t *)(data[0 ] + offsets1[0 ]);
655
+ arg0_t * out2 = (arg0_t *)(data[0 ] + offsets2[0 ]);
656
+ arg0_t * out3 = (arg0_t *)(data[0 ] + offsets3[0 ]);
657
+ arg0_t * out4 = (arg0_t *)(data[0 ] + offsets4[0 ]);
658
+ arg0_t * out5 = (arg0_t *)(data[0 ] + offsets5[0 ]);
659
+ arg0_t * out6 = (arg0_t *)(data[0 ] + offsets6[0 ]);
660
+ arg0_t * out7 = (arg0_t *)(data[0 ] + offsets7[0 ]);
661
+ auto tmp0 = invoke (f, &data[1 ], &offsets0[1 ], 1 );
662
+ auto tmp1 = invoke (f, &data[1 ], &offsets1[1 ], 1 );
663
+ auto tmp2 = invoke (f, &data[1 ], &offsets2[1 ], 1 );
664
+ auto tmp3 = invoke (f, &data[1 ], &offsets3[1 ], 1 );
665
+ auto tmp4 = invoke (f, &data[1 ], &offsets4[1 ], 1 );
666
+ auto tmp5 = invoke (f, &data[1 ], &offsets5[1 ], 1 );
667
+ auto tmp6 = invoke (f, &data[1 ], &offsets6[1 ], 1 );
668
+ auto tmp7 = invoke (f, &data[1 ], &offsets7[1 ], 1 );
669
+ *out0 = tmp0;
670
+ *out1 = tmp1;
671
+ *out2 = tmp2;
672
+ *out3 = tmp3;
673
+ *out4 = tmp4;
674
+ *out5 = tmp5;
675
+ *out6 = tmp6;
676
+ *out7 = tmp7;
677
+ }
678
+ else {
679
+ auto offsets = offset_calc.get (idx);
680
+ arg0_t * out = (arg0_t *)(data[0 ] + offsets[0 ]);
681
+ *out = invoke (f, &data[1 ], &offsets[1 ], 1 );
682
+ }
683
+ }
684
+ });
685
+ #endif
545
686
}
546
687
547
688
#ifdef USE_ROCM
@@ -759,7 +900,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
759
900
dtypes[i] = iter.dtype (i);
760
901
strides[i] = inner_strides[i];
761
902
}
762
- launch_legacy_kernel <512 , 1 >(numel, [=]GPU_LAMBDA (int idx) {
903
+ launch_legacy_kernel_strided <512 , 4 >(numel, [=]GPU_LAMBDA (int idx) {
763
904
void * out = data[0 ] + strides[0 ] * idx;
764
905
arg0_t result = invoke (f, &data[1 ], &strides[1 ], &dtypes[1 ], idx);
765
906
c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
0 commit comments