16
16
#include < ATen/cuda/CUDAContext.h>
17
17
#include < ATen/OpMathType.h>
18
18
19
+ #ifdef USE_ROCM
20
+ #include < hipblas.h>
21
+ #include < hipsolver.h>
22
+ #endif
23
+
24
+
19
25
namespace at {
20
26
namespace cuda {
21
27
namespace blas {
@@ -221,8 +227,30 @@ void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
221
227
template <>
222
228
void vdot<c10::complex<double >>(CUDABLAS_DOT_ARGTYPES(c10::complex<double >));
223
229
224
- // This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
225
- #ifdef CUDART_VERSION
230
+ #ifdef USE_ROCM
231
+
232
+
233
+ #define HIPBLAS_GETRS_ARGTYPES (Dtype ) \
234
+ hipblasHandle_t handle, hipblasOperation_t trans, \
235
+ int n, int nrhs, Dtype** dA_array, int lda, int * ipiv_array, \
236
+ Dtype** dB_array, int ldb, int * info_array, int batchsize
237
+
238
+ template <class Dtype >
239
+ void getrsBatched (HIPBLAS_GETRS_ARGTYPES(Dtype)) {
240
+ TORCH_INTERNAL_ASSERT (false , " at::cuda::blas::getrsBatched: not implemented for " ,
241
+ typeid (Dtype).name ());
242
+ }
243
+ template <>
244
+ TORCH_CUDA_CU_API void getrsBatched<float >(HIPBLAS_GETRS_ARGTYPES(float ));
245
+ template <>
246
+ TORCH_CUDA_CU_API void getrsBatched<double >(HIPBLAS_GETRS_ARGTYPES(double ));
247
+ template <>
248
+ TORCH_CUDA_CU_API void getrsBatched<c10::complex<float >>(HIPBLAS_GETRS_ARGTYPES(c10::complex<float >));
249
+ template <>
250
+ TORCH_CUDA_CU_API void getrsBatched<c10::complex<double >>(HIPBLAS_GETRS_ARGTYPES(c10::complex<double >));
251
+
252
+
253
+ #else
226
254
227
255
#define CUDABLAS_GETRS_ARGTYPES (Dtype ) \
228
256
cublasHandle_t handle, cublasOperation_t trans, \
@@ -243,6 +271,31 @@ TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES
243
271
template <>
244
272
TORCH_CUDA_CU_API void getrsBatched<c10::complex<double >>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double >));
245
273
274
+ #endif
275
+
276
+ #ifdef USE_ROCM
277
+ #define HIPBLAS_GEQRF_BATCHED_ARGTYPES (Dtype ) \
278
+ hipblasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
279
+ Dtype **tau_array, int *info, int batchsize
280
+
281
+ template <class Dtype >
282
+ void geqrfBatched (HIPBLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
283
+ TORCH_INTERNAL_ASSERT (
284
+ false ,
285
+ " at::cuda::blas::geqrfBatched: not implemented for " ,
286
+ typeid (Dtype).name ());
287
+ }
288
+ template <>
289
+ TORCH_CUDA_CU_API void geqrfBatched<float >(HIPBLAS_GEQRF_BATCHED_ARGTYPES(float ));
290
+ template <>
291
+ TORCH_CUDA_CU_API void geqrfBatched<double >(HIPBLAS_GEQRF_BATCHED_ARGTYPES(double ));
292
+ template <>
293
+ TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double >>(
294
+ HIPBLAS_GEQRF_BATCHED_ARGTYPES (c10::complex<double >));
295
+ template <>
296
+ TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float >>(
297
+ HIPBLAS_GEQRF_BATCHED_ARGTYPES (c10::complex<float >));
298
+ #else
246
299
#define CUDABLAS_GEQRF_BATCHED_ARGTYPES (Dtype ) \
247
300
cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
248
301
Dtype **tau_array, int *info, int batchsize
@@ -264,22 +317,107 @@ TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
264
317
template <>
265
318
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float >>(
266
319
CUDABLAS_GEQRF_BATCHED_ARGTYPES (c10::complex<float >));
320
+ #endif
321
+
322
+ #ifdef USE_ROCM
323
+ #define HIPBLAS_GETRF_BATCHED_ARGTYPES (Dtype ) \
324
+ int n, Dtype** dA_array, int ldda, int * ipiv_array, int * info_array, int batchsize
325
+ template <class Dtype >
326
+ void getrfBatched (HIPBLAS_GETRF_BATCHED_ARGTYPES(Dtype)) {
327
+ TORCH_CHECK (false , " at::cuda::blas::getrfBatched: not implemented for " , typeid (Dtype).name ());
328
+ }
329
+ template <>
330
+ TORCH_CUDA_CU_API void getrfBatched<float >(HIPBLAS_GETRF_BATCHED_ARGTYPES(float ));
331
+ template <>
332
+ TORCH_CUDA_CU_API void getrfBatched<double >(HIPBLAS_GETRF_BATCHED_ARGTYPES(double ));
333
+ template <>
334
+ TORCH_CUDA_CU_API void getrfBatched<c10::complex<double >>(HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex<double >));
335
+ template <>
336
+ TORCH_CUDA_CU_API void getrfBatched<c10::complex<float >>(HIPBLAS_GETRF_BATCHED_ARGTYPES(c10::complex<float >));
337
+
338
+ #else
267
339
268
- #define CUDABLAS_GETRF_ARGTYPES (Dtype ) \
340
+ #define CUDABLAS_GETRF_BATCHED_ARGTYPES (Dtype ) \
269
341
int n, Dtype** dA_array, int ldda, int * ipiv_array, int * info_array, int batchsize
270
342
271
343
template <class Dtype >
272
- void getrfBatched (CUDABLAS_GETRF_ARGTYPES (Dtype)) {
344
+ void getrfBatched (CUDABLAS_GETRF_BATCHED_ARGTYPES (Dtype)) {
273
345
TORCH_CHECK (false , " at::cuda::blas::getrfBatched: not implemented for " , typeid (Dtype).name ());
274
346
}
275
347
template <>
276
- TORCH_CUDA_CU_API void getrfBatched<float >(CUDABLAS_GETRF_ARGTYPES(float ));
348
+ TORCH_CUDA_CU_API void getrfBatched<float >(CUDABLAS_GETRF_BATCHED_ARGTYPES(float ));
349
+ template <>
350
+ TORCH_CUDA_CU_API void getrfBatched<double >(CUDABLAS_GETRF_BATCHED_ARGTYPES(double ));
351
+ template <>
352
+ TORCH_CUDA_CU_API void getrfBatched<c10::complex<double >>(CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex<double >));
353
+ template <>
354
+ TORCH_CUDA_CU_API void getrfBatched<c10::complex<float >>(CUDABLAS_GETRF_BATCHED_ARGTYPES(c10::complex<float >));
355
+ #endif
356
+
357
+
358
+ #ifdef USE_ROCM
359
+ #define HIPBLAS_GETRI_BATCHED_ARGTYPES (Dtype ) \
360
+ int n, Dtype** dA_array, int ldda, int * ipiv_array, Dtype** dC_array, int lddc, int * info_array, int batchsize
361
+
362
+ template <class Dtype >
363
+ void getriBatched (HIPBLAS_GETRI_BATCHED_ARGTYPES(Dtype)) {
364
+ TORCH_CHECK (false , " at::cuda::blas::getriBatched: not implemented for " , typeid (Dtype).name ());
365
+ }
366
+ template <>
367
+ TORCH_CUDA_CU_API void getriBatched<float >(HIPBLAS_GETRI_BATCHED_ARGTYPES(float ));
277
368
template <>
278
- TORCH_CUDA_CU_API void getrfBatched <double >(CUDABLAS_GETRF_ARGTYPES (double ));
369
+ TORCH_CUDA_CU_API void getriBatched <double >(HIPBLAS_GETRI_BATCHED_ARGTYPES (double ));
279
370
template <>
280
- TORCH_CUDA_CU_API void getrfBatched <c10::complex<double >>(CUDABLAS_GETRF_ARGTYPES (c10::complex<double >));
371
+ TORCH_CUDA_CU_API void getriBatched <c10::complex<double >>(HIPBLAS_GETRI_BATCHED_ARGTYPES (c10::complex<double >));
281
372
template <>
282
- TORCH_CUDA_CU_API void getrfBatched<c10::complex<float >>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float >));
373
+ TORCH_CUDA_CU_API void getriBatched<c10::complex<float >>(HIPBLAS_GETRI_BATCHED_ARGTYPES(c10::complex<float >));
374
+
375
+
376
+ #else
377
+
378
+
379
+ #define CUDABLAS_GETRI_BATCHED_ARGTYPES (Dtype ) \
380
+ int n, Dtype** dA_array, int ldda, int * ipiv_array, Dtype** dC_array, int lddc, int * info_array, int batchsize
381
+
382
+ template <class Dtype >
383
+ void getriBatched (CUDABLAS_GETRI_BATCHED_ARGTYPES(Dtype)) {
384
+ TORCH_CHECK (false , " at::cuda::blas::getriBatched: not implemented for " , typeid (Dtype).name ());
385
+ }
386
+ template <>
387
+ TORCH_CUDA_CU_API void getriBatched<float >(CUDABLAS_GETRI_BATCHED_ARGTYPES(float ));
388
+ template <>
389
+ TORCH_CUDA_CU_API void getriBatched<double >(CUDABLAS_GETRI_BATCHED_ARGTYPES(double ));
390
+ template <>
391
+ TORCH_CUDA_CU_API void getriBatched<c10::complex<double >>(CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex<double >));
392
+ template <>
393
+ TORCH_CUDA_CU_API void getriBatched<c10::complex<float >>(CUDABLAS_GETRI_BATCHED_ARGTYPES(c10::complex<float >));
394
+
395
+ #endif
396
+
397
+
398
+
399
+ #if defined(USE_ROCM) && (ROCM_VERSION >= 50400)
400
+
401
+ #define HIPBLAS_GELS_BATCHED_ARGTYPES (Dtype ) \
402
+ hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int * info, int *devInfoArray, int batchSize
403
+
404
+ template <class Dtype >
405
+ void gelsBatched (HIPBLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
406
+ TORCH_INTERNAL_ASSERT (false , " at::cuda::blas::gelsBatched: not implemented for " , typeid (Dtype).name ());
407
+ }
408
+
409
+ template <>
410
+ TORCH_CUDA_CU_API void gelsBatched<double >(HIPBLAS_GELS_BATCHED_ARGTYPES(double ));
411
+ template <>
412
+ TORCH_CUDA_CU_API void gelsBatched<float >(HIPBLAS_GELS_BATCHED_ARGTYPES(float ));
413
+ template <>
414
+ TORCH_CUDA_CU_API void gelsBatched<c10::complex<double >>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex<double >));
415
+ template <>
416
+ TORCH_CUDA_CU_API void gelsBatched<c10::complex<float >>(HIPBLAS_GELS_BATCHED_ARGTYPES(c10::complex<float >));
417
+
418
+ #else
419
+
420
+ #ifdef CUDART_VERSION
283
421
284
422
#define CUDABLAS_GELS_BATCHED_ARGTYPES (Dtype ) \
285
423
cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int * info, int *devInfoArray, int batchSize
@@ -298,7 +436,9 @@ TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_A
298
436
template <>
299
437
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float >>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float >));
300
438
301
- #endif // CUDART_VERSION
439
+ #endif // CUDART_VERSION
440
+ #endif // USE_ROCM
441
+
302
442
303
443
} // namespace blas
304
444
} // namespace cuda
0 commit comments