Skip to content

Commit ce281b9

Browse files
committed
llama : disable FA for AMD
1 parent 8937ec5 commit ce281b9

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,8 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
399399

400400
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
401401
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
402-
#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
403-
defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA
402+
403+
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
404404

405405
// TODO: move to ggml-common.h
406406
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

ggml-cuda/fattn.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
#include "fattn.cuh"
33

44
#include <cstdint>
5+
6+
#if FP16_MMA_AVAILABLE
57
#include <mma.h>
8+
#endif
69

710
#define FATTN_KQ_STRIDE 256
811
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.

llama.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15357,6 +15357,13 @@ struct llama_context * llama_new_context_with_model(
1535715357
cparams.flash_attn = false;
1535815358
}
1535915359

15360+
#ifdef GGML_USE_HIPBLAS
15361+
if (cparams.flash_attn) {
15362+
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
15363+
cparams.flash_attn = false;
15364+
}
15365+
#endif
15366+
1536015367
if (params.seed == LLAMA_DEFAULT_SEED) {
1536115368
params.seed = time(NULL);
1536215369
}

0 commit comments

Comments
 (0)