Skip to content

Commit ae4cbb8

Browse files
xinyazhangjithunnair-amd
authored andcommitted
Enable gesvda for ROCM >= 6.1 (#1339)
This also fixes a problem in gesvd driver when UV is not needed.
1 parent 07efcff commit ae4cbb8

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,11 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
328328
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
329329
const auto not_A_H = A.size(-2) >= A.size(-1);
330330
Tensor Vcopy = V; // Shallow copy
331-
#ifdef USE_ROCM
331+
#ifdef ROCM_VERSION
332332
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
333333
// not guaranteed to be column major on ROCM, we have to create a copy to
334334
// deal with this
335-
if (!not_A_H) {
335+
if (compute_uv && !not_A_H) {
336336
Vcopy = at::empty_like(V.mT(),
337337
V.options()
338338
.device(V.device())
@@ -347,8 +347,8 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
347347
infos,
348348
full_matrices, compute_uv, calculate_all_batches, batches);
349349
});
350-
#ifdef USE_ROCM
351-
if (!not_A_H) {
350+
#ifdef ROCM_VERSION
351+
if (compute_uv && !not_A_H) {
352352
V.copy_(Vcopy);
353353
}
354354
#endif
@@ -522,8 +522,8 @@ inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U,
522522
template<typename scalar_t>
523523
inline static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
524524
const Tensor& infos, bool full_matrices, bool compute_uv) {
525-
#ifndef CUDART_VERSION
526-
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
525+
#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
526+
TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0.")
527527
#else
528528
using value_t = typename c10::scalar_value_type<scalar_t>::type;
529529
int m = cuda_int_cast(A.size(-2), "m");
@@ -648,7 +648,7 @@ std::string _format_non_converging_batches(const std::vector<int64_t>& batches)
648648
void svd_cusolver(const Tensor& A,
649649
const bool full_matrices,
650650
const bool compute_uv,
651-
const std::optional<c10::string_view>& driver,
651+
const c10::optional<c10::string_view>& driver,
652652
const Tensor& U,
653653
const Tensor& S,
654654
const Tensor& V,
@@ -661,7 +661,7 @@ void svd_cusolver(const Tensor& A,
661661
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
662662

663663
// The default heuristic is to use gesvdj driver
664-
#ifdef USE_ROCM
664+
#if defined(ROCM_VERSION) && ROCM_VERSION < 60100
665665
const auto driver_v = c10::string_view("gesvdj");
666666
#else
667667
const auto driver_v = driver.value_or("gesvdj");

aten/src/ATen/native/cuda/linalg/CUDASolver.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ void gesvdjBatched<c10::complex<double>>(
470470
}
471471

472472

473-
// ROCM does not implement gesdva yet
474-
#ifdef CUDART_VERSION
473+
// ROCM does not implement gesdva correctly before 6.1
474+
#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 60100
475475
template<>
476476
void gesvdaStridedBatched_buffersize<float>(
477477
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,

0 commit comments

Comments
 (0)