@@ -139,6 +139,19 @@ if (GGML_METAL)
139
139
)
140
140
endif ()
141
141
142
+ if (GGML_MUSA )
143
+ set (CMAKE_C_COMPILER clang )
144
+ set (CMAKE_C_EXTENSIONS OFF )
145
+ set (CMAKE_CXX_COMPILER clang++ )
146
+ set (CMAKE_CXX_EXTENSIONS OFF )
147
+
148
+ set (GGML_CUDA ON )
149
+
150
+ list (APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA )
151
+
152
+ add_compile_definitions (GGML_USE_MUSA )
153
+ endif ()
154
+
142
155
if (GGML_OPENMP )
143
156
find_package (OpenMP )
144
157
if (OpenMP_FOUND )
@@ -147,6 +160,11 @@ if (GGML_OPENMP)
147
160
add_compile_definitions (GGML_USE_OPENMP )
148
161
149
162
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX )
163
+
164
+ if (GGML_MUSA )
165
+ set (GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp" )
166
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so" )
167
+ endif ()
150
168
else ()
151
169
message (WARNING "OpenMP not found" )
152
170
endif ()
@@ -249,7 +267,13 @@ endif()
249
267
if (GGML_CUDA )
250
268
cmake_minimum_required (VERSION 3.18 ) # for CMAKE_CUDA_ARCHITECTURES
251
269
252
- find_package (CUDAToolkit )
270
+ if (GGML_MUSA )
271
+ list (APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/" )
272
+ find_package (MUSAToolkit )
273
+ set (CUDAToolkit_FOUND ${MUSAToolkit_FOUND} )
274
+ else ()
275
+ find_package (CUDAToolkit )
276
+ endif ()
253
277
254
278
if (CUDAToolkit_FOUND )
255
279
message (STATUS "CUDA found" )
@@ -268,7 +292,11 @@ if (GGML_CUDA)
268
292
endif ()
269
293
message (STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES} " )
270
294
271
- enable_language (CUDA )
295
+ if (GGML_MUSA )
296
+ set (CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE} )
297
+ else ()
298
+ enable_language (CUDA )
299
+ endif ()
272
300
273
301
file (GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh" )
274
302
list (APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h" )
@@ -332,21 +360,40 @@ if (GGML_CUDA)
332
360
add_compile_definitions (GGML_CUDA_NO_PEER_COPY )
333
361
endif ()
334
362
363
+ if (GGML_MUSA )
364
+ set_source_files_properties (${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX )
365
+ foreach (SOURCE ${GGML_SOURCES_CUDA} )
366
+ set_property (SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22" )
367
+ endforeach ()
368
+ endif ()
369
+
335
370
if (GGML_STATIC )
336
371
if (WIN32 )
337
372
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
338
373
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt )
339
374
else ()
340
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static )
375
+ if (GGML_MUSA )
376
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static )
377
+ else ()
378
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static )
379
+ endif ()
341
380
endif ()
342
381
else ()
343
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt )
382
+ if (GGML_MUSA )
383
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas )
384
+ else ()
385
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt )
386
+ endif ()
344
387
endif ()
345
388
346
389
if (GGML_CUDA_NO_VMM )
347
390
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
348
391
else ()
349
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver ) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
392
+ if (GGML_MUSA )
393
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver ) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
394
+ else ()
395
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver ) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
396
+ endif ()
350
397
endif ()
351
398
else ()
352
399
message (WARNING "CUDA not found" )
@@ -757,8 +804,10 @@ function(get_flags CCID CCVER)
757
804
set(C_FLAGS -Wdouble-promotion)
758
805
set(CXX_FLAGS -Wno-array-bounds)
759
806
760
- if (CCVER VERSION_GREATER_EQUAL 7.1.0)
761
- list(APPEND CXX_FLAGS -Wno-format-truncation)
807
+ if (NOT GGML_MUSA)
808
+ if (CCVER VERSION_GREATER_EQUAL 7.1.0)
809
+ list(APPEND CXX_FLAGS -Wno-format-truncation)
810
+ endif()
762
811
endif()
763
812
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
764
813
list(APPEND CXX_FLAGS -Wextra-semi)
@@ -1059,7 +1108,9 @@ if (GGML_CUDA)
1059
1108
list (JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED ) # pass host compiler flags as a single argument
1060
1109
1061
1110
if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "" )
1062
- list (APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED} )
1111
+ # list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
1112
+ # XXX: Removed flags: -Xcompiler
1113
+ list (APPEND CUDA_FLAGS ${CUDA_CXX_FLAGS_JOINED} )
1063
1114
endif ()
1064
1115
1065
1116
add_compile_options ("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS} >" )
@@ -1163,6 +1214,7 @@ endif()
1163
1214
target_compile_definitions (ggml PUBLIC ${GGML_CDEF_PUBLIC} )
1164
1215
target_include_directories (ggml PUBLIC ../include )
1165
1216
target_include_directories (ggml PRIVATE . ${GGML_EXTRA_INCLUDES} )
1217
+ target_link_directories (ggml PRIVATE ${GGML_EXTRA_LIBDIRS} )
1166
1218
target_compile_features (ggml PRIVATE c_std_11 ) # don't bump
1167
1219
1168
1220
target_link_libraries (ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS} )
0 commit comments