Skip to content

Commit 3686456

Browse files
nicholaiTukanovntukanov
andauthored
ggml : add NVPL BLAS support (#8329) (#8425)
* ggml : add NVPL BLAS support * ggml : replace `<BLASLIB>_ENABLE_CBLAS` with `GGML_BLAS_USE_<BLASLIB>` --------- Co-authored-by: ntukanov <[email protected]>
1 parent b078c61 commit 3686456

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,11 +547,17 @@ ifdef GGML_OPENBLAS64
547547
endif # GGML_OPENBLAS64
548548

549549
ifdef GGML_BLIS
550-
MK_CPPFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis
550+
MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis
551551
MK_LDFLAGS += -lblis -L/usr/local/lib
552552
OBJ_GGML += ggml/src/ggml-blas.o
553553
endif # GGML_BLIS
554554

555+
ifdef GGML_NVPL
556+
MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas
557+
MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp
558+
OBJ_GGML += ggml/src/ggml-blas.o
559+
endif # GGML_NVPL
560+
555561
ifndef GGML_NO_LLAMAFILE
556562
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
557563
OBJ_GGML += ggml/src/llamafile/sgemm.o

ggml/src/ggml-blas.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
# include <Accelerate/Accelerate.h>
99
#elif defined(GGML_BLAS_USE_MKL)
1010
# include <mkl.h>
11+
#elif defined(GGML_BLAS_USE_BLIS)
12+
# include <blis.h>
13+
#elif defined(GGML_BLAS_USE_NVPL)
14+
# include <nvpl_blas.h>
1115
#else
1216
# include <cblas.h>
13-
# ifdef BLIS_ENABLE_CBLAS
14-
# include <blis.h>
15-
# endif
1617
#endif
1718

1819
struct ggml_backend_blas_context {
@@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
140141
openblas_set_num_threads(ctx->n_threads);
141142
#endif
142143

143-
#if defined(BLIS_ENABLE_CBLAS)
144+
#if defined(GGML_BLAS_USE_BLIS)
144145
bli_thread_set_num_threads(ctx->n_threads);
145146
#endif
146147

148+
#if defined(GGML_BLAS_USE_NVPL)
149+
nvpl_blas_set_num_threads(ctx->n_threads);
150+
#endif
151+
147152
for (int64_t i13 = 0; i13 < ne13; i13++) {
148153
for (int64_t i12 = 0; i12 < ne12; i12++) {
149154
const int64_t i03 = i13/r3;

0 commit comments

Comments
 (0)