@@ -139,6 +139,20 @@ if (GGML_METAL)
139
139
)
140
140
endif ()
141
141
142
+ if (GGML_MUSA )
143
+ set (CMAKE_C_COMPILER clang )
144
+ set (CMAKE_C_EXTENSIONS OFF )
145
+ set (CMAKE_CXX_COMPILER clang++ )
146
+ set (CMAKE_CXX_EXTENSIONS OFF )
147
+
148
+ set (GGML_CUDA ON )
149
+ set (GGML_OPENMP OFF )
150
+
151
+ list (APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA )
152
+
153
+ add_compile_definitions (GGML_USE_MUSA )
154
+ endif ()
155
+
142
156
if (GGML_OPENMP )
143
157
find_package (OpenMP )
144
158
if (OpenMP_FOUND )
@@ -249,7 +263,13 @@ endif()
249
263
if (GGML_CUDA )
250
264
cmake_minimum_required (VERSION 3.18 ) # for CMAKE_CUDA_ARCHITECTURES
251
265
252
- find_package (CUDAToolkit )
266
+ if (GGML_MUSA )
267
+ list (APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/" )
268
+ find_package (MUSAToolkit )
269
+ set (CUDAToolkit_FOUND ${MUSAToolkit_FOUND} )
270
+ else ()
271
+ find_package (CUDAToolkit )
272
+ endif ()
253
273
254
274
if (CUDAToolkit_FOUND )
255
275
message (STATUS "CUDA found" )
@@ -268,7 +288,11 @@ if (GGML_CUDA)
268
288
endif ()
269
289
message (STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES} " )
270
290
271
- enable_language (CUDA )
291
+ if (GGML_MUSA )
292
+ set (CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE} )
293
+ else ()
294
+ enable_language (CUDA )
295
+ endif ()
272
296
273
297
file (GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh" )
274
298
list (APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h" )
@@ -332,21 +356,40 @@ if (GGML_CUDA)
332
356
add_compile_definitions (GGML_CUDA_NO_PEER_COPY )
333
357
endif ()
334
358
359
+ if (GGML_MUSA )
360
+ set_source_files_properties (${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX )
361
+ foreach (SOURCE ${GGML_SOURCES_CUDA} )
362
+ set_property (SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22" )
363
+ endforeach ()
364
+ endif ()
365
+
335
366
if (GGML_STATIC )
336
367
if (WIN32 )
337
368
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
338
369
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt )
339
370
else ()
340
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static )
371
+ if (GGML_MUSA )
372
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static )
373
+ else ()
374
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static )
375
+ endif ()
341
376
endif ()
342
377
else ()
343
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt )
378
+ if (GGML_MUSA )
379
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas )
380
+ else ()
381
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt )
382
+ endif ()
344
383
endif ()
345
384
346
385
if (GGML_CUDA_NO_VMM )
347
386
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
348
387
else ()
349
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver ) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
388
+ if (GGML_MUSA )
389
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver ) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
390
+ else ()
391
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver ) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
392
+ endif ()
350
393
endif ()
351
394
else ()
352
395
message (WARNING "CUDA not found" )
@@ -757,8 +800,10 @@ function(get_flags CCID CCVER)
757
800
set(C_FLAGS -Wdouble-promotion)
758
801
set(CXX_FLAGS -Wno-array-bounds)
759
802
760
- if (CCVER VERSION_GREATER_EQUAL 7.1.0)
761
- list(APPEND CXX_FLAGS -Wno-format-truncation)
803
+ if (NOT GGML_MUSA)
804
+ if (CCVER VERSION_GREATER_EQUAL 7.1.0)
805
+ list(APPEND CXX_FLAGS -Wno-format-truncation)
806
+ endif()
762
807
endif()
763
808
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
764
809
list(APPEND CXX_FLAGS -Wextra-semi)
@@ -1059,7 +1104,9 @@ if (GGML_CUDA)
1059
1104
list (JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED ) # pass host compiler flags as a single argument
1060
1105
1061
1106
if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "" )
1062
- list (APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED} )
1107
+ # list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
1108
+ # XXX: Removed flags: -Xcompiler
1109
+ list (APPEND CUDA_FLAGS ${CUDA_CXX_FLAGS_JOINED} )
1063
1110
endif ()
1064
1111
1065
1112
add_compile_options ("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS} >" )
0 commit comments