@@ -526,10 +526,21 @@ ifndef GGML_NO_ACCELERATE
526
526
endif
527
527
endif # GGML_NO_ACCELERATE
528
528
529
+ ifdef GGML_MUSA
530
+ CC := clang
531
+ CXX := clang++
532
+ GGML_CUDA := 1
533
+ MK_CPPFLAGS += -DGGML_USE_MUSA
534
+ endif
535
+
529
536
ifndef GGML_NO_OPENMP
530
537
MK_CPPFLAGS += -DGGML_USE_OPENMP
531
538
MK_CFLAGS += -fopenmp
532
539
MK_CXXFLAGS += -fopenmp
540
+ ifdef GGML_MUSA
541
+ MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp
542
+ MK_LDFLAGS += -L/usr/lib/llvm-10/lib
543
+ endif # GGML_MUSA
533
544
endif # GGML_NO_OPENMP
534
545
535
546
ifdef GGML_OPENBLAS
@@ -580,15 +591,27 @@ else
580
591
endif # GGML_CUDA_FA_ALL_QUANTS
581
592
582
593
ifdef GGML_CUDA
583
- ifneq ('', '$(wildcard /opt/cuda)')
584
- CUDA_PATH ?= /opt/cuda
594
+ ifdef GGML_MUSA
595
+ ifneq ('', '$(wildcard /opt/musa)')
596
+ CUDA_PATH ?= /opt/musa
597
+ else
598
+ CUDA_PATH ?= /usr/local/musa
599
+ endif
600
+
601
+ MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
602
+ MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
603
+ MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
585
604
else
586
- CUDA_PATH ?= /usr/local/cuda
587
- endif
605
+ ifneq ('', '$(wildcard /opt/cuda)')
606
+ CUDA_PATH ?= /opt/cuda
607
+ else
608
+ CUDA_PATH ?= /usr/local/cuda
609
+ endif
588
610
589
- MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
590
- MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
591
- MK_NVCCFLAGS += -use_fast_math
611
+ MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
612
+ MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
613
+ MK_NVCCFLAGS += -use_fast_math
614
+ endif # GGML_MUSA
592
615
593
616
OBJ_GGML += ggml/src/ggml-cuda.o
594
617
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
@@ -598,9 +621,11 @@ ifdef LLAMA_FATAL_WARNINGS
598
621
MK_NVCCFLAGS += -Werror all-warnings
599
622
endif # LLAMA_FATAL_WARNINGS
600
623
624
+ ifndef GGML_MUSA
601
625
ifndef JETSON_EOL_MODULE_DETECT
602
626
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
603
627
endif # JETSON_EOL_MODULE_DETECT
628
+ endif # GGML_MUSA
604
629
605
630
ifdef LLAMA_DEBUG
606
631
MK_NVCCFLAGS += -lineinfo
@@ -613,8 +638,12 @@ endif # GGML_CUDA_DEBUG
613
638
ifdef GGML_CUDA_NVCC
614
639
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
615
640
else
616
- NVCC = $(CCACHE) nvcc
617
- endif # GGML_CUDA_NVCC
641
+ ifdef GGML_MUSA
642
+ NVCC = $(CCACHE) mcc
643
+ else
644
+ NVCC = $(CCACHE) nvcc
645
+ endif # GGML_MUSA
646
+ endif # GGML_CUDA_NVCC
618
647
619
648
ifdef CUDA_DOCKER_ARCH
620
649
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
@@ -685,9 +714,15 @@ define NVCC_COMPILE
685
714
$(NVCC ) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS ) $(CPPFLAGS ) -Xcompiler "$(CUDA_CXXFLAGS ) " -c $< -o $@
686
715
endef # NVCC_COMPILE
687
716
else
717
+ ifdef GGML_MUSA
718
+ define NVCC_COMPILE
719
+ $(NVCC ) $(NVCCFLAGS ) $(CPPFLAGS ) -c $< -o $@
720
+ endef # NVCC_COMPILE
721
+ else
688
722
define NVCC_COMPILE
689
723
$(NVCC ) $(NVCCFLAGS ) $(CPPFLAGS ) -Xcompiler "$(CUDA_CXXFLAGS ) " -c $< -o $@
690
724
endef # NVCC_COMPILE
725
+ endif # GGML_MUSA
691
726
endif # JETSON_EOL_MODULE_DETECT
692
727
693
728
ggml/src/ggml-cuda/% .o : \
@@ -913,6 +948,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
913
948
ifdef GGML_CUDA
914
949
$(info I NVCC : $(shell $(NVCC ) --version | tail -n 1) )
915
950
CUDA_VERSION := $(shell $(NVCC ) --version | grep -oP 'release (\K[0-9]+\.[0-9]) ')
951
+ ifndef GGML_MUSA
916
952
ifeq ($(shell awk -v "v=$(CUDA_VERSION ) " 'BEGIN { print (v < 11.7) }'),1)
917
953
918
954
ifndef CUDA_DOCKER_ARCH
@@ -922,6 +958,7 @@ endif # CUDA_POWER_ARCH
922
958
endif # CUDA_DOCKER_ARCH
923
959
924
960
endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
961
+ endif # GGML_MUSA
925
962
endif # GGML_CUDA
926
963
$(info )
927
964
0 commit comments