Skip to content

Commit 73da8b8

Browse files
micmelessejeffdaily
authored andcommitted
[ROCM] Navi21 Enablement 8: Index, Repeat and Sort kernels
This PR is a follow up to the following prs. pytorch#69942 pytorch#72682 pytorch#72809 pytorch#73543 pytorch#73545 pytorch#73546 pytorch#73548 We are adding support to Navi21 GPUs which have a warpsize of 32. We cannot rely on a constant so we have to dynamically look up the warpsize when launching the kernel on the host side. Inside device functions this is not needed and the compiler can correctly detect the correct warpsize to replace the C10_WARP_SIZE constant. Pull Request resolved: pytorch#73549 Approved by: https://github.com/malfet
1 parent c4ad64b commit 73da8b8

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,11 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
268268
linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
269269
const int UNROLL = 4;
270270
const int indices_per_block = 4;
271+
const int warp_size = at::cuda::warp_size();
271272
dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
272-
std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (C10_WARP_SIZE*UNROLL))),
273+
std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size*UNROLL))),
273274
std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
274-
dim3 block(C10_WARP_SIZE, indices_per_block);
275+
dim3 block(warp_size, indices_per_block);
275276

276277
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
277278
expandedValue.scalar_type(), "indexing_backward", [&] {

aten/src/ATen/native/cuda/Repeat.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static void compute_cuda(
3333
int64_t size,
3434
int64_t result_size) {
3535
int64_t block = 512;
36-
int64_t warps_per_block = block / C10_WARP_SIZE;
36+
int64_t warps_per_block = block / at::cuda::warp_size();
3737
int64_t grid =
3838
std::min<int64_t>((size + warps_per_block - 1) / warps_per_block, 2048L);
3939

aten/src/ATen/native/cuda/Sorting.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/Dispatch.h>
66
#include <ATen/NumericUtils.h>
77
#include <c10/macros/Macros.h>
8+
#include <ATen/cuda/CUDAContext.h>
89
#include <ATen/cuda/detail/TensorInfo.cuh>
910
#include <ATen/native/cuda/SortingCommon.cuh>
1011
#include <ATen/native/cuda/SortingRadixSelect.cuh>
@@ -189,7 +190,7 @@ struct KthValueLauncher {
189190
}
190191

191192
dim3 block(std::min(
192-
round_up(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024));
193+
round_up(slice_size, (int64_t)at::cuda::warp_size()), (int64_t)1024));
193194
auto stream = at::cuda::getCurrentCUDAStream();
194195
gatherKthValue<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
195196
self_info,
@@ -228,7 +229,7 @@ struct MedianLauncher {
228229
}
229230

230231
dim3 block(std::min(
231-
round_up(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024));
232+
round_up(slice_size, (int64_t)at::cuda::warp_size()), (int64_t)1024));
232233
auto stream = at::cuda::getCurrentCUDAStream();
233234
gatherMedian<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
234235
values_info,

0 commit comments

Comments
 (0)