Skip to content

Commit 97d23f6

Browse files
micmelessejeffdaily
authored andcommitted
[ROCM] Navi21 Enablement 9: Range and Multinomial Kernels (pytorch#73550)
Summary: This PR is a follow up to the following prs. pytorch#69942 pytorch#72682 pytorch#72809 pytorch#73543 pytorch#73545 pytorch#73546 pytorch#73548 pytorch#73549 We are adding support to Navi21 GPUs which have a warpsize of 32. We cannot rely on a constant so we have to dynamically look up the warpsize when launching the kernel on the host side. Inside device functions this is not needed and the compiler can correctly detect the correct warpsize to replace the C10_WARP_SIZE constant. Pull Request resolved: pytorch#73550 Reviewed By: malfet Differential Revision: D35444958 Pulled By: ngimel fbshipit-source-id: c65f06d3227c23bb097a71fc6c86e3f884114e04 (cherry picked from commit 7f3ba52)
1 parent 73da8b8 commit 97d23f6

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

aten/src/ATen/native/cuda/MultinomialKernel.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,13 @@ void renormRows(Tensor& t) {
7474
const int64_t maxThreads = std::min(
7575
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
7676

77+
int warp_size = at::cuda::warp_size();
7778
dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
78-
dim3 block(std::min(maxThreads, C10_WARP_SIZE * ceil_div(cols, int64_t{C10_WARP_SIZE})));
79+
dim3 block(std::min(maxThreads, warp_size * ceil_div(cols, int64_t{warp_size})));
7980

8081
AT_DISPATCH_FLOATING_TYPES_AND_HALF(t.scalar_type(), "renormRows_cuda", [&] {
8182
renormRowsL1<scalar_t>
82-
<<<grid, block, (block.x / C10_WARP_SIZE) * sizeof(scalar_t),
83+
<<<grid, block, (block.x / warp_size) * sizeof(scalar_t),
8384
at::cuda::getCurrentCUDAStream()>>>(t.data_ptr<scalar_t>(),
8485
rows, cols);
8586
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -335,8 +336,9 @@ void multinomial_with_replacement_kernel_impl(
335336
int maxThreads = props->maxThreadsPerBlock;
336337
int maxShared = props->sharedMemPerBlock;
337338
338-
int requiredWarps = at::ceil_div(numCategories, C10_WARP_SIZE);
339-
int requiredThreads = std::min(maxThreads, requiredWarps * C10_WARP_SIZE);
339+
int warp_size = at::cuda::warp_size();
340+
int requiredWarps = at::ceil_div(numCategories, warp_size);
341+
int requiredThreads = std::min(maxThreads, requiredWarps * warp_size);
340342
int requiredShared = requiredThreads * sizeof(accscalar_t);
341343
342344
if (n_sample == 1 && maxShared >= requiredShared) {

aten/src/ATen/native/cuda/RangeFactories.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,24 @@
1212

1313
namespace {
1414

15-
constexpr int num_threads = C10_WARP_SIZE * 2;
15+
#if defined(USE_ROCM)
16+
constexpr int num_threads() {
17+
return 128;
18+
}
19+
#else
20+
constexpr int num_threads() {
21+
return C10_WARP_SIZE * 2;
22+
}
23+
#endif
1624
constexpr int thread_work_size = 1;
17-
constexpr int block_work_size = thread_work_size * num_threads;
25+
constexpr int block_work_size = thread_work_size * num_threads();
1826

1927
template<typename index_t, typename func_t>
20-
C10_LAUNCH_BOUNDS_1(num_threads)
28+
C10_LAUNCH_BOUNDS_1(num_threads())
2129
__global__ void elementwise_kernel_with_index(index_t N, func_t f, typename function_traits<func_t>::result_type *data) {
2230
#pragma unroll
2331
for (int i = 0; i < thread_work_size; i++) {
24-
index_t idx = block_work_size * blockIdx.x + num_threads * i + threadIdx.x;
32+
index_t idx = block_work_size * blockIdx.x + num_threads() * i + threadIdx.x;
2533
if (idx < N) {
2634
data[idx] = f(idx);
2735
}
@@ -38,10 +46,10 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) {
3846
auto stream = at::cuda::getCurrentCUDAStream();
3947
using scalar_t = typename function_traits<func_t>::result_type;
4048
if (N <= std::numeric_limits<int>::max()) {
41-
elementwise_kernel_with_index<int><<<grid, num_threads, 0, stream>>>(N, f, output.data_ptr<scalar_t>());
49+
elementwise_kernel_with_index<int><<<grid, num_threads(), 0, stream>>>(N, f, output.data_ptr<scalar_t>());
4250
C10_CUDA_KERNEL_LAUNCH_CHECK();
4351
} else {
44-
elementwise_kernel_with_index<int64_t><<<grid, num_threads, 0, stream>>>(N, f, output.data_ptr<scalar_t>());
52+
elementwise_kernel_with_index<int64_t><<<grid, num_threads(), 0, stream>>>(N, f, output.data_ptr<scalar_t>());
4553
C10_CUDA_KERNEL_LAUNCH_CHECK();
4654
}
4755
}

0 commit comments

Comments
 (0)