@@ -321,11 +321,11 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
321
321
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
322
322
const auto not_A_H = A.size (-2 ) >= A.size (-1 );
323
323
Tensor Vcopy = V; // Shallow copy
324
- #ifdef USE_ROCM
324
+ #ifdef ROCM_VERSION
325
325
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
326
326
// not guaranteed to be column major on ROCM, we have to create a copy to
327
327
// deal with this
328
- if (!not_A_H) {
328
+ if (compute_uv && !not_A_H) {
329
329
Vcopy = at::empty_like (V.mT (),
330
330
V.options ()
331
331
.device (V.device ())
@@ -340,8 +340,8 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
340
340
infos,
341
341
full_matrices, compute_uv, calculate_all_batches, batches);
342
342
});
343
- #ifdef USE_ROCM
344
- if (!not_A_H) {
343
+ #ifdef ROCM_VERSION
344
+ if (compute_uv && !not_A_H) {
345
345
V.copy_ (Vcopy);
346
346
}
347
347
#endif
@@ -515,8 +515,8 @@ static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const T
515
515
template <typename scalar_t >
516
516
static void apply_svd_cusolver_gesvdaStridedBatched (const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
517
517
const Tensor& infos, bool full_matrices, bool compute_uv) {
518
- #ifndef CUDART_VERSION
519
- TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend." )
518
+ #if defined( CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
519
+ TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0 ." )
520
520
#else
521
521
using value_t = typename c10::scalar_value_type<scalar_t >::type;
522
522
int m = cuda_int_cast (A.size (-2 ), " m" );
@@ -654,7 +654,7 @@ void svd_cusolver(const Tensor& A,
654
654
static const char * check_svd_doc = " Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html" ;
655
655
656
656
// The default heuristic is to use gesvdj driver
657
- #ifdef USE_ROCM
657
+ #if defined(ROCM_VERSION) && ROCM_VERSION < 60100
658
658
const auto driver_v = std::string_view (" gesvdj" );
659
659
#else
660
660
const auto driver_v = driver.value_or (" gesvdj" );
0 commit comments