Skip to content

Commit 08ceb35

Browse files
jerrymannildnikolaev-amd
authored andcommitted
[ROCm] Improvements to non-vectorized elementwise kernels (#1875)
* Unroll loops manually to hide memory access latency * Strided access for coalesced memory acesses Co-authors: @akadutta @doru1004 @amd-hhashemi @carlobertolli (cherry picked from commit 2e48656)
1 parent 6686d55 commit 08ceb35

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,44 @@ __global__ void elementwise_kernel(int N, func_t f) {
445445
}
446446
}
447447

448+
#ifdef USE_ROCM
449+
template <int nt, int vt, typename func_t>
450+
C10_LAUNCH_BOUNDS_2(nt, 4)
451+
__global__ void elementwise_kernel_manual_unroll(int N, func_t f) {
452+
int tid = threadIdx.x;
453+
int nv = nt * vt;
454+
int idx = nv * blockIdx.x + tid;
455+
if ((idx + nt*(vt-1)) < N) {
456+
f(idx, true);
457+
} else {
458+
#pragma unroll
459+
for (int i = 0; i < vt; i++) {
460+
if (idx < N) {
461+
f(idx, false);
462+
idx += nt;
463+
}
464+
}
465+
}
466+
}
467+
468+
template <int nt, int vt, typename func_t>
469+
C10_LAUNCH_BOUNDS_2(nt, 4)
470+
__global__ void elementwise_kernel_strided(int N, func_t f) {
471+
int tid = threadIdx.x;
472+
int idx = nt * vt * blockIdx.x + tid;
473+
int step = nt * vt * gridDim.x;
474+
while (idx < N) {
475+
#pragma unroll
476+
for (int i = 0; i < vt; i++) {
477+
if ((idx + nt * i) < N) {
478+
f(idx + nt * i);
479+
}
480+
}
481+
idx += step;
482+
}
483+
}
484+
#endif
485+
448486
template <int nt, int vt, typename func_t>
449487
static void launch_legacy_kernel(int64_t N, const func_t& f) {
450488
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
@@ -458,6 +496,37 @@ static void launch_legacy_kernel(int64_t N, const func_t& f) {
458496
C10_CUDA_KERNEL_LAUNCH_CHECK();
459497
}
460498

499+
#ifdef USE_ROCM
500+
template <int nt, int vt, typename func_t>
501+
static void launch_legacy_kernel_manual_unroll(int64_t N, const func_t& f) {
502+
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
503+
if (N == 0) {
504+
return;
505+
}
506+
dim3 block(nt);
507+
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
508+
auto stream = at::cuda::getCurrentCUDAStream();
509+
elementwise_kernel_manual_unroll<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
510+
C10_CUDA_KERNEL_LAUNCH_CHECK();
511+
}
512+
513+
template <int nt, int vt, typename func_t>
514+
static void launch_legacy_kernel_strided(int64_t N, const func_t& f) {
515+
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
516+
if (N == 0) {
517+
return;
518+
}
519+
dim3 block(nt);
520+
dim3 grid(8192);
521+
auto stream = at::cuda::getCurrentCUDAStream();
522+
int ub_idx = nt * vt;
523+
ub_idx = ub_idx * (grid.x - 1) +(block.x - 1);
524+
ub_idx = ub_idx + nt*vt;
525+
elementwise_kernel_strided<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
526+
C10_CUDA_KERNEL_LAUNCH_CHECK();
527+
}
528+
#endif
529+
461530
template <typename traits, typename func_t, typename index_t, size_t... INDEX>
462531
C10_HOST_DEVICE typename traits::result_type invoke_impl(
463532
const func_t& f,
@@ -536,12 +605,84 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
536605
return launch_vectorized_kernel(numel, f, data);
537606
}
538607
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
608+
#ifndef USE_ROCM
539609
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
540610
launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
541611
auto offsets = offset_calc.get(idx);
542612
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
543613
*out = invoke(f, &data[1], &offsets[1], 1);
544614
});
615+
#else
616+
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 4 : 8;
617+
constexpr int grp_sz = 128;
618+
launch_legacy_kernel_manual_unroll<grp_sz, unroll_factor>(numel, [=] GPU_LAMBDA(int idx, bool unrl4x) {
619+
if constexpr (unroll_factor == 4) {
620+
if (unrl4x) {
621+
auto offsets0 = offset_calc.get(idx);
622+
auto offsets1 = offset_calc.get(idx+grp_sz);
623+
auto offsets2 = offset_calc.get(idx+grp_sz*2);
624+
auto offsets3 = offset_calc.get(idx+grp_sz*3);
625+
arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]);
626+
arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]);
627+
arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]);
628+
arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]);
629+
auto tmp0 = invoke(f, &data[1], &offsets0[1], 1);
630+
auto tmp1 = invoke(f, &data[1], &offsets1[1], 1);
631+
auto tmp2 = invoke(f, &data[1], &offsets2[1], 1);
632+
auto tmp3 = invoke(f, &data[1], &offsets3[1], 1);
633+
*out0 = tmp0;
634+
*out1 = tmp1;
635+
*out2 = tmp2;
636+
*out3 = tmp3;
637+
}
638+
else {
639+
auto offsets = offset_calc.get(idx);
640+
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
641+
*out = invoke(f, &data[1], &offsets[1], 1);
642+
}
643+
} else {
644+
if (unrl4x) {
645+
auto offsets0 = offset_calc.get(idx);
646+
auto offsets1 = offset_calc.get(idx+grp_sz);
647+
auto offsets2 = offset_calc.get(idx+grp_sz*2);
648+
auto offsets3 = offset_calc.get(idx+grp_sz*3);
649+
auto offsets4 = offset_calc.get(idx+grp_sz*4);
650+
auto offsets5 = offset_calc.get(idx+grp_sz*5);
651+
auto offsets6 = offset_calc.get(idx+grp_sz*6);
652+
auto offsets7 = offset_calc.get(idx+grp_sz*7);
653+
arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]);
654+
arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]);
655+
arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]);
656+
arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]);
657+
arg0_t* out4 = (arg0_t*)(data[0] + offsets4[0]);
658+
arg0_t* out5 = (arg0_t*)(data[0] + offsets5[0]);
659+
arg0_t* out6 = (arg0_t*)(data[0] + offsets6[0]);
660+
arg0_t* out7 = (arg0_t*)(data[0] + offsets7[0]);
661+
auto tmp0 = invoke(f, &data[1], &offsets0[1], 1);
662+
auto tmp1 = invoke(f, &data[1], &offsets1[1], 1);
663+
auto tmp2 = invoke(f, &data[1], &offsets2[1], 1);
664+
auto tmp3 = invoke(f, &data[1], &offsets3[1], 1);
665+
auto tmp4 = invoke(f, &data[1], &offsets4[1], 1);
666+
auto tmp5 = invoke(f, &data[1], &offsets5[1], 1);
667+
auto tmp6 = invoke(f, &data[1], &offsets6[1], 1);
668+
auto tmp7 = invoke(f, &data[1], &offsets7[1], 1);
669+
*out0 = tmp0;
670+
*out1 = tmp1;
671+
*out2 = tmp2;
672+
*out3 = tmp3;
673+
*out4 = tmp4;
674+
*out5 = tmp5;
675+
*out6 = tmp6;
676+
*out7 = tmp7;
677+
}
678+
else {
679+
auto offsets = offset_calc.get(idx);
680+
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
681+
*out = invoke(f, &data[1], &offsets[1], 1);
682+
}
683+
}
684+
});
685+
#endif
545686
}
546687

547688
#ifdef USE_ROCM
@@ -759,7 +900,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
759900
dtypes[i] = iter.dtype(i);
760901
strides[i] = inner_strides[i];
761902
}
762-
launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
903+
launch_legacy_kernel_strided<512, 4>(numel, [=]GPU_LAMBDA(int idx) {
763904
void* out = data[0] + strides[0] * idx;
764905
arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx);
765906
c10::cast_and_store<arg0_t>(dtypes[0], out, result);

0 commit comments

Comments
 (0)