Skip to content

Commit 6424cdb

Browse files
jerrymannilroot
authored andcommitted
[ROCm] Improvements to non-vectorized elementwise kernels
1 parent 4686828 commit 6424cdb

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,44 @@ __global__ void elementwise_kernel(int N, func_t f) {
271271
}
272272
}
273273

274+
#ifdef USE_ROCM
275+
template <int nt, int vt, typename func_t>
276+
C10_LAUNCH_BOUNDS_2(nt, 4)
277+
__global__ void elementwise_kernel_manual_unroll(int N, func_t f) {
278+
int tid = threadIdx.x;
279+
int nv = nt * vt;
280+
int idx = nv * blockIdx.x + tid;
281+
if ((idx + nt*(vt-1)) < N) {
282+
f(idx, true);
283+
} else {
284+
#pragma unroll
285+
for (int i = 0; i < vt; i++) {
286+
if (idx < N) {
287+
f(idx, false);
288+
idx += nt;
289+
}
290+
}
291+
}
292+
}
293+
294+
template <int nt, int vt, typename func_t>
295+
C10_LAUNCH_BOUNDS_2(nt, 4)
296+
__global__ void elementwise_kernel_strided(int N, func_t f) {
297+
int tid = threadIdx.x;
298+
int idx = nt * vt * blockIdx.x + tid;
299+
int step = nt * vt * gridDim.x;
300+
while (idx < N) {
301+
#pragma unroll
302+
for (int i = 0; i < vt; i++) {
303+
if ((idx + nt * i) < N) {
304+
f(idx + nt * i);
305+
}
306+
}
307+
idx += step;
308+
}
309+
}
310+
#endif
311+
274312
template <int nt, int vt, typename func_t>
275313
static void launch_legacy_kernel(int64_t N, const func_t& f) {
276314
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
@@ -284,6 +322,37 @@ static void launch_legacy_kernel(int64_t N, const func_t& f) {
284322
C10_CUDA_KERNEL_LAUNCH_CHECK();
285323
}
286324

325+
#ifdef USE_ROCM
326+
template <int nt, int vt, typename func_t>
327+
static void launch_legacy_kernel_manual_unroll(int64_t N, const func_t& f) {
328+
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
329+
if (N == 0) {
330+
return;
331+
}
332+
dim3 block(nt);
333+
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
334+
auto stream = at::cuda::getCurrentCUDAStream();
335+
elementwise_kernel_manual_unroll<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
336+
C10_CUDA_KERNEL_LAUNCH_CHECK();
337+
}
338+
339+
template <int nt, int vt, typename func_t>
340+
static void launch_legacy_kernel_strided(int64_t N, const func_t& f) {
341+
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
342+
if (N == 0) {
343+
return;
344+
}
345+
dim3 block(nt);
346+
dim3 grid(8192);
347+
auto stream = at::cuda::getCurrentCUDAStream();
348+
int ub_idx = nt * vt;
349+
ub_idx = ub_idx * (grid.x - 1) +(block.x - 1);
350+
ub_idx = ub_idx + nt*vt;
351+
elementwise_kernel_strided<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
352+
C10_CUDA_KERNEL_LAUNCH_CHECK();
353+
}
354+
#endif
355+
287356
template <typename traits, typename func_t, typename index_t, size_t... INDEX>
288357
C10_HOST_DEVICE typename traits::result_type invoke_impl(
289358
const func_t& f,
@@ -362,12 +431,84 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
362431
return launch_vectorized_kernel(numel, f, data);
363432
}
364433
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
434+
#ifndef USE_ROCM
365435
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
366436
launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
367437
auto offsets = offset_calc.get(idx);
368438
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
369439
*out = invoke(f, &data[1], &offsets[1], 1);
370440
});
441+
#else
442+
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 4 : 8;
443+
constexpr int grp_sz = 128;
444+
launch_legacy_kernel_manual_unroll<grp_sz, unroll_factor>(numel, [=] GPU_LAMBDA(int idx, bool unrl4x) {
445+
if constexpr (unroll_factor == 4) {
446+
if (unrl4x) {
447+
auto offsets0 = offset_calc.get(idx);
448+
auto offsets1 = offset_calc.get(idx+grp_sz);
449+
auto offsets2 = offset_calc.get(idx+grp_sz*2);
450+
auto offsets3 = offset_calc.get(idx+grp_sz*3);
451+
arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]);
452+
arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]);
453+
arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]);
454+
arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]);
455+
auto tmp0 = invoke(f, &data[1], &offsets0[1], 1);
456+
auto tmp1 = invoke(f, &data[1], &offsets1[1], 1);
457+
auto tmp2 = invoke(f, &data[1], &offsets2[1], 1);
458+
auto tmp3 = invoke(f, &data[1], &offsets3[1], 1);
459+
*out0 = tmp0;
460+
*out1 = tmp1;
461+
*out2 = tmp2;
462+
*out3 = tmp3;
463+
}
464+
else {
465+
auto offsets = offset_calc.get(idx);
466+
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
467+
*out = invoke(f, &data[1], &offsets[1], 1);
468+
}
469+
} else {
470+
if (unrl4x) {
471+
auto offsets0 = offset_calc.get(idx);
472+
auto offsets1 = offset_calc.get(idx+grp_sz);
473+
auto offsets2 = offset_calc.get(idx+grp_sz*2);
474+
auto offsets3 = offset_calc.get(idx+grp_sz*3);
475+
auto offsets4 = offset_calc.get(idx+grp_sz*4);
476+
auto offsets5 = offset_calc.get(idx+grp_sz*5);
477+
auto offsets6 = offset_calc.get(idx+grp_sz*6);
478+
auto offsets7 = offset_calc.get(idx+grp_sz*7);
479+
arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]);
480+
arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]);
481+
arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]);
482+
arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]);
483+
arg0_t* out4 = (arg0_t*)(data[0] + offsets4[0]);
484+
arg0_t* out5 = (arg0_t*)(data[0] + offsets5[0]);
485+
arg0_t* out6 = (arg0_t*)(data[0] + offsets6[0]);
486+
arg0_t* out7 = (arg0_t*)(data[0] + offsets7[0]);
487+
auto tmp0 = invoke(f, &data[1], &offsets0[1], 1);
488+
auto tmp1 = invoke(f, &data[1], &offsets1[1], 1);
489+
auto tmp2 = invoke(f, &data[1], &offsets2[1], 1);
490+
auto tmp3 = invoke(f, &data[1], &offsets3[1], 1);
491+
auto tmp4 = invoke(f, &data[1], &offsets4[1], 1);
492+
auto tmp5 = invoke(f, &data[1], &offsets5[1], 1);
493+
auto tmp6 = invoke(f, &data[1], &offsets6[1], 1);
494+
auto tmp7 = invoke(f, &data[1], &offsets7[1], 1);
495+
*out0 = tmp0;
496+
*out1 = tmp1;
497+
*out2 = tmp2;
498+
*out3 = tmp3;
499+
*out4 = tmp4;
500+
*out5 = tmp5;
501+
*out6 = tmp6;
502+
*out7 = tmp7;
503+
}
504+
else {
505+
auto offsets = offset_calc.get(idx);
506+
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
507+
*out = invoke(f, &data[1], &offsets[1], 1);
508+
}
509+
}
510+
});
511+
#endif
371512
}
372513

373514
template <typename func_t>
@@ -401,7 +542,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
401542
dtypes[i] = iter.dtype(i);
402543
strides[i] = inner_strides[i];
403544
}
404-
launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
545+
launch_legacy_kernel_strided<512, 4>(numel, [=]GPU_LAMBDA(int idx) {
405546
void* out = data[0] + strides[0] * idx;
406547
arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx);
407548
c10::cast_and_store<arg0_t>(dtypes[0], out, result);

0 commit comments

Comments
 (0)