Skip to content

Commit 1ac8d89

Browse files
authored
[ROCM] Properly disable Flash Attention/Efficient Attention with environment variables (#1542)
Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can compile correctly. This is cherry-picked version of pytorch#133866 Tested with `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py develop --user` and `python -c 'import torch'`
1 parent 1a0f455 commit 1ac8d89

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,16 @@ cmake_dependent_option(
877877
Will be disabled if not supported by the platform" ON
878878
"USE_CUDA OR USE_ROCM" OFF)
879879

880+
#
881+
# Cannot be put into Dependencies.cmake due circular dependency:
882+
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
883+
#
884+
if(USE_ROCM)
885+
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
886+
include(cmake/External/aotriton.cmake)
887+
endif()
888+
endif()
889+
880890
if(DEBUG_CUDA)
881891
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
882892
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
#include <c10/util/string_view.h>
2626

2727
#if USE_ROCM
28+
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
2829
#include <aotriton/flash.h>
30+
#define USE_AOTRITON 1
31+
#endif
2932
#endif
3033

3134
/**
@@ -208,6 +211,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
208211
using sm80 = SMVersion<8, 0>;
209212
using sm90 = SMVersion<9, 0>;
210213
#if USE_ROCM
214+
#if USE_AOTRITON
211215
auto stream = at::cuda::getCurrentCUDAStream().stream();
212216
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
213217
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -217,6 +221,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
217221
}
218222
return false;
219223
}
224+
#else
225+
return false;
226+
#endif
220227
#else
221228
auto dprops = at::cuda::getCurrentDeviceProperties();
222229
if (!check_sm_version<sm80, sm90>(dprops)) {
@@ -239,6 +246,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
239246
using sm50 = SMVersion<5, 0>;
240247
using sm90 = SMVersion<9, 0>;
241248
#if USE_ROCM
249+
#if USE_AOTRITON
242250
auto stream = at::cuda::getCurrentCUDAStream().stream();
243251
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
244252
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -248,6 +256,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
248256
}
249257
return false;
250258
}
259+
#else
260+
return false;
261+
#endif
251262
#else
252263
auto dprops = at::cuda::getCurrentDeviceProperties();
253264
if (!check_sm_version<sm50, sm90>(dprops)) {

cmake/Dependencies.cmake

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,10 +1096,6 @@ if(USE_ROCM)
10961096
message(STATUS "Disabling Kernel Assert for ROCm")
10971097
endif()
10981098

1099-
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
1100-
if(USE_CUDA)
1101-
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
1102-
endif()
11031099
else()
11041100
caffe2_update_option(USE_ROCM OFF)
11051101
endif()

0 commit comments

Comments
 (0)