Skip to content

Commit 9fb82af

Browse files
ikawrakowKawrakow
andauthored
Fix bug in MMVQ kernel (#446)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 6b12c2e commit 9fb82af

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,13 @@ static __device__ void mul_mat_vec_q(
7272

7373
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
7474

75+
//int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ?
76+
// ncols_y < 4 ? 1 : 2 : 1;
77+
7578
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
7679
constexpr int rows_per_cuda_block = 1;
7780
#else
78-
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
81+
constexpr int rows_per_cuda_block = ncols_y < 4 ? 1 : 2;
7982
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
8083

8184
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;

0 commit comments

Comments
 (0)