Skip to content

Commit 6686d55

Browse files
jerrymannilpruthvistony
authored andcommitted
[ROCm] Improvements for vectorized elementwise kernels (pytorch#143269) (#1874)
* Make io_size calculation as minimum of size of input and output size, rather than the summation of all sizes * for e.g, for torch.add() on half dtypes (bfloat16/float16), calc_io_size() returns 6 causing elems_per_thread to be 4 * But elems_per_thread = 8 works better on half datypes for AMD gpus * Enable *_load_dwordx4 ISA for 16-bit and 8-bit dtypes on AMD gpus by using vector size of 8 and 16 respectively Co-author: @akadutta Pull Request resolved: pytorch#143269 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony Co-authored-by: Pruthvi Madugundu <[email protected]> (cherry picked from commit 4686828)
1 parent c743d68 commit 6686d55

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

aten/src/ATen/native/cuda/CUDAJitLoops.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ struct JittedVecKernelCache {
4949
at::cuda::jit::NvrtcFunction vec1;
5050
at::cuda::jit::NvrtcFunction vec2;
5151
at::cuda::jit::NvrtcFunction vec4;
52-
at::cuda::jit::NvrtcFunction vec8;
5352
#ifdef USE_ROCM
53+
at::cuda::jit::NvrtcFunction vec8;
5454
at::cuda::jit::NvrtcFunction vec16;
5555
#endif
5656

@@ -150,11 +150,11 @@ void launch_jitted_vectorized_kernel(
150150
#ifdef USE_ROCM
151151
if (vec_size == 16) {
152152
fn_ptr = &fn_cache.vec16;
153+
} else if (vec_size == 8) {
154+
fn_ptr = &fn_cache.vec8;
153155
} else
154156
#endif
155-
if (vec_size == 8) {
156-
fn_ptr = &fn_cache.vec8;
157-
} else if (vec_size == 4) {
157+
if (vec_size == 4) {
158158
fn_ptr = &fn_cache.vec4;
159159
} else if (vec_size == 2) {
160160
fn_ptr = &fn_cache.vec2;

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,12 @@ static inline void launch_vectorized_kernel(
240240
<<<grid, num_threads(), 0, stream>>>(N, f, data);
241241
C10_CUDA_KERNEL_LAUNCH_CHECK();
242242
break;
243-
#endif
244243
case 8:
245244
vectorized_elementwise_kernel<8, func_t, array_t>
246245
<<<grid, num_threads(), 0, stream>>>(N, f, data);
247246
C10_CUDA_KERNEL_LAUNCH_CHECK();
248247
break;
248+
#endif
249249
case 4:
250250
vectorized_elementwise_kernel<4, func_t, array_t>
251251
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/jit_utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
6060
if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) {
6161
return 8;
6262
}
63-
#else
6463
if (ip % (8 * default_alignment) == 0) {
6564
return 8;
6665
}

0 commit comments

Comments
 (0)