Skip to content

Commit a851703

Browse files
alugoreypruthvistony
authored andcommitted
Integrate new batched linalg drivers (#1163)
* Integrate new batched linalg drivers * Skip test_qr_batched; ROCM doesn't support QR decomp for complex dtype * Skip complex types, hipsolver does not support * Skip complex types in other batched tests as well
1 parent fe7cc73 commit a851703

File tree

9 files changed

+577
-101
lines changed

9 files changed

+577
-101
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 317 additions & 17 deletions
Large diffs are not rendered by default.

aten/src/ATen/cuda/CUDABlas.h

Lines changed: 149 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
#include <ATen/cuda/CUDAContext.h>
1717
#include <ATen/OpMathType.h>
1818

19+
#ifdef USE_ROCM
20+
#include <hipblas.h>
21+
#include <hipsolver.h>
22+
#endif
23+
24+
1925
namespace at {
2026
namespace cuda {
2127
namespace blas {
@@ -221,8 +227,30 @@ void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
221227
template <>
222228
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
223229

224-
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
225-
#ifdef CUDART_VERSION
230+
#ifdef USE_ROCM
231+
232+
233+
#define HIPBLAS_GETRS_ARGTYPES(Dtype) \
234+
hipblasHandle_t handle, hipblasOperation_t trans, \
235+
int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
236+
Dtype** dB_array, int ldb, int* info_array, int batchsize
237+
238+
template<class Dtype>
239+
void getrsBatched(HIPBLAS_GETRS_ARGTYPES(Dtype)) {
240+
TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ",
241+
typeid(Dtype).name());
242+
}
243+
template<>
244+
TORCH_CUDA_CU_API void getrsBatched<float>(HIPBLAS_GETRS_ARGTYPES(float));
245+
template<>
246+
TORCH_CUDA_CU_API void getrsBatched<double>(HIPBLAS_GETRS_ARGTYPES(double));
247+
template<>
248+
TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(HIPBLAS_GETRS_ARGTYPES(c10::complex<float>));
249+
template<>
250+
TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(HIPBLAS_GETRS_ARGTYPES(c10::complex<double>));
251+
252+
253+
#else
226254

227255
#define CUDABLAS_GETRS_ARGTYPES(Dtype) \
228256
cublasHandle_t handle, cublasOperation_t trans, \
@@ -243,6 +271,31 @@ TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES
243271
template<>
244272
TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
245273

274+
#endif
275+
276+
#ifdef USE_ROCM
277+
#define HIPBLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
278+
hipblasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
279+
Dtype **tau_array, int *info, int batchsize
280+
281+
template <class Dtype>
282+
void geqrfBatched(HIPBLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
283+
TORCH_INTERNAL_ASSERT(
284+
false,
285+
"at::cuda::blas::geqrfBatched: not implemented for ",
286+
typeid(Dtype).name());
287+
}
288+
template <>
289+
TORCH_CUDA_CU_API void geqrfBatched<float>(HIPBLAS_GEQRF_BATCHED_ARGTYPES(float));
290+
template <>
291+
TORCH_CUDA_CU_API void geqrfBatched<double>(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double));
292+
template <>
293+
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
294+
HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
295+
template <>
296+
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
297+
HIPBLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
298+
#else
246299
#define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
247300
cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
248301
Dtype **tau_array, int *info, int batchsize
@@ -264,22 +317,107 @@ TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
264317
template <>
265318
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
266319
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
320+
#endif
321+
322+
#ifdef USE_ROCM
323+
#define HIPBLAS_GETRF_BATCHED_ARGTYPES(Dtype) \
324+
int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
325+
template<class Dtype>
326+
void getrfBatched(HIPBLAS_GETRF_BATCHED_ARGTYPES(Dtype)) {
327+
TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name());
328+
}
329+
template<>
330+
TORCH_CUDA_CU_API void getrfBatched<float>(HIPBLAS_GETRF_BATCHED_ARGTYPES(float));
331+
template<>
332+
TORCH_CUDA_CU_API void getrfBatched<double>(HIPBLAS_GETRF_BATCHED_ARGTYPES(double));
333+
template<>
334+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex<double>));
335+
template<>
336+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex<float>));
337+
338+
#else
267339

268-
#define CUDABLAS_GETRF_ARGTYPES(Dtype) \
340+
#define CUDABLAS_GETRF_BATCHED_ARGTYPES(Dtype) \
269341
int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
270342

271343
template<class Dtype>
272-
void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
344+
void getrfBatched(CUDABLAS_GETRF_BATCHED_ARGTYPES(Dtype)) {
273345
TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name());
274346
}
275347
template<>
276-
TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
348+
TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_BATCHED_ARGTYPES(float));
349+
template<>
350+
TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_BATCHED_ARGTYPES(double));
351+
template<>
352+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex<double>));
353+
template<>
354+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex<float>));
355+
#endif
356+
357+
358+
#ifdef USE_ROCM
359+
#define HIPBLAS_GETRI_BATCHED_ARGTYPES(Dtype) \
360+
int n, Dtype** dA_array, int ldda, int* ipiv_array, Dtype** dC_array, int lddc, int* info_array, int batchsize
361+
362+
template<class Dtype>
363+
void getriBatched(HIPBLAS_GETRI_BATCHED_ARGTYPES(Dtype)) {
364+
TORCH_CHECK(false, "at::cuda::blas::getriBatched: not implemented for ", typeid(Dtype).name());
365+
}
366+
template<>
367+
TORCH_CUDA_CU_API void getriBatched<float>(HIPBLAS_GETRI_BATCHED_ARGTYPES(float));
277368
template<>
278-
TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
369+
TORCH_CUDA_CU_API void getriBatched<double>(HIPBLAS_GETRI_BATCHED_ARGTYPES(double));
279370
template<>
280-
TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
371+
TORCH_CUDA_CU_API void getriBatched<c10::complex<double>>(HIPBLAS_GETRI_BATCHED_ARGTYPES(c10::complex<double>));
281372
template<>
282-
TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
373+
TORCH_CUDA_CU_API void getriBatched<c10::complex<float>>(HIPBLAS_GETRI_BATCHED_ARGTYPES(c10::complex<float>));
374+
375+
376+
#else
377+
378+
379+
#define CUDABLAS_GETRI_BATCHED_ARGTYPES(Dtype) \
380+
int n, Dtype** dA_array, int ldda, int* ipiv_array, Dtype** dC_array, int lddc, int* info_array, int batchsize
381+
382+
template<class Dtype>
383+
void getriBatched(CUDABLAS_GETRI_BATCHED_ARGTYPES(Dtype)) {
384+
TORCH_CHECK(false, "at::cuda::blas::getriBatched: not implemented for ", typeid(Dtype).name());
385+
}
386+
template<>
387+
TORCH_CUDA_CU_API void getriBatched<float>(CUDABLAS_GETRI_BATCHED_ARGTYPES(float));
388+
template<>
389+
TORCH_CUDA_CU_API void getriBatched<double>(CUDABLAS_GETRI_BATCHED_ARGTYPES(double));
390+
template<>
391+
TORCH_CUDA_CU_API void getriBatched<c10::complex<double>>(CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex<double>));
392+
template<>
393+
TORCH_CUDA_CU_API void getriBatched<c10::complex<float>>(CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex<float>));
394+
395+
#endif
396+
397+
398+
399+
#if defined(USE_ROCM) && (ROCM_VERSION >= 50400)
400+
401+
#define HIPBLAS_GELS_BATCHED_ARGTYPES(Dtype) \
402+
hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
403+
404+
template <class Dtype>
405+
void gelsBatched(HIPBLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
406+
TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name());
407+
}
408+
409+
template<>
410+
TORCH_CUDA_CU_API void gelsBatched<double>(HIPBLAS_GELS_BATCHED_ARGTYPES(double));
411+
template<>
412+
TORCH_CUDA_CU_API void gelsBatched<float>(HIPBLAS_GELS_BATCHED_ARGTYPES(float));
413+
template<>
414+
TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
415+
template<>
416+
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
417+
418+
#else
419+
420+
#ifdef CUDART_VERSION
283421

284422
#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
285423
cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
@@ -298,7 +436,9 @@ TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_A
298436
template<>
299437
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
300438

301-
#endif // CUDART_VERSION
439+
#endif //CUDART_VERSION
440+
#endif //USE_ROCM
441+
302442

303443
} // namespace blas
304444
} // namespace cuda

aten/src/ATen/cuda/Exceptions.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#include <c10/util/Exception.h>
1313
#include <c10/cuda/CUDAException.h>
1414

15+
#ifdef USE_ROCM
16+
#include <hipblas.h>
17+
#endif
1518

1619
namespace c10 {
1720

@@ -40,10 +43,16 @@ class CuDNNError : public c10::Error {
4043
} \
4144
} while (0)
4245

46+
47+
48+
49+
50+
4351
namespace at { namespace cuda { namespace blas {
4452
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
4553
}}} // namespace at::cuda::blas
4654

55+
4756
#define TORCH_CUDABLAS_CHECK(EXPR) \
4857
do { \
4958
cublasStatus_t __err = EXPR; \
@@ -53,6 +62,17 @@ C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
5362
" when calling `" #EXPR "`"); \
5463
} while (0)
5564

65+
#ifdef USE_ROCM
66+
#define TORCH_HIPBLAS_CHECK(EXPR) \
67+
do { \
68+
hipblasStatus_t __err = EXPR; \
69+
TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
70+
"CUDA error: ", \
71+
" when calling `" #EXPR "`"); \
72+
} while (0)
73+
#endif
74+
75+
5676
const char *cusparseGetErrorString(cusparseStatus_t status);
5777

5878
#define TORCH_CUDASPARSE_CHECK(EXPR) \

0 commit comments

Comments
 (0)