Skip to content

Commit 4686828

Browse files
[ROCm] Improvements for vectorized elementwise kernels (pytorch#143269) (#1874)
* Make io_size calculation as minimum of size of input and output size, rather than the summation of all sizes * for e.g, for torch.add() on half dtypes (bfloat16/float16), calc_io_size() returns 6 causing elems_per_thread to be 4 * But elems_per_thread = 8 works better on half datypes for AMD gpus * Enable *_load_dwordx4 ISA for 16-bit and 8-bit dtypes on AMD gpus by using vector size of 8 and 16 respectively Co-author: @akadutta Pull Request resolved: pytorch#143269 Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony Co-authored-by: Pruthvi Madugundu <[email protected]>
1 parent f7ad58f commit 4686828

File tree

7 files changed

+255
-36
lines changed

7 files changed

+255
-36
lines changed

aten/src/ATen/cuda/jiterator.cu

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
1616
DeviceIndex dev_idx, int64_t N, const std::string& f, const void* data_ptr,
1717
const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {
1818
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
19+
20+
int nInputs = iter.ninputs();
21+
int nOutputs = iter.noutputs();
22+
const at::ScalarType common_dtype = iter.common_dtype();
23+
24+
int tws = at::cuda::jit::calc_thread_work_size(nInputs, nOutputs, common_dtype, common_dtype);
25+
int vec_size = jitted_can_vectorize_up_to(iter);
26+
27+
int bws = tws * num_threads();
1928
// N is still int64_t for the computation, but it's always safe to cast result to int
20-
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
29+
const uint32_t grid = (N + bws - 1) / bws;
2130

22-
const int vec_size = jitted_can_vectorize_up_to(iter);
2331
bool vectorized = vec_size > 1;
2432

2533
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
2634
// fn_ptr is set to the appropriate function based on the vec size and GPU used
2735
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
2836
// the same compute capability
2937

30-
int nInputs = iter.ninputs();
31-
int nOutputs = iter.noutputs();
32-
const at::ScalarType common_dtype = iter.common_dtype();
3338
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
3439
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
3540
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
@@ -59,6 +64,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
5964
/*contiguous=*/true, /*dynamic_casting=*/false,
6065
at::cuda::jit::BinaryFuncVariant::NoScalar,
6166
extra_args_types,
67+
tws,
6268
vectorized, vec_size,
6369
return_by_ref);
6470
std::string kernel_name = vectorized ? name + "_vectorized" + std::to_string(vec_size) : name;
@@ -121,12 +127,16 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
121127
const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {
122128

123129
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
124-
//casting result to int is always safe, intermediate is int64 and won't overflow
125-
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
126130

127131
int nInputs = iter.ninputs();
128132
int nOutputs = iter.noutputs();
129133
const at::ScalarType common_dtype = iter.common_dtype();
134+
135+
int tws = at::cuda::jit::calc_thread_work_size(nInputs, nOutputs, common_dtype, common_dtype);
136+
int bws = tws * num_threads();
137+
//casting result to int is always safe, intermediate is int64 and won't overflow
138+
const uint32_t grid = (N + bws - 1) / bws;
139+
130140
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
131141
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
132142
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
@@ -153,7 +163,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
153163
f_inputs_type_str, compute_type_str, result_type_str,
154164
contiguous, dynamic_casting,
155165
at::cuda::jit::BinaryFuncVariant::NoScalar,
156-
extra_args_types, /*vectorized*/false, /*vec_size*/0, return_by_ref);
166+
extra_args_types, tws, /*vectorized*/false, /*vec_size*/0, return_by_ref);
157167
*fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
158168
}
159169
}

aten/src/ATen/native/cuda/CUDAJitLoops.cuh

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ struct JittedVecKernelCache {
5050
at::cuda::jit::NvrtcFunction vec1;
5151
at::cuda::jit::NvrtcFunction vec2;
5252
at::cuda::jit::NvrtcFunction vec4;
53+
#ifdef USE_ROCM
54+
at::cuda::jit::NvrtcFunction vec8;
55+
at::cuda::jit::NvrtcFunction vec16;
56+
#endif
57+
5358
};
5459

5560
struct JittedKernelVariantCache {
@@ -89,16 +94,19 @@ void launch_jitted_unrolled_kernel(
8994
c10::ArrayRef<const void*> extra_args) {
9095

9196
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
97+
98+
int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
99+
int bws = tws * num_threads();
92100
//casting result to int is always safe, intermediate is int64 and won't overflow
93-
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
101+
const uint32_t grid = (N + bws - 1) / bws;
94102

95103
if (!fn_cache.function) {
96104
const std::lock_guard<std::mutex> lock{jiterator_mutex};
97105
if (!fn_cache.function) {
98106
constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
99107
!std::is_same<decltype(s), memory::StoreWithoutCast>();
100108
auto code = at::cuda::jit::generate_code(
101-
desc, contiguous, dynamic_casting, scalar_pos);
109+
desc, contiguous, dynamic_casting, scalar_pos, tws);
102110
fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
103111
}
104112
}
@@ -115,14 +123,26 @@ void launch_jitted_vectorized_kernel(
115123
at::cuda::jit::BinaryFuncVariant scalar_pos,
116124
const void *scalar_val, c10::ArrayRef<const void*> extra_args) {
117125
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
126+
127+
int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
128+
int bws = tws * num_threads();
118129
// N is still int64_t for the computation, but it's always safe to cast result to int
119-
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
120-
const int vec_size = at::cuda::jit::can_vectorize_up_to(
130+
const uint32_t grid = (N + bws - 1) / bws;
131+
132+
int vec_size = at::cuda::jit::can_vectorize_up_to(
121133
desc, c10::ArrayRef<char*>(data.data(), data.size()));
122134

123135
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
124136
// fn_ptr is set to the appropriate function based on the vec size and GPU used
125137
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;
138+
139+
#ifdef USE_ROCM
140+
if (vec_size == 16) {
141+
fn_ptr = &fn_cache.vec16;
142+
} else if (vec_size == 8) {
143+
fn_ptr = &fn_cache.vec8;
144+
} else
145+
#endif
126146
if (vec_size == 4) {
127147
fn_ptr = &fn_cache.vec4;
128148
} else if (vec_size == 2) {
@@ -142,7 +162,7 @@ void launch_jitted_vectorized_kernel(
142162
// Generates program
143163
auto code = at::cuda::jit::generate_code(
144164
desc, /*contiguous=*/true, /*dynamic_casting=*/false,
145-
scalar_pos, vectorized, vec_size);
165+
scalar_pos, tws, vectorized, vec_size);
146166
std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
147167

148168
// Acquires the program

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,49 @@ constexpr auto io_block_work_size() {
7878
return num_threads() * elems_per_thread<io_sizes>();
7979
}
8080

81+
#ifdef USE_ROCM
82+
template <typename args_t, size_t... Is>
83+
constexpr auto input_size(args_t args, std::index_sequence<Is...>) {
84+
if constexpr (sizeof...(Is) == 0) {
85+
return 0;
86+
} else {
87+
return sizeof(std::tuple_element_t<0, args_t>);
88+
}
89+
}
90+
91+
template <int vec_size, int io_size>
92+
constexpr auto calc_optimal_vec_size() {
93+
static_assert(vec_size != 0);
94+
static_assert(io_size != 0);
95+
if constexpr (io_size == 1 && vec_size >= 16) {
96+
return 16;
97+
} else if constexpr (io_size <= 2 && vec_size >= 8) {
98+
return 8;
99+
} else if constexpr (io_size <= 4 && vec_size >= 4) {
100+
return 4;
101+
} else if constexpr (vec_size >= 4) {
102+
return 4;
103+
} else if constexpr (vec_size >= 2) {
104+
return 2;
105+
} else {
106+
return 1;
107+
}
108+
}
109+
#endif
110+
81111
template <typename func_t>
82112
constexpr auto calc_io_size(){
83113
using traits = function_traits<func_t>;
84114
using args_t = typename traits::ArgsTuple;
115+
#ifdef USE_ROCM
116+
constexpr auto input_size = at::native::input_size(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
117+
constexpr auto output_size = sizeof(typename traits::result_type);
118+
return (input_size > 0) ? ((input_size < output_size) ? input_size : output_size) : output_size;
119+
#else
85120
constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
86121
constexpr auto output_size = sizeof(typename traits::result_type);
87122
return input_size + output_size;
123+
#endif
88124
}
89125

90126
template <int vec_size, typename func_t, typename array_t>
@@ -111,8 +147,13 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
111147
elementwise_kernel_helper(f, policy);
112148
} else { // if this block has a full `block_work_size` data to handle, use
113149
// vectorized memory access
150+
#ifdef USE_ROCM
151+
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
152+
#else
153+
constexpr auto optimal_vec_size = vec_size;
154+
#endif
114155
elementwise_kernel_helper(
115-
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
156+
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
116157
}
117158
}
118159

@@ -154,6 +195,18 @@ static inline void launch_vectorized_kernel(
154195
int vec_size = memory::can_vectorize_up_to<func_t>(data);
155196

156197
switch (vec_size) {
198+
#ifdef USE_ROCM
199+
case 16:
200+
vectorized_elementwise_kernel<16, func_t, array_t>
201+
<<<grid, num_threads(), 0, stream>>>(N, f, data);
202+
C10_CUDA_KERNEL_LAUNCH_CHECK();
203+
break;
204+
case 8:
205+
vectorized_elementwise_kernel<8, func_t, array_t>
206+
<<<grid, num_threads(), 0, stream>>>(N, f, data);
207+
C10_CUDA_KERNEL_LAUNCH_CHECK();
208+
break;
209+
#endif
157210
case 4:
158211
vectorized_elementwise_kernel<4, func_t, array_t>
159212
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/Dropout.cu

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
5050
at::cuda::detail::TensorInfo<mask_t, IndexType> c,
5151
IndexType totalElements, accscalar_t p,
5252
PhiloxCudaState philox_args) {
53-
// make sure we don't break assumption that we can't have > 4 elements / thread
54-
static_assert(VEC <= 4, "Value of VEC must be in [2, 4]");
55-
5653
using LoadT = memory::aligned_vector<scalar_t, VEC>;
5754
using MaskLoadT = memory::aligned_vector<mask_t, VEC>;
5855

@@ -66,7 +63,8 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
6663
bool gridxvec_loop_state = 0;
6764
accscalar_t scale = 1.0 / p;
6865

69-
float4 rand;
66+
constexpr int RAND_SIZE = (VEC + 4 - 1) / 4;
67+
float4 rand[RAND_SIZE];
7068

7169
// Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
7270
for (IndexType linearIndex = idx * VEC;
@@ -80,20 +78,31 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
8078
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
8179
// Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
8280
// sets of rand.
83-
if ((VEC == 4) || (gridxvec_loop_state == 0)) {
84-
rand = curand_uniform4(&state);
81+
if ((VEC >= 4) || (gridxvec_loop_state == 0)) {
82+
#pragma unroll
83+
for (int ii = 0; ii < RAND_SIZE; ii++) {
84+
rand[ii] = curand_uniform4(&state);
85+
}
8586
} else {
8687
// sets up the last two values we generated last iteration to be used this iteration.
87-
rand.x = rand.z;
88-
rand.y = rand.w;
88+
rand[0].x = rand[0].z;
89+
rand[0].y = rand[0].w;
8990
gridxvec_loop_state ^= 1;
9091
}
9192

92-
rand.x = rand.x < p;
93-
rand.y = rand.y < p;
94-
if (VEC == 4) {
95-
rand.z = rand.z < p;
96-
rand.w = rand.w < p;
93+
rand[0].x = rand[0].x < p;
94+
rand[0].y = rand[0].y < p;
95+
if constexpr (VEC >= 4) {
96+
rand[0].z = rand[0].z < p;
97+
rand[0].w = rand[0].w < p;
98+
}
99+
100+
#pragma unroll
101+
for (int ii = 1; ii < RAND_SIZE; ii++) {
102+
rand[ii].x = rand[ii].x < p;
103+
rand[ii].y = rand[ii].y < p;
104+
rand[ii].z = rand[ii].z < p;
105+
rand[ii].w = rand[ii].w < p;
97106
}
98107

99108
// Note: We explicitly check for is_contiguous() before launching the vectorized kernel
@@ -107,10 +116,14 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
107116

108117
// Perform the actual computation
109118
#pragma unroll
110-
for (int ii = 0; ii < VEC; ii++) {
111-
r[ii] = src[ii]*(&rand.x)[ii]*scale;
112-
mask[ii] = (mask_t)(&rand.x)[ii];
119+
for (int jj = 0; jj < RAND_SIZE; jj++) {
120+
#pragma unroll
121+
for (int ii = 0; ii < std::min(VEC, 4); ii++) {
122+
r[jj * 4 + ii] = src[jj * 4 + ii]*(&rand[jj].x)[ii]*scale;
123+
mask[jj * 4 + ii] = (mask_t)(&rand[jj].x)[ii];
124+
}
113125
}
126+
114127
// Vectorized writes for both mask & result
115128
*(reinterpret_cast<LoadT*>(&b.data[linearIndex])) = *reinterpret_cast<LoadT*>(&r[0]);
116129
*(reinterpret_cast<MaskLoadT*>(&c.data[linearIndex])) = *reinterpret_cast<MaskLoadT*>(&mask[0]);
@@ -200,6 +213,13 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
200213
vec_size = 1;
201214
} else {
202215
vec_size = memory::can_vectorize_up_to<scalar_t>((const char*)self.const_data_ptr());
216+
#ifdef USE_ROCM
217+
// make sure we don't break assumption that we can't have > 16 elements / thread
218+
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
219+
#else
220+
// make sure we don't break assumption that we can't have > 4 elements / thread
221+
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
222+
#endif
203223
}
204224

205225
// check that we'd have no remainders - prefer a smaller vector size with no remainders over a larger vector and remainder.
@@ -244,6 +264,38 @@ inline void launcher(
244264

245265
if (vec_size > 1) {
246266
switch (vec_size) {
267+
case 16:
268+
fused_dropout_kernel_vec<
269+
scalar_t,
270+
accscalar_t,
271+
index_type,
272+
1,
273+
16>
274+
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
275+
self_info,
276+
ret_info,
277+
mask_info,
278+
nelem,
279+
pa,
280+
rng_engine_inputs);
281+
C10_CUDA_KERNEL_LAUNCH_CHECK();
282+
break;
283+
case 8:
284+
fused_dropout_kernel_vec<
285+
scalar_t,
286+
accscalar_t,
287+
index_type,
288+
1,
289+
8>
290+
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
291+
self_info,
292+
ret_info,
293+
mask_info,
294+
nelem,
295+
pa,
296+
rng_engine_inputs);
297+
C10_CUDA_KERNEL_LAUNCH_CHECK();
298+
break;
247299
case 4:
248300
fused_dropout_kernel_vec<
249301
scalar_t,
@@ -276,6 +328,8 @@ inline void launcher(
276328
rng_engine_inputs);
277329
C10_CUDA_KERNEL_LAUNCH_CHECK();
278330
break;
331+
default:
332+
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
279333
}
280334
} else {
281335
switch (self_info.dims) {

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,16 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
351351
uint64_t address = reinterpret_cast<uint64_t>(pointer);
352352
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
353353
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
354+
#ifdef USE_ROCM
355+
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
356+
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
357+
constexpr int type_size = sizeof(scalar_t);
358+
if (type_size == 1 && (address % vec16_alignment == 0)) {
359+
return 16;
360+
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
361+
return 8;
362+
} else
363+
#endif
354364
if (address % vec4_alignment == 0) {
355365
return 4;
356366
} else if (address % vec2_alignment == 0) {

0 commit comments

Comments
 (0)