Skip to content

Commit 28c8e93

Browse files
committed
Merge remote-tracking branch 'upstream/master' into ntkv2
2 parents 50879df + 0728c5a commit 28c8e93

21 files changed

+3051
-657
lines changed

CMakeLists.txt

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ endif()
6767
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
6868
option(LLAMA_BLAS "llama: use BLAS" OFF)
6969
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
70-
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
70+
option(LLAMA_CUBLAS "llama: use CUDA" OFF)
71+
#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
72+
set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
7173
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
7274
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
7375
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
@@ -251,6 +253,10 @@ if (LLAMA_CUBLAS)
251253
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
252254

253255
add_compile_definitions(GGML_USE_CUBLAS)
256+
# if (LLAMA_CUDA_CUBLAS)
257+
# add_compile_definitions(GGML_CUDA_CUBLAS)
258+
# endif()
259+
add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
254260
if (LLAMA_CUDA_FORCE_DMMV)
255261
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
256262
endif()
@@ -271,10 +277,14 @@ if (LLAMA_CUBLAS)
271277
endif()
272278

273279
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
280+
# 52 == lowest CUDA 12 standard
281+
# 60 == f16 CUDA intrinsics
282+
# 61 == integer CUDA intrinsics
283+
# 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
274284
if (LLAMA_CUDA_DMMV_F16)
275-
set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics
285+
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
276286
else()
277-
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
287+
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
278288
endif()
279289
endif()
280290
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
@@ -497,6 +507,8 @@ endif()
497507
add_library(ggml OBJECT
498508
ggml.c
499509
ggml.h
510+
ggml-alloc.c
511+
ggml-alloc.h
500512
${GGML_SOURCES_CUDA}
501513
${GGML_SOURCES_OPENCL}
502514
${GGML_SOURCES_METAL}

Makefile

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ ifdef LLAMA_CUBLAS
194194
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
195195
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
196196
OBJS += ggml-cuda.o
197-
NVCCFLAGS = --forward-unknown-to-host-compiler
197+
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
198198
ifdef LLAMA_CUDA_NVCC
199199
NVCC = $(LLAMA_CUDA_NVCC)
200200
else
@@ -220,14 +220,25 @@ else ifdef LLAMA_CUDA_DMMV_Y
220220
else
221221
NVCCFLAGS += -DGGML_CUDA_MMV_Y=1
222222
endif # LLAMA_CUDA_MMV_Y
223+
ifdef LLAMA_CUDA_F16
224+
NVCCFLAGS += -DGGML_CUDA_F16
225+
endif # LLAMA_CUDA_F16
223226
ifdef LLAMA_CUDA_DMMV_F16
224-
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
227+
NVCCFLAGS += -DGGML_CUDA_F16
225228
endif # LLAMA_CUDA_DMMV_F16
226229
ifdef LLAMA_CUDA_KQUANTS_ITER
227230
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
228231
else
229232
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
230233
endif
234+
ifdef LLAMA_CUDA_MMQ_Y
235+
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y)
236+
else
237+
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
238+
endif # LLAMA_CUDA_MMQ_Y
239+
#ifdef LLAMA_CUDA_CUBLAS
240+
# NVCCFLAGS += -DGGML_CUDA_CUBLAS
241+
#endif # LLAMA_CUDA_CUBLAS
231242
ifdef LLAMA_CUDA_CCBIN
232243
NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
233244
endif
@@ -318,7 +329,12 @@ $(info )
318329
ggml.o: ggml.c ggml.h ggml-cuda.h
319330
$(CC) $(CFLAGS) -c $< -o $@
320331

321-
llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
332+
ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
333+
$(CC) $(CFLAGS) -c $< -o $@
334+
335+
OBJS += ggml-alloc.o
336+
337+
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
322338
$(CXX) $(CXXFLAGS) -c $< -o $@
323339

324340
common.o: examples/common.cpp examples/common.h

README.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ as the main playground for developing new features for the [ggml](https://github
7777
**Supported models:**
7878

7979
- [X] LLaMA 🦙
80+
- [x] LLaMA 2 🦙🦙
8081
- [X] [Alpaca](https://github.com/ggerganov/llama.cpp#instruction-mode-with-alpaca)
8182
- [X] [GPT4All](https://github.com/ggerganov/llama.cpp#using-gpt4all)
8283
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
@@ -399,12 +400,16 @@ Building the program with BLAS support may lead to some performance improvements
399400

400401
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance:
401402

403+
<!---
404+
| LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
405+
--->
402406
| Option | Legal values | Default | Description |
403407
|-------------------------|------------------------|---------|-------------|
408+
| LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
404409
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
405410
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
406-
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
407-
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
411+
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
412+
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
408413
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
409414

410415
- #### CLBlast
@@ -650,6 +655,19 @@ python3 convert.py pygmalion-7b/ --outtype q4_1
650655
- The LLaMA models are officially distributed by Facebook and will **never** be provided through this repository.
651656
- Refer to [Facebook's LLaMA repository](https://github.com/facebookresearch/llama/pull/73/files) if you need to request access to the model data.
652657
658+
### Obtaining and using the Facebook LLaMA 2 model
659+
660+
- Refer to [Facebook's LLaMA download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) if you want to access the model data.
661+
- Alternatively, if you want to save time and space, you can download already converted and quantized models from [TheBloke](https://huggingface.co/TheBloke), including:
662+
- [LLaMA 2 7B base](https://huggingface.co/TheBloke/Llama-2-7B-GGML)
663+
- [LLaMA 2 13B base](https://huggingface.co/TheBloke/Llama-2-13B-GGML)
664+
- [LLaMA 2 70B base](https://huggingface.co/TheBloke/Llama-2-70B-GGML)
665+
- [LLaMA 2 7B chat](https://huggingface.co/TheBloke/Llama-2-7B-chat-GGML)
666+
- [LLaMA 2 13B chat](https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML)
667+
- [LLaMA 2 70B chat](https://huggingface.co/TheBloke/Llama-2-70B-chat-GGML)
668+
- Specify `-eps 1e-5` for best generation quality
669+
- Specify `-gqa 8` for 70B models to work
670+
653671
### Verifying the model files
654672
655673
Please verify the [sha256 checksums](SHA256SUMS) of all downloaded model files to confirm that you have the correct model data files before creating an issue relating to your model files.

convert.py

100755100644
Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,20 @@ def make_tensors_list() -> List[str]:
133133

134134
def find_n_mult(n_ff: int, n_embd: int) -> int:
135135
# hardcoded magic range
136-
for n_mult in range(256, 1, -1):
136+
for n_mult in range(8192, 1, -1):
137137
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
138138
if calc_ff == n_ff:
139139
return n_mult
140140
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
141141

142142
@dataclass
143143
class Params:
144-
n_vocab: int
145-
n_embd: int
146-
n_mult: int
147-
n_head: int
148-
n_layer: int
144+
n_vocab: int
145+
n_embd: int
146+
n_mult: int
147+
n_head: int
148+
n_layer: int
149+
n_kv_head: Optional[int] # This parameter is only used for Llama 2
149150

150151
@staticmethod
151152
def guessed(model: 'LazyModel') -> 'Params':
@@ -167,11 +168,12 @@ def guessed(model: 'LazyModel') -> 'Params':
167168
n_head=n_embd // 128 # guessed
168169

169170
return Params(
170-
n_vocab = n_vocab,
171-
n_embd = n_embd,
172-
n_mult = 256,
173-
n_head = n_head,
174-
n_layer = n_layer,
171+
n_vocab = n_vocab,
172+
n_embd = n_embd,
173+
n_mult = 256,
174+
n_head = n_head,
175+
n_layer = n_layer,
176+
n_kv_head = None,
175177
)
176178

177179
@staticmethod
@@ -183,15 +185,17 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
183185
n_head = config["num_attention_heads"];
184186
n_layer = config["num_hidden_layers"];
185187
n_ff = config["intermediate_size"];
188+
n_kv_head = config.get("num_key_value_heads")
186189

187190
n_mult = find_n_mult(n_ff, n_embd);
188191

189192
return Params(
190-
n_vocab = n_vocab,
191-
n_embd = n_embd,
192-
n_mult = n_mult,
193-
n_head = n_head,
194-
n_layer = n_layer,
193+
n_vocab = n_vocab,
194+
n_embd = n_embd,
195+
n_mult = n_mult,
196+
n_head = n_head,
197+
n_layer = n_layer,
198+
n_kv_head = n_kv_head,
195199
)
196200

197201
# LLaMA v2 70B params.json
@@ -200,21 +204,22 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
200204
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
201205
config = json.load(open(config_path))
202206

203-
n_vocab = config["vocab_size"];
204-
n_embd = config["dim"];
205-
n_head = config["n_heads"];
206-
n_layer = config["n_layers"];
207-
n_mult = config["multiple_of"];
207+
n_vocab = config["vocab_size"];
208+
n_embd = config["dim"];
209+
n_head = config["n_heads"];
210+
n_layer = config["n_layers"];
211+
n_mult = config["multiple_of"];
208212

209213
if n_vocab == -1:
210214
n_vocab = model["tok_embeddings.weight"].shape[0]
211215

212216
return Params(
213-
n_vocab = n_vocab,
214-
n_embd = n_embd,
215-
n_mult = n_mult,
216-
n_head = n_head,
217-
n_layer = n_layer,
217+
n_vocab = n_vocab,
218+
n_embd = n_embd,
219+
n_mult = n_mult,
220+
n_head = n_head,
221+
n_layer = n_layer,
222+
n_kv_head = None,
218223
)
219224

220225
@staticmethod
@@ -317,10 +322,12 @@ def __repr__(self) -> str:
317322
Vocab = Union[SentencePieceVocab, GGMLVocab]
318323

319324

320-
def permute(weights: NDArray, n_head: int) -> NDArray:
325+
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
326+
if n_kv_head is not None and n_head != n_kv_head:
327+
n_head //= n_kv_head
321328
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
322-
.swapaxes(1, 2)
323-
.reshape(weights.shape))
329+
.swapaxes(1, 2)
330+
.reshape(weights.shape))
324331

325332

326333
def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray:
@@ -368,7 +375,7 @@ class Tensor(metaclass=ABCMeta):
368375
@abstractmethod
369376
def astype(self, data_type: DataType) -> 'Tensor': ...
370377
@abstractmethod
371-
def permute(self, n_head: int) -> 'Tensor': ...
378+
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ...
372379
@abstractmethod
373380
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
374381
@abstractmethod
@@ -406,8 +413,8 @@ def part(self, n_part: int) -> 'UnquantizedTensor':
406413
r = self.ndarray.shape[0] // 3
407414
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
408415

409-
def permute(self, n_head: int) -> 'UnquantizedTensor':
410-
return UnquantizedTensor(permute(self.ndarray, n_head))
416+
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor':
417+
return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head))
411418

412419

413420
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
@@ -455,26 +462,27 @@ def astype(self, data_type: DataType) -> Tensor:
455462
def to_ggml(self) -> 'GGMLQuantizedTensor':
456463
return self
457464

458-
def permute(self, n_head: int) -> 'GGMLQuantizedTensor':
459-
return GGMLQuantizedTensor(permute(self.ndarray, n_head), self.shape, self.data_type)
465+
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'GGMLQuantizedTensor':
466+
return GGMLQuantizedTensor(permute(self.ndarray, n_head, n_kv_head), self.shape, self.data_type)
460467

461468

462469
GGMLCompatibleTensor = Union[UnquantizedTensor, GGMLQuantizedTensor]
463470

464471

465472
class DeferredPermutedTensor(Tensor):
466-
def __init__(self, base: Tensor, n_head: int) -> None:
473+
def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None:
467474
self.base = base
468475
self.n_head = n_head
476+
self.n_kv_head = n_kv_head
469477
self.data_type = self.base.data_type
470478

471479
def astype(self, data_type: DataType) -> Tensor:
472-
return self.base.astype(data_type).permute(self.n_head)
480+
return self.base.astype(data_type).permute(self.n_head, self.n_kv_head)
473481

474482
def to_ggml(self) -> GGMLCompatibleTensor:
475-
return self.base.to_ggml().permute(self.n_head)
483+
return self.base.to_ggml().permute(self.n_head, self.n_kv_head)
476484

477-
def permute(self, n_head: int) -> Tensor:
485+
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
478486
raise Exception("shouldn't permute twice")
479487

480488

@@ -566,8 +574,8 @@ def regroup(self, new_groupsize: int = 32) -> 'GPTQForLLaMaQuantizedTensor':
566574
ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False)
567575
return ret
568576

569-
def permute(self, n_head: int) -> Tensor:
570-
return DeferredPermutedTensor(self, n_head)
577+
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
578+
return DeferredPermutedTensor(self, n_head, n_kv_head)
571579

572580
def to_ggml(self) -> GGMLQuantizedTensor:
573581
# The output format looks like this:
@@ -698,10 +706,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
698706
return ModelPlus(model, paths, format, vocab)
699707

700708

701-
def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor:
709+
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor:
702710
def load() -> Tensor:
703-
return lazy_tensor.load().permute(n_head)
704-
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
711+
return lazy_tensor.load().permute(n_head, n_kv_head)
712+
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description)
705713

706714
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
707715
def load() -> Tensor:
@@ -726,7 +734,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
726734
for i in itertools.count():
727735
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
728736
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
729-
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
737+
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head)
730738
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
731739
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
732740
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)

0 commit comments

Comments
 (0)