Skip to content

Commit c4ad64b

Browse files
micmelessejeffdaily
authored andcommitted
[ROCM] Navi21 Enablement 7: Sparse kernels
This PR is a follow up to the following prs. pytorch#69942 pytorch#72682 pytorch#72809 pytorch#73543 pytorch#73545 pytorch#73546 We are adding support to Navi21 GPUs which have a warpsize of 32. We cannot rely on a constant so we have to dynamically look up the warpsize when launching the kernel on the host side. Inside device functions this is not needed and the compiler can correctly detect the correct warpsize to replace the C10_WARP_SIZE constant. Pull Request resolved: pytorch#73548 Approved by: https://github.com/ngimel
1 parent bc60ee5 commit c4ad64b

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <ATen/cuda/detail/TensorInfo.cuh>
44
#include <ATen/cuda/CUDAApplyUtils.cuh>
5+
#include <ATen/native/cuda/thread_constants.h>
56
#include <c10/macros/Macros.h>
67

78
namespace at { namespace native {
@@ -297,7 +298,7 @@ __global__ void indexSparseIntersectionKernel(
297298
// }
298299

299300
template <typename Dtype, typename Acctype>
300-
C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4)
301+
C10_LAUNCH_BOUNDS_1(num_threads())
301302
__global__ void coalesceValuesKernel(
302303
int64_t *segment_offsets, int64_t *value_indices,
303304
Dtype *values, Dtype *newValues,

aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,9 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) {
142142
const int SZ = 4;
143143
values = values.contiguous();
144144
int64_t stride = c10::multiply_integers(values.sizes().slice(1));
145-
dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) C10_WARP_SIZE*SZ));
146-
dim3 block(C10_WARP_SIZE, SZ);
145+
int warp_size = at::cuda::warp_size();
146+
dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) warp_size*SZ));
147+
dim3 block(warp_size, SZ);
147148
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
148149
at::ScalarType::Half, at::ScalarType::BFloat16, values.scalar_type(), "coalesce_sparse_cuda", [&] {
149150
using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;

0 commit comments

Comments
 (0)