@@ -332,7 +332,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
332
332
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
333
333
// not guaranteed to be column major on ROCM, we have to create a copy to
334
334
// deal with this
335
- if (!not_A_H) {
335
+ if (compute_uv && !not_A_H) {
336
336
Vcopy = at::empty_like (V.mT (),
337
337
V.options ()
338
338
.device (V.device ())
@@ -348,7 +348,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
348
348
full_matrices, compute_uv, calculate_all_batches, batches);
349
349
});
350
350
#ifdef ROCM_VERSION
351
- if (!not_A_H) {
351
+ if (compute_uv && !not_A_H) {
352
352
V.copy_ (Vcopy);
353
353
}
354
354
#endif
@@ -522,8 +522,8 @@ inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U,
522
522
template <typename scalar_t >
523
523
inline static void apply_svd_cusolver_gesvdaStridedBatched (const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
524
524
const Tensor& infos, bool full_matrices, bool compute_uv) {
525
- #ifndef CUDART_VERSION
526
- TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend." )
525
+ #if defined( CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
526
+ TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0 ." )
527
527
#else
528
528
using value_t = typename c10::scalar_value_type<scalar_t >::type;
529
529
int m = cuda_int_cast (A.size (-2 ), " m" );
@@ -661,7 +661,7 @@ void svd_cusolver(const Tensor& A,
661
661
static const char * check_svd_doc = " Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html" ;
662
662
663
663
// The default heuristic is to use gesvdj driver
664
- #ifdef ROCM_VERSION
664
+ #if defined( ROCM_VERSION) && ROCM_VERSION < 60100
665
665
const auto driver_v = c10::string_view (" gesvdj" );
666
666
#else
667
667
const auto driver_v = driver.value_or (" gesvdj" );
0 commit comments