@@ -328,11 +328,11 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
328
328
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
329
329
const auto not_A_H = A.size (-2 ) >= A.size (-1 );
330
330
Tensor Vcopy = V; // Shallow copy
331
- #ifdef USE_ROCM
331
+ #ifdef ROCM_VERSION
332
332
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
333
333
// not guaranteed to be column major on ROCM, we have to create a copy to
334
334
// deal with this
335
- if (!not_A_H) {
335
+ if (compute_uv && !not_A_H) {
336
336
Vcopy = at::empty_like (V.mT (),
337
337
V.options ()
338
338
.device (V.device ())
@@ -347,8 +347,8 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
347
347
infos,
348
348
full_matrices, compute_uv, calculate_all_batches, batches);
349
349
});
350
- #ifdef USE_ROCM
351
- if (!not_A_H) {
350
+ #ifdef ROCM_VERSION
351
+ if (compute_uv && !not_A_H) {
352
352
V.copy_ (Vcopy);
353
353
}
354
354
#endif
@@ -522,8 +522,8 @@ inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U,
522
522
template <typename scalar_t >
523
523
inline static void apply_svd_cusolver_gesvdaStridedBatched (const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
524
524
const Tensor& infos, bool full_matrices, bool compute_uv) {
525
- #ifndef CUDART_VERSION
526
- TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend." )
525
+ #if defined( CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
526
+ TORCH_CHECK (false , " gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0 ." )
527
527
#else
528
528
using value_t = typename c10::scalar_value_type<scalar_t >::type;
529
529
int m = cuda_int_cast (A.size (-2 ), " m" );
@@ -648,7 +648,7 @@ std::string _format_non_converging_batches(const std::vector<int64_t>& batches)
648
648
void svd_cusolver (const Tensor& A,
649
649
const bool full_matrices,
650
650
const bool compute_uv,
651
- const std ::optional<c10::string_view>& driver,
651
+ const c10 ::optional<c10::string_view>& driver,
652
652
const Tensor& U,
653
653
const Tensor& S,
654
654
const Tensor& V,
@@ -661,7 +661,7 @@ void svd_cusolver(const Tensor& A,
661
661
static const char * check_svd_doc = " Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html" ;
662
662
663
663
// The default heuristic is to use gesvdj driver
664
- #ifdef USE_ROCM
664
+ #if defined(ROCM_VERSION) && ROCM_VERSION < 60100
665
665
const auto driver_v = c10::string_view (" gesvdj" );
666
666
#else
667
667
const auto driver_v = driver.value_or (" gesvdj" );
0 commit comments