Skip to content

Commit e3f5863

Browse files
xinyazhangdnikolaev-amd
authored andcommitted
Enable gesvda for ROCM >= 6.1 (#1339)
This also fixes a problem in gesvd driver when UV is not needed. (cherry picked from commit 4ce57ec)
1 parent 3a4ee6b commit e3f5863

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,11 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
321321
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
322322
const auto not_A_H = A.size(-2) >= A.size(-1);
323323
Tensor Vcopy = V; // Shallow copy
324-
#ifdef USE_ROCM
324+
#ifdef ROCM_VERSION
325325
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
326326
// not guaranteed to be column major on ROCM, we have to create a copy to
327327
// deal with this
328-
if (!not_A_H) {
328+
if (compute_uv && !not_A_H) {
329329
Vcopy = at::empty_like(V.mT(),
330330
V.options()
331331
.device(V.device())
@@ -340,8 +340,8 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
340340
infos,
341341
full_matrices, compute_uv, calculate_all_batches, batches);
342342
});
343-
#ifdef USE_ROCM
344-
if (!not_A_H) {
343+
#ifdef ROCM_VERSION
344+
if (compute_uv && !not_A_H) {
345345
V.copy_(Vcopy);
346346
}
347347
#endif
@@ -515,8 +515,8 @@ static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const T
515515
template<typename scalar_t>
516516
static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
517517
const Tensor& infos, bool full_matrices, bool compute_uv) {
518-
#ifndef CUDART_VERSION
519-
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
518+
#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
519+
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0.")
520520
#else
521521
using value_t = typename c10::scalar_value_type<scalar_t>::type;
522522
int m = cuda_int_cast(A.size(-2), "m");
@@ -654,7 +654,7 @@ void svd_cusolver(const Tensor& A,
654654
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
655655

656656
// The default heuristic is to use gesvdj driver
657-
#ifdef USE_ROCM
657+
#if defined(ROCM_VERSION) && ROCM_VERSION < 60100
658658
const auto driver_v = std::string_view("gesvdj");
659659
#else
660660
const auto driver_v = driver.value_or("gesvdj");

aten/src/ATen/native/cuda/linalg/CUDASolver.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ void gesvdjBatched<c10::complex<double>>(
470470
}
471471

472472

473-
// ROCM does not implement gesdva yet
474-
#ifdef CUDART_VERSION
473+
// ROCM does not implement gesdva correctly before 6.1
474+
#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 60100
475475
template<>
476476
void gesvdaStridedBatched_buffersize<float>(
477477
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,

0 commit comments

Comments
 (0)