Skip to content

[rocm6.4_internal_testing] [ROCm] Improvements for vectorized elementwise kernels (#143269) #1874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions aten/src/ATen/cuda/jiterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
DeviceIndex dev_idx, int64_t N, const std::string& f, const void* data_ptr,
const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());

int nInputs = iter.ninputs();
int nOutputs = iter.noutputs();
const at::ScalarType common_dtype = iter.common_dtype();

int tws = at::cuda::jit::calc_thread_work_size(nInputs, nOutputs, common_dtype, common_dtype);
int vec_size = jitted_can_vectorize_up_to(iter);

int bws = tws * num_threads();
// N is still int64_t for the computation, but it's always safe to cast result to int
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
const uint32_t grid = (N + bws - 1) / bws;

const int vec_size = jitted_can_vectorize_up_to(iter);
bool vectorized = vec_size > 1;

// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// fn_ptr is set to the appropriate function based on the vec size and GPU used
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
// the same compute capability

int nInputs = iter.ninputs();
int nOutputs = iter.noutputs();
const at::ScalarType common_dtype = iter.common_dtype();
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
Expand Down Expand Up @@ -59,6 +64,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
/*contiguous=*/true, /*dynamic_casting=*/false,
at::cuda::jit::BinaryFuncVariant::NoScalar,
extra_args_types,
tws,
vectorized, vec_size,
return_by_ref);
std::string kernel_name = vectorized ? name + "_vectorized" + std::to_string(vec_size) : name;
Expand Down Expand Up @@ -121,12 +127,16 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
const c10::SmallVector<at::Scalar>& extra_args, bool return_by_ref) {

TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
//casting result to int is always safe, intermediate is int64 and won't overflow
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();

int nInputs = iter.ninputs();
int nOutputs = iter.noutputs();
const at::ScalarType common_dtype = iter.common_dtype();

int tws = at::cuda::jit::calc_thread_work_size(nInputs, nOutputs, common_dtype, common_dtype);
int bws = tws * num_threads();
//casting result to int is always safe, intermediate is int64 and won't overflow
const uint32_t grid = (N + bws - 1) / bws;

std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
std::string compute_type_str = at::cuda::jit::typeName(toOpMathType(common_dtype));
std::string result_type_str = at::cuda::jit::typeName(common_dtype);
Expand All @@ -153,7 +163,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
f_inputs_type_str, compute_type_str, result_type_str,
contiguous, dynamic_casting,
at::cuda::jit::BinaryFuncVariant::NoScalar,
extra_args_types, /*vectorized*/false, /*vec_size*/0, return_by_ref);
extra_args_types, tws, /*vectorized*/false, /*vec_size*/0, return_by_ref);
*fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
}
}
Expand Down
30 changes: 25 additions & 5 deletions aten/src/ATen/native/cuda/CUDAJitLoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ struct JittedVecKernelCache {
at::cuda::jit::NvrtcFunction vec1;
at::cuda::jit::NvrtcFunction vec2;
at::cuda::jit::NvrtcFunction vec4;
#ifdef USE_ROCM
at::cuda::jit::NvrtcFunction vec8;
at::cuda::jit::NvrtcFunction vec16;
#endif

};

struct JittedKernelVariantCache {
Expand Down Expand Up @@ -89,16 +94,19 @@ void launch_jitted_unrolled_kernel(
c10::ArrayRef<const void*> extra_args) {

TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());

int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
int bws = tws * num_threads();
//casting result to int is always safe, intermediate is int64 and won't overflow
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
const uint32_t grid = (N + bws - 1) / bws;

if (!fn_cache.function) {
const std::lock_guard<std::mutex> lock{jiterator_mutex};
if (!fn_cache.function) {
constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
!std::is_same<decltype(s), memory::StoreWithoutCast>();
auto code = at::cuda::jit::generate_code(
desc, contiguous, dynamic_casting, scalar_pos);
desc, contiguous, dynamic_casting, scalar_pos, tws);
fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
}
}
Expand All @@ -115,14 +123,26 @@ void launch_jitted_vectorized_kernel(
at::cuda::jit::BinaryFuncVariant scalar_pos,
const void *scalar_val, c10::ArrayRef<const void*> extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());

int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
int bws = tws * num_threads();
// N is still int64_t for the computation, but it's always safe to cast result to int
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
const int vec_size = at::cuda::jit::can_vectorize_up_to(
const uint32_t grid = (N + bws - 1) / bws;

int vec_size = at::cuda::jit::can_vectorize_up_to(
desc, c10::ArrayRef<char*>(data.data(), data.size()));

// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// fn_ptr is set to the appropriate function based on the vec size and GPU used
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;

#ifdef USE_ROCM
if (vec_size == 16) {
fn_ptr = &fn_cache.vec16;
} else if (vec_size == 8) {
fn_ptr = &fn_cache.vec8;
} else
#endif
if (vec_size == 4) {
fn_ptr = &fn_cache.vec4;
} else if (vec_size == 2) {
Expand All @@ -142,7 +162,7 @@ void launch_jitted_vectorized_kernel(
// Generates program
auto code = at::cuda::jit::generate_code(
desc, /*contiguous=*/true, /*dynamic_casting=*/false,
scalar_pos, vectorized, vec_size);
scalar_pos, tws, vectorized, vec_size);
std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;

// Acquires the program
Expand Down
55 changes: 54 additions & 1 deletion aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,49 @@ constexpr auto io_block_work_size() {
return num_threads() * elems_per_thread<io_sizes>();
}

#ifdef USE_ROCM
template <typename args_t, size_t... Is>
constexpr auto input_size(args_t args, std::index_sequence<Is...>) {
if constexpr (sizeof...(Is) == 0) {
return 0;
} else {
return sizeof(std::tuple_element_t<0, args_t>);
}
}

template <int vec_size, int io_size>
constexpr auto calc_optimal_vec_size() {
static_assert(vec_size != 0);
static_assert(io_size != 0);
if constexpr (io_size == 1 && vec_size >= 16) {
return 16;
} else if constexpr (io_size <= 2 && vec_size >= 8) {
return 8;
} else if constexpr (io_size <= 4 && vec_size >= 4) {
return 4;
} else if constexpr (vec_size >= 4) {
return 4;
} else if constexpr (vec_size >= 2) {
return 2;
} else {
return 1;
}
}
#endif

template <typename func_t>
constexpr auto calc_io_size(){
using traits = function_traits<func_t>;
using args_t = typename traits::ArgsTuple;
#ifdef USE_ROCM
constexpr auto input_size = at::native::input_size(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
constexpr auto output_size = sizeof(typename traits::result_type);
return (input_size > 0) ? ((input_size < output_size) ? input_size : output_size) : output_size;
#else
constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
constexpr auto output_size = sizeof(typename traits::result_type);
return input_size + output_size;
#endif
}

template <int vec_size, typename func_t, typename array_t>
Expand All @@ -111,8 +147,13 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
#ifdef USE_ROCM
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
#else
constexpr auto optimal_vec_size = vec_size;
#endif
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
}
}

Expand Down Expand Up @@ -154,6 +195,18 @@ static inline void launch_vectorized_kernel(
int vec_size = memory::can_vectorize_up_to<func_t>(data);

switch (vec_size) {
#ifdef USE_ROCM
case 16:
vectorized_elementwise_kernel<16, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 8:
vectorized_elementwise_kernel<8, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
case 4:
vectorized_elementwise_kernel<4, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
Expand Down
86 changes: 70 additions & 16 deletions aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
at::cuda::detail::TensorInfo<mask_t, IndexType> c,
IndexType totalElements, accscalar_t p,
PhiloxCudaState philox_args) {
// make sure we don't break assumption that we can't have > 4 elements / thread
static_assert(VEC <= 4, "Value of VEC must be in [2, 4]");

using LoadT = memory::aligned_vector<scalar_t, VEC>;
using MaskLoadT = memory::aligned_vector<mask_t, VEC>;

Expand All @@ -66,7 +63,8 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
bool gridxvec_loop_state = 0;
accscalar_t scale = 1.0 / p;

float4 rand;
constexpr int RAND_SIZE = (VEC + 4 - 1) / 4;
float4 rand[RAND_SIZE];

// Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
for (IndexType linearIndex = idx * VEC;
Expand All @@ -80,20 +78,31 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
// Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
// sets of rand.
if ((VEC == 4) || (gridxvec_loop_state == 0)) {
rand = curand_uniform4(&state);
if ((VEC >= 4) || (gridxvec_loop_state == 0)) {
#pragma unroll
for (int ii = 0; ii < RAND_SIZE; ii++) {
rand[ii] = curand_uniform4(&state);
}
} else {
// sets up the last two values we generated last iteration to be used this iteration.
rand.x = rand.z;
rand.y = rand.w;
rand[0].x = rand[0].z;
rand[0].y = rand[0].w;
gridxvec_loop_state ^= 1;
}

rand.x = rand.x < p;
rand.y = rand.y < p;
if (VEC == 4) {
rand.z = rand.z < p;
rand.w = rand.w < p;
rand[0].x = rand[0].x < p;
rand[0].y = rand[0].y < p;
if constexpr (VEC >= 4) {
rand[0].z = rand[0].z < p;
rand[0].w = rand[0].w < p;
}

#pragma unroll
for (int ii = 1; ii < RAND_SIZE; ii++) {
rand[ii].x = rand[ii].x < p;
rand[ii].y = rand[ii].y < p;
rand[ii].z = rand[ii].z < p;
rand[ii].w = rand[ii].w < p;
}

// Note: We explicitly check for is_contiguous() before launching the vectorized kernel
Expand All @@ -107,10 +116,14 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>

// Perform the actual computation
#pragma unroll
for (int ii = 0; ii < VEC; ii++) {
r[ii] = src[ii]*(&rand.x)[ii]*scale;
mask[ii] = (mask_t)(&rand.x)[ii];
for (int jj = 0; jj < RAND_SIZE; jj++) {
#pragma unroll
for (int ii = 0; ii < std::min(VEC, 4); ii++) {
r[jj * 4 + ii] = src[jj * 4 + ii]*(&rand[jj].x)[ii]*scale;
mask[jj * 4 + ii] = (mask_t)(&rand[jj].x)[ii];
}
}

// Vectorized writes for both mask & result
*(reinterpret_cast<LoadT*>(&b.data[linearIndex])) = *reinterpret_cast<LoadT*>(&r[0]);
*(reinterpret_cast<MaskLoadT*>(&c.data[linearIndex])) = *reinterpret_cast<MaskLoadT*>(&mask[0]);
Expand Down Expand Up @@ -200,6 +213,13 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
vec_size = 1;
} else {
vec_size = memory::can_vectorize_up_to<scalar_t>((const char*)self.const_data_ptr());
#ifdef USE_ROCM
// make sure we don't break assumption that we can't have > 16 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
#else
// make sure we don't break assumption that we can't have > 4 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
#endif
}

// check that we'd have no remainders - prefer a smaller vector size with no remainders over a larger vector and remainder.
Expand Down Expand Up @@ -244,6 +264,38 @@ inline void launcher(

if (vec_size > 1) {
switch (vec_size) {
case 16:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
16>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 8:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
8>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 4:
fused_dropout_kernel_vec<
scalar_t,
Expand Down Expand Up @@ -276,6 +328,8 @@ inline void launcher(
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
default:
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
}
} else {
switch (self_info.dims) {
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
#ifdef USE_ROCM
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
constexpr int type_size = sizeof(scalar_t);
if (type_size == 1 && (address % vec16_alignment == 0)) {
return 16;
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
return 8;
} else
#endif
if (address % vec4_alignment == 0) {
return 4;
} else if (address % vec2_alignment == 0) {
Expand Down
Loading