Skip to content

Commit 3ea36a2

Browse files
pruthvistonydnikolaev-amd
authored andcommitted
CONSOLIDATED COMMITS: Fix lstsq related regressions
=================================================== Fix lstsq related regressions (part of SWDEV-392820) Correcting usage of USE_ROCM (cherry picked from commit e85cf5a)
1 parent c20efe9 commit 3ea36a2

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,13 @@ void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left
239239

240240
template <typename scalar_t>
241241
inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
242+
#if defined(USE_ROCM) && (ROCM_VERSION >= 50400)
243+
auto trans = HIPBLAS_OP_N;
244+
#elif defined(USE_ROCM) && (ROCM_VERSION < 50400)
245+
auto trans = rocblas_operation_none;
246+
#else
242247
auto trans = CUBLAS_OP_N;
248+
#endif
243249
auto m = cuda_int_cast(A.size(-2), "m");
244250
auto n = cuda_int_cast(A.size(-1), "n");
245251

0 commit comments

Comments
 (0)