Skip to content

Commit 6ff2f0c

Browse files
dnikolaev-amdpruthvistony
authored andcommitted
Update gesvda USE_ROCM guards
1 parent 7efca48 commit 6ff2f0c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
328328
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
329329
const auto not_A_H = A.size(-2) >= A.size(-1);
330330
Tensor Vcopy = V; // Shallow copy
331-
#ifdef ROCM_VERSION
331+
#ifdef USE_ROCM
332332
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
333333
// not guaranteed to be column major on ROCM, we have to create a copy to
334334
// deal with this
@@ -347,7 +347,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
347347
infos,
348348
full_matrices, compute_uv, calculate_all_batches, batches);
349349
});
350-
#ifdef ROCM_VERSION
350+
#ifdef USE_ROCM
351351
if (compute_uv && !not_A_H) {
352352
V.copy_(Vcopy);
353353
}
@@ -661,7 +661,7 @@ void svd_cusolver(const Tensor& A,
661661
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
662662

663663
// The default heuristic is to use gesvdj driver
664-
#if defined(ROCM_VERSION) && ROCM_VERSION < 60100
664+
#if defined(USE_ROCM) && ROCM_VERSION < 60100
665665
const auto driver_v = c10::string_view("gesvdj");
666666
#else
667667
const auto driver_v = driver.value_or("gesvdj");

0 commit comments

Comments
 (0)