Skip to content

Commit a28e0d5

Browse files
CUDA: app option to compile without FlashAttention (#12025)
1 parent 36c258e commit a28e0d5

13 files changed

+46
-31
lines changed

Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN
680680
MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN)
681681
endif # GGML_CUDA_CCBIN
682682

683+
ifdef GGML_CUDA_NO_FA
684+
MK_NVCCFLAGS += -DGGML_CUDA_NO_FA
685+
endif # GGML_CUDA_NO_FA
686+
683687
ifdef GGML_CUDA_FA_ALL_QUANTS
684688
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
685689
endif # GGML_CUDA_FA_ALL_QUANTS
@@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
800804
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
801805
endif # GGML_CUDA_NO_PEER_COPY
802806

807+
ifdef GGML_CUDA_NO_FA
808+
HIPFLAGS += -DGGML_CUDA_NO_FA
809+
endif # GGML_CUDA_NO_FA
810+
803811
OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o
804812
OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
805813
OBJ_GGML_EXT += $(OBJ_CUDA_TMPL)
@@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
876884
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
877885
endif # GGML_CUDA_NO_PEER_COPY
878886

887+
ifdef GGML_CUDA_NO_FA
888+
MUSAFLAGS += -DGGML_CUDA_NO_FA
889+
endif # GGML_CUDA_NO_FA
890+
879891
ifdef GGML_CUDA_FA_ALL_QUANTS
880892
MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
881893
endif # GGML_CUDA_FA_ALL_QUANTS

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
151151
"ggml: max. batch size for using peer access")
152152
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
153153
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
154+
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
154155
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
155156
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
156157

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
6969
add_compile_definitions(GGML_CUDA_NO_VMM)
7070
endif()
7171

72+
if (NOT GGML_CUDA_FA)
73+
add_compile_definitions(GGML_CUDA_NO_FA)
74+
endif()
75+
7276
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
7377
add_compile_definitions(GGML_CUDA_F16)
7478
endif()

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ typedef float2 dfloat2;
204204
#define CP_ASYNC_AVAILABLE
205205
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
206206

207-
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
207+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
208208
#define FLASH_ATTN_AVAILABLE
209-
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
209+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
210210

211211
static bool fp16_available(const int cc) {
212212
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
839839
const int ne1,
840840
const int ne2,
841841
const int ne3) {
842-
#ifndef NEW_MMA_AVAILABLE
843-
NO_DEVICE_CODE;
844-
return;
845-
#endif // NEW_MMA_AVAILABLE
842+
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
846843

847844
// Skip unused kernel variants for faster compilation:
848845
if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
933930
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
934931
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
935932
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
933+
#else
934+
NO_DEVICE_CODE;
935+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
936936
}
937937

938938
template <int D, int ncols1, int ncols2>

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
4444
const int ne1,
4545
const int ne2,
4646
const int ne3) {
47-
#ifdef FP16_AVAILABLE
48-
49-
#ifndef FLASH_ATTN_AVAILABLE
50-
NO_DEVICE_CODE;
51-
return;
52-
#endif // FLASH_ATTN_AVAILABLE
47+
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
5348

5449
// Skip unused kernel variants for faster compilation:
5550
#ifdef FP16_MMA_AVAILABLE
@@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
290285
}
291286
#else
292287
NO_DEVICE_CODE;
293-
#endif // FP16_AVAILABLE
288+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
294289
}
295290

296291
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
4444
const int ne1,
4545
const int ne2,
4646
const int ne3) {
47-
#ifndef FLASH_ATTN_AVAILABLE
48-
NO_DEVICE_CODE;
49-
return;
50-
#endif // FLASH_ATTN_AVAILABLE
47+
#ifdef FLASH_ATTN_AVAILABLE
5148

5249
// Skip unused kernel variants for faster compilation:
5350
#ifdef FP16_MMA_AVAILABLE
@@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
285282
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
286283
}
287284
}
285+
#else
286+
NO_DEVICE_CODE;
287+
#endif // FLASH_ATTN_AVAILABLE
288288
}
289289

290290
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
4141
const int ne1,
4242
const int ne2,
4343
const int ne3) {
44-
#ifdef FP16_AVAILABLE
45-
46-
#ifndef FLASH_ATTN_AVAILABLE
47-
NO_DEVICE_CODE;
48-
return;
49-
#endif // FLASH_ATTN_AVAILABLE
44+
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
5045

5146
// Skip unused kernel variants for faster compilation:
5247
if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
300295
}
301296
#else
302297
NO_DEVICE_CODE;
303-
#endif // FP16_AVAILABLE
298+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
304299
}
305300

306301
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
4141
const int ne1,
4242
const int ne2,
4343
const int ne3) {
44-
#ifndef FLASH_ATTN_AVAILABLE
45-
NO_DEVICE_CODE;
46-
return;
47-
#endif // FLASH_ATTN_AVAILABLE
44+
#ifdef FLASH_ATTN_AVAILABLE
4845

4946
// Skip unused kernel variants for faster compilation:
5047
if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
281278
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
282279
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
283280
}
281+
#else
282+
NO_DEVICE_CODE;
283+
#endif // FLASH_ATTN_AVAILABLE
284284
}
285285

286286
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
5151
const int ne1,
5252
const int ne2,
5353
const int ne3) {
54-
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
54+
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
5555
// Skip unused kernel variants for faster compilation:
5656
if (use_logit_softcap && !(D == 128 || D == 256)) {
5757
NO_DEVICE_CODE;
@@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
425425
}
426426
#else
427427
NO_DEVICE_CODE;
428-
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
428+
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
429429
}
430430

431431
constexpr int get_max_power_of_2(int x) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32033203
case GGML_OP_FLASH_ATTN_EXT: {
32043204
#ifndef FLASH_ATTN_AVAILABLE
32053205
return false;
3206-
#endif
3206+
#endif // FLASH_ATTN_AVAILABLE
32073207
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
32083208
return false;
32093209
}

ggml/src/ggml-hip/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
107107
add_compile_definitions(GGML_HIP_NO_VMM)
108108
endif()
109109

110+
if (NOT GGML_CUDA_FA)
111+
add_compile_definitions(GGML_CUDA_NO_FA)
112+
endif()
113+
110114
if (CXX_IS_HIPCC)
111115
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
112116
target_link_libraries(ggml-hip PRIVATE hip::device)

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND)
8383
add_compile_definitions(GGML_CUDA_NO_VMM)
8484
endif()
8585

86+
if (NOT GGML_CUDA_FA)
87+
add_compile_definitions(GGML_CUDA_NO_FA)
88+
endif()
89+
8690
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
8791
add_compile_definitions(GGML_CUDA_F16)
8892
endif()

0 commit comments

Comments
 (0)