Skip to content

Commit 7d6d91b

Browse files
authored
HIP: disable rocwmma on gfx12 by default until rocm 7.0 (#14202)
1 parent d3e64b9 commit 7d6d91b

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ option(GGML_HIP "ggml: use HIP"
172172
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
173173
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
174174
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
175+
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
175176
option(GGML_VULKAN "ggml: use Vulkan" OFF)
176177
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
177178
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ typedef float2 dfloat2;
207207
#define FP16_MMA_AVAILABLE
208208
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
209209

210-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
210+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
211211
#define FP16_MMA_AVAILABLE
212-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
212+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
213213

214214
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
215215
#define NEW_MMA_AVAILABLE

ggml/src/ggml-hip/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
113113
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
114114
endif()
115115

116+
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
117+
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
118+
endif()
119+
116120
if (NOT GGML_CUDA_FA)
117121
add_compile_definitions(GGML_CUDA_NO_FA)
118122
endif()

0 commit comments

Comments
 (0)