Skip to content

Commit 7a226e2

Browse files
xinyazhangpruthvistony
authored andcommitted
Enable gesvda for ROCM >= 6.1 (#1339)
This also fixes a problem in gesvd driver when UV is not needed.
1 parent a7df7d3 commit 7a226e2

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
332332
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
333333
// not guaranteed to be column major on ROCM, we have to create a copy to
334334
// deal with this
335-
if (!not_A_H) {
335+
if (compute_uv && !not_A_H) {
336336
Vcopy = at::empty_like(V.mT(),
337337
V.options()
338338
.device(V.device())
@@ -348,7 +348,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
348348
full_matrices, compute_uv, calculate_all_batches, batches);
349349
});
350350
#ifdef ROCM_VERSION
351-
if (!not_A_H) {
351+
if (compute_uv && !not_A_H) {
352352
V.copy_(Vcopy);
353353
}
354354
#endif
@@ -522,8 +522,8 @@ inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U,
522522
template<typename scalar_t>
523523
inline static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
524524
const Tensor& infos, bool full_matrices, bool compute_uv) {
525-
#ifndef CUDART_VERSION
526-
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
525+
#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
526+
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0.")
527527
#else
528528
using value_t = typename c10::scalar_value_type<scalar_t>::type;
529529
int m = cuda_int_cast(A.size(-2), "m");
@@ -661,7 +661,7 @@ void svd_cusolver(const Tensor& A,
661661
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
662662

663663
// The default heuristic is to use gesvdj driver
664-
#ifdef ROCM_VERSION
664+
#if defined(ROCM_VERSION) && ROCM_VERSION < 60100
665665
const auto driver_v = c10::string_view("gesvdj");
666666
#else
667667
const auto driver_v = driver.value_or("gesvdj");

aten/src/ATen/native/cuda/linalg/CUDASolver.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ void gesvdjBatched<c10::complex<double>>(
470470
}
471471

472472

473-
// ROCM does not implement gesdva yet
474-
#ifdef CUDART_VERSION
473+
// ROCM does not implement gesdva correctly before 6.1
474+
#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 60100
475475
template<>
476476
void gesvdaStridedBatched_buffersize<float>(
477477
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,

0 commit comments

Comments
 (0)