Skip to content

Commit ba46057

Browse files
committed
Merge remote-tracking branch 'upstream/master' into cancel-model-load
2 parents ca122dc + 328b83d commit ba46057

32 files changed

+1955
-881
lines changed

.editorconfig

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@ insert_final_newline = unset
2323

2424
[examples/server/public/*]
2525
indent_size = 2
26+
27+
[examples/llama.swiftui/llama.swiftui.xcodeproj/*]
28+
indent_style = tab

CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,12 @@ if (LLAMA_CUBLAS)
291291
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
292292

293293
if (LLAMA_STATIC)
294-
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
294+
if (WIN32)
295+
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
296+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
297+
else ()
298+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
299+
endif()
295300
else()
296301
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
297302
endif()

Makefile

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,15 @@ ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
441441
endif # LLAMA_CLBLAST
442442

443443
ifdef LLAMA_HIPBLAS
444-
ROCM_PATH ?= /opt/rocm
445-
HIPCC ?= $(ROCM_PATH)/bin/hipcc
446-
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
444+
445+
ifeq ($(wildcard /opt/rocm),)
446+
ROCM_PATH ?= /usr
447+
GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
448+
else
449+
ROCM_PATH ?= /opt/rocm
450+
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
451+
endif
452+
HIPCC ?= $(ROCM_PATH)/bin/hipcc
447453
LLAMA_CUDA_DMMV_X ?= 32
448454
LLAMA_CUDA_MMV_Y ?= 1
449455
LLAMA_CUDA_KQUANTS_ITER ?= 2

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
1010

1111
### Hot topics
1212

13+
- Collecting Apple Silicon performance stats:
14+
- M-series: https://github.com/ggerganov/llama.cpp/discussions/4167
15+
- A-series: https://github.com/ggerganov/llama.cpp/discussions/4508
1316
- Added Mixtral support: https://github.com/ggerganov/llama.cpp/pull/4406
14-
- **llama.h API change for handling KV cache offloading and data type: https://github.com/ggerganov/llama.cpp/pull/4309**
15-
- Using `llama.cpp` with AWS instances: https://github.com/ggerganov/llama.cpp/discussions/4225
1617
- Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216
17-
- Collecting Apple Silicon performance stats: https://github.com/ggerganov/llama.cpp/discussions/4167
1818

1919
----
2020

common/train.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void free_random_uniform_distribution(struct random_uniform_distribution * rnd)
7171

7272
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
7373
float scale = 1.0f; // xavier
74-
switch (tensor->n_dims) {
74+
switch (ggml_n_dims(tensor)) {
7575
case 1:
7676
scale /= sqrtf((float) tensor->ne[0]);
7777
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
@@ -119,7 +119,7 @@ struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct
119119
}
120120

121121
struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
122-
switch (tensor->n_dims) {
122+
switch (ggml_n_dims(tensor)) {
123123
case 1:
124124
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
125125
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
@@ -183,25 +183,27 @@ float fclamp(const float v, const float min, const float max) {
183183
}
184184

185185
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
186-
GGML_ASSERT(tensor->n_dims == 1);
187186
GGML_ASSERT(tensor->ne[0] == ne0);
187+
GGML_ASSERT(tensor->ne[1] == 1);
188+
GGML_ASSERT(tensor->ne[2] == 1);
189+
GGML_ASSERT(tensor->ne[3] == 1);
188190
}
189191

190192
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
191-
GGML_ASSERT(tensor->n_dims == 2);
192193
GGML_ASSERT(tensor->ne[0] == ne0);
193194
GGML_ASSERT(tensor->ne[1] == ne1);
195+
GGML_ASSERT(tensor->ne[2] == 1);
196+
GGML_ASSERT(tensor->ne[3] == 1);
194197
}
195198

196199
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
197-
GGML_ASSERT(tensor->n_dims == 3);
198200
GGML_ASSERT(tensor->ne[0] == ne0);
199201
GGML_ASSERT(tensor->ne[1] == ne1);
200202
GGML_ASSERT(tensor->ne[2] == ne2);
203+
GGML_ASSERT(tensor->ne[3] == 1);
201204
}
202205

203206
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
204-
GGML_ASSERT(tensor->n_dims == 4);
205207
GGML_ASSERT(tensor->ne[0] == ne0);
206208
GGML_ASSERT(tensor->ne[1] == ne1);
207209
GGML_ASSERT(tensor->ne[2] == ne2);
@@ -225,8 +227,8 @@ int64_t get_example_targets_batch(
225227
bool sample_random_offsets
226228
) {
227229
GGML_ASSERT(samples_count > 0);
228-
GGML_ASSERT(tokens_input->n_dims == 2);
229-
GGML_ASSERT(target_probs->n_dims == 3);
230+
GGML_ASSERT(ggml_is_matrix(tokens_input));
231+
GGML_ASSERT(ggml_is_3d(target_probs));
230232
int64_t n_vocab = target_probs->ne[0];
231233
int64_t n_tokens = tokens_input->ne[0];
232234
int64_t n_batch = tokens_input->ne[1];

convert-hf-to-gguf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
182182
return QwenModel
183183
if model_architecture == "MixtralForCausalLM":
184184
return MixtralModel
185+
if model_architecture == "PhiForCausalLM":
186+
return Phi2Model
185187
return Model
186188

187189
def _is_model_safetensors(self) -> bool:
@@ -221,6 +223,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
221223
return gguf.MODEL_ARCH.QWEN
222224
if arch == "MixtralForCausalLM":
223225
return gguf.MODEL_ARCH.LLAMA
226+
if arch == "PhiForCausalLM":
227+
return gguf.MODEL_ARCH.PHI2
224228

225229
raise NotImplementedError(f'Architecture "{arch}" not supported!')
226230

@@ -980,6 +984,24 @@ def write_tensors(self):
980984
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
981985
self.gguf_writer.add_tensor(new_name, data)
982986

987+
988+
class Phi2Model(Model):
989+
def set_gguf_parameters(self):
990+
block_count = self.hparams["n_layer"]
991+
992+
self.gguf_writer.add_name("Phi2")
993+
self.gguf_writer.add_context_length(self.hparams["n_positions"])
994+
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
995+
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
996+
self.gguf_writer.add_block_count(block_count)
997+
self.gguf_writer.add_head_count(self.hparams["n_head"])
998+
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
999+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
1000+
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
1001+
self.gguf_writer.add_file_type(self.ftype)
1002+
self.gguf_writer.add_add_bos_token(False)
1003+
1004+
9831005
###### CONVERSION LOGIC ######
9841006

9851007

convert-lora-to-ggml.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,20 @@
33

44
import json
55
import os
6-
import re
76
import struct
87
import sys
98
from typing import Any, BinaryIO, Sequence
109

1110
import numpy as np
1211
import torch
1312

14-
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
15-
13+
from pathlib import Path
14+
if 'NO_LOCAL_GGUF' not in os.environ:
15+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
16+
import gguf
1617

17-
HF_SUBLAYER_TO_GGML = {
18-
"self_attn.q_proj": "attn_q",
19-
"self_attn.k_proj": "attn_k",
20-
"self_attn.v_proj": "attn_v",
21-
"self_attn.o_proj": "attn_output",
22-
"mlp.gate_proj": "ffn_gate",
23-
"mlp.down_proj": "ffn_down",
24-
"mlp.up_proj": "ffn_up",
25-
"input_layernorm": "attn_norm",
26-
"post_attention_layernorm": "ffn_norm",
27-
}
28-
29-
30-
def translate_tensor_name(t: str) -> str:
31-
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
32-
if match:
33-
nn = match.group(1)
34-
sub_layer = match.group(2)
35-
lora_type = match.group(3)
36-
37-
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
38-
if sub_layer_renamed is None:
39-
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
40-
sys.exit(1)
4118

42-
output_string = (
43-
f"blk.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
44-
)
45-
return output_string
46-
else:
47-
print(f"Error: unrecognized tensor {t}")
48-
sys.exit(1)
19+
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
4920

5021

5122
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
@@ -61,9 +32,7 @@ def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
6132
fout.write(struct.pack("i", int(params["lora_alpha"])))
6233

6334

64-
def write_tensor_header(
65-
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
66-
) -> None:
35+
def write_tensor_header(fout: BinaryIO, name: str, shape: Sequence[int], data_type: np.dtype[Any]) -> None:
6736
sname = name.encode("utf-8")
6837
fout.write(
6938
struct.pack(
@@ -78,18 +47,27 @@ def write_tensor_header(
7847
fout.seek((fout.tell() + 31) & -32)
7948

8049

81-
if len(sys.argv) != 2:
82-
print(f"Usage: python {sys.argv[0]} <path>")
50+
if len(sys.argv) < 2:
51+
print(f"Usage: python {sys.argv[0]} <path> [arch]")
8352
print(
8453
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
8554
)
55+
print(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
8656
sys.exit(1)
8757

8858
input_json = os.path.join(sys.argv[1], "adapter_config.json")
8959
input_model = os.path.join(sys.argv[1], "adapter_model.bin")
9060
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
9161

9262
model = torch.load(input_model, map_location="cpu")
63+
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
64+
65+
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
66+
print(f"Error: unsupported architecture {arch_name}")
67+
sys.exit(1)
68+
69+
arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
70+
name_map = gguf.TensorNameMap(arch, 200) # 200 layers ought to be enough for anyone
9371

9472
with open(input_json, "r") as f:
9573
params = json.load(f)
@@ -117,6 +95,7 @@ def write_tensor_header(
11795

11896
write_file_header(fout, params)
11997
for k, v in model.items():
98+
orig_k = k
12099
if k.endswith(".default.weight"):
121100
k = k.replace(".default.weight", ".weight")
122101
if k in ["llama_proj.weight", "llama_proj.bias"]:
@@ -129,7 +108,32 @@ def write_tensor_header(
129108
v = v.float()
130109

131110
t = v.detach().numpy()
132-
tname = translate_tensor_name(k)
111+
112+
prefix = "base_model.model."
113+
if k.startswith(prefix):
114+
k = k[len(prefix) :]
115+
116+
lora_suffixes = (".lora_A.weight", ".lora_B.weight")
117+
if k.endswith(lora_suffixes):
118+
suffix = k[-len(lora_suffixes[0]):]
119+
k = k[: -len(lora_suffixes[0])]
120+
else:
121+
print(f"Error: unrecognized tensor name {orig_k}")
122+
sys.exit(1)
123+
124+
tname = name_map.get_name(k)
125+
if tname is None:
126+
print(f"Error: could not map tensor name {orig_k}")
127+
print(" Note: the arch parameter must be specified if the model is not llama")
128+
sys.exit(1)
129+
130+
if suffix == ".lora_A.weight":
131+
tname += ".weight.loraA"
132+
elif suffix == ".lora_B.weight":
133+
tname += ".weight.loraB"
134+
else:
135+
assert False
136+
133137
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
134138
write_tensor_header(fout, tname, t.shape, t.dtype)
135139
t.tofile(fout)

examples/baby-llama/baby-llama.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,9 +1258,9 @@ static struct ggml_tensor * forward_lora(
12581258
}
12591259

12601260
static void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) {
1261-
assert(logits->n_dims == 2);
1262-
assert(probs->n_dims == 2);
1263-
assert(best_samples->n_dims == 1);
1261+
assert(ggml_is_matrix(logits));
1262+
assert(ggml_is_matrix(probs));
1263+
assert(ggml_is_vector(best_samples));
12641264
assert(logits->ne[1] == best_samples->ne[0]);
12651265
assert(logits->ne[0] == probs->ne[0]);
12661266
assert(logits->ne[1] == probs->ne[1]);
@@ -1292,9 +1292,9 @@ static void sample_softmax_batch(
12921292
struct ggml_context * ctx, struct ggml_tensor * logits, struct ggml_tensor * probs,
12931293
struct ggml_tensor * best_samples
12941294
) {
1295-
GGML_ASSERT(best_samples->n_dims == 2);
1296-
GGML_ASSERT(logits->n_dims == 3);
1297-
GGML_ASSERT(probs->n_dims == 3);
1295+
GGML_ASSERT(ggml_is_matrix(best_samples));
1296+
GGML_ASSERT(ggml_is_3d(logits));
1297+
GGML_ASSERT(ggml_is_3d(probs));
12981298
int n_tokens = best_samples->ne[0];
12991299
int n_batch = best_samples->ne[1];
13001300
int n_vocab = logits->ne[0];
@@ -1334,7 +1334,7 @@ static void print_row(struct ggml_tensor * probs, int i) {
13341334
}
13351335

13361336
static void print_matrix(struct ggml_tensor * probs) {
1337-
assert(probs->n_dims == 2);
1337+
assert(ggml_is_matrix(probs));
13381338
for (int i = 0; i < probs->ne[1]; ++i) {
13391339
for (int k = 0; k < probs->ne[0]; ++k) {
13401340
float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k);
@@ -1386,8 +1386,8 @@ static void get_example_targets(int example_id, struct ggml_tensor * tokens_inpu
13861386
static void get_example_targets_batch(
13871387
struct ggml_context * ctx, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets
13881388
) {
1389-
GGML_ASSERT(tokens_input->n_dims == 2);
1390-
GGML_ASSERT( targets->n_dims == 3);
1389+
GGML_ASSERT(ggml_is_matrix(tokens_input));
1390+
GGML_ASSERT(ggml_is_3d(targets));
13911391
int n_tokens = tokens_input->ne[0];
13921392
int n_batch = tokens_input->ne[1];
13931393
GGML_ASSERT(n_tokens == targets->ne[1]);

examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ static void print_row(struct ggml_tensor * probs, int i) {
427427
}
428428

429429
static void print_matrix(struct ggml_tensor * probs) {
430-
assert(probs->n_dims == 2);
430+
assert(ggml_is_matrix(probs));
431431
for (int i = 0; i < probs->ne[1]; ++i) {
432432
for (int k = 0; k < probs->ne[0]; ++k) {
433433
float p = get_f32_2d(probs, k, i);
@@ -639,7 +639,7 @@ static void load_vocab(const char *filename, Config *config, struct llama_vocab
639639

640640
static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) {
641641
int ct;
642-
switch (gg_weights->n_dims){
642+
switch (ggml_n_dims(gg_weights)) {
643643
case 1:
644644
ct = 0;
645645
for (int i0 = 0; i0 < gg_weights->ne[0]; i0++){

0 commit comments

Comments
 (0)