Skip to content

CUDA: app option to compile without FlashAttention #12025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN
MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN)
endif # GGML_CUDA_CCBIN

ifdef GGML_CUDA_NO_FA
MK_NVCCFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA

ifdef GGML_CUDA_FA_ALL_QUANTS
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # GGML_CUDA_FA_ALL_QUANTS
Expand Down Expand Up @@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # GGML_CUDA_NO_PEER_COPY

ifdef GGML_CUDA_NO_FA
HIPFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA

OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o
OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
OBJ_GGML_EXT += $(OBJ_CUDA_TMPL)
Expand Down Expand Up @@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # GGML_CUDA_NO_PEER_COPY

ifdef GGML_CUDA_NO_FA
MUSAFLAGS += -DGGML_CUDA_NO_FA
endif # GGML_CUDA_NO_FA

ifdef GGML_CUDA_FA_ALL_QUANTS
MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # GGML_CUDA_FA_ALL_QUANTS
Expand Down
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()

if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()

if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_F16)
endif()
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ typedef float2 dfloat2;
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)

static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#ifndef NEW_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // NEW_MMA_AVAILABLE
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)

// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
Expand Down Expand Up @@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
}

template <int D, int ncols1, int ncols2>
Expand Down
9 changes: 2 additions & 7 deletions ggml/src/ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE

#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)

// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
Expand Down Expand Up @@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}

template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/fattn-tile-f32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#ifdef FLASH_ATTN_AVAILABLE

// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
Expand Down Expand Up @@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}

template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
Expand Down
9 changes: 2 additions & 7 deletions ggml/src/ggml-cuda/fattn-vec-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE

#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)

// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
Expand Down Expand Up @@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}

template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/fattn-vec-f32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
#ifdef FLASH_ATTN_AVAILABLE

// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
Expand Down Expand Up @@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}

template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
Expand Down Expand Up @@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
}
#else
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}

constexpr int get_max_power_of_2(int x) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
add_compile_definitions(GGML_HIP_NO_VMM)
endif()

if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()

if (CXX_IS_HIPCC)
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-hip PRIVATE hip::device)
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-musa/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND)
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()

if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()

if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_F16)
endif()
Expand Down
Loading