Skip to content

Commit e76d630

Browse files
llama : grouped-query attention + LLaMAv2 70B support (#2276)
* CUDA: GQA implementation * llama : support for GQA and LLaMAv2 70B ggml-ci * py : fix hparams parsing (if-else blocks) ggml-ci * py : oh boy .. ggml-ci * help : fix gqa value for 70B ggml-ci --------- Co-authored-by: JohannesGaessler <[email protected]>
1 parent 1d0824b commit e76d630

File tree

7 files changed

+215
-108
lines changed

7 files changed

+215
-108
lines changed

convert.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
142142
@dataclass
143143
class Params:
144144
n_vocab: int
145-
n_embd: int
146-
n_mult: int
147-
n_head: int
145+
n_embd: int
146+
n_mult: int
147+
n_head: int
148148
n_layer: int
149149

150150
@staticmethod
@@ -167,40 +167,65 @@ def guessed(model: 'LazyModel') -> 'Params':
167167
n_head=n_embd // 128 # guessed
168168

169169
return Params(
170-
n_vocab=n_vocab,
171-
n_embd=n_embd,
172-
n_mult=256,
173-
n_head=n_head,
174-
n_layer=n_layer,
170+
n_vocab = n_vocab,
171+
n_embd = n_embd,
172+
n_mult = 256,
173+
n_head = n_head,
174+
n_layer = n_layer,
175175
)
176176

177177
@staticmethod
178178
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
179179
config = json.load(open(config_path))
180180

181181
n_vocab = config["vocab_size"];
182-
n_embd = config["hidden_size"];
183-
n_head = config["num_attention_heads"];
182+
n_embd = config["hidden_size"];
183+
n_head = config["num_attention_heads"];
184184
n_layer = config["num_hidden_layers"];
185-
n_ff = config["intermediate_size"];
185+
n_ff = config["intermediate_size"];
186186

187187
n_mult = find_n_mult(n_ff, n_embd);
188188

189189
return Params(
190-
n_vocab=n_vocab,
191-
n_embd=n_embd,
192-
n_mult=n_mult,
193-
n_head=n_head,
194-
n_layer=n_layer,
190+
n_vocab = n_vocab,
191+
n_embd = n_embd,
192+
n_mult = n_mult,
193+
n_head = n_head,
194+
n_layer = n_layer,
195+
)
196+
197+
# LLaMA v2 70B params.json
198+
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
199+
@staticmethod
200+
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
201+
config = json.load(open(config_path))
202+
203+
n_vocab = config["vocab_size"];
204+
n_embd = config["dim"];
205+
n_head = config["n_heads"];
206+
n_layer = config["n_layers"];
207+
n_mult = config["multiple_of"];
208+
209+
if n_vocab == -1:
210+
n_vocab = model["tok_embeddings.weight"].shape[0]
211+
212+
return Params(
213+
n_vocab = n_vocab,
214+
n_embd = n_embd,
215+
n_mult = n_mult,
216+
n_head = n_head,
217+
n_layer = n_layer,
195218
)
196219

197220
@staticmethod
198221
def load(model_plus: 'ModelPlus') -> 'Params':
222+
hf_config_path = model_plus.paths[0].parent / "config.json"
199223
orig_config_path = model_plus.paths[0].parent / "params.json"
200-
hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
201224

202-
if hf_transformer_config_path.exists():
203-
params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path)
225+
if hf_config_path.exists():
226+
params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
227+
elif orig_config_path.exists():
228+
params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
204229
else:
205230
params = Params.guessed(model_plus.model)
206231

@@ -1036,8 +1061,7 @@ def write_vocab(self, vocab: Vocab) -> None:
10361061
@staticmethod
10371062
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
10381063
of = OutputFile(fname_out)
1039-
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0,
1040-
n_head=1, n_layer=0)
1064+
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
10411065
of = OutputFile(fname_out)
10421066
of.write_file_header(params, file_type=GGMLFileType.AllF32)
10431067
of.write_vocab(vocab)

examples/common.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
168168
break;
169169
}
170170
params.n_ctx = std::stoi(argv[i]);
171+
} else if (arg == "-gqa" || arg == "--gqa") {
172+
if (++i >= argc) {
173+
invalid_param = true;
174+
break;
175+
}
176+
params.n_gqa = std::stoi(argv[i]);
171177
} else if (arg == "--rope-freq-base") {
172178
if (++i >= argc) {
173179
invalid_param = true;
@@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
485491
fprintf(stdout, " -f FNAME, --file FNAME\n");
486492
fprintf(stdout, " prompt file to start generation.\n");
487493
fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
494+
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
495+
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
496+
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
488497
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
489498
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
490499
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -505,15 +514,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
505514
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
506515
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
507516
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
508-
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
509517
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
510518
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
511519
fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
512520
fprintf(stdout, " --no-penalize-nl do not penalize newline token\n");
513521
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
514522
fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n");
515523
fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp);
516-
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
517524
fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n");
518525
fprintf(stdout, " --perplexity-lines compute perplexity over each line of the prompt\n");
519526
fprintf(stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
@@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
580587

581588
lparams.n_ctx = params.n_ctx;
582589
lparams.n_batch = params.n_batch;
590+
lparams.n_gqa = params.n_gqa;
583591
lparams.n_gpu_layers = params.n_gpu_layers;
584592
lparams.main_gpu = params.main_gpu;
585593
lparams.tensor_split = params.tensor_split;

examples/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct gpt_params {
2727
int32_t n_predict = -1; // new tokens to predict
2828
int32_t n_ctx = 512; // context size
2929
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
30+
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
3031
int32_t n_keep = 0; // number of tokens to keep from initial prompt
3132
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
3233
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
@@ -47,7 +48,7 @@ struct gpt_params {
4748
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
4849
float frequency_penalty = 0.00f; // 0.0 = disabled
4950
float presence_penalty = 0.00f; // 0.0 = disabled
50-
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
51+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
5152
float mirostat_tau = 5.00f; // target entropy
5253
float mirostat_eta = 0.10f; // learning rate
5354

examples/main/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
9393
}
9494

9595
if (params.n_ctx > 2048) {
96-
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
97-
" you are on your own\n", __func__, params.n_ctx);
96+
// TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048
97+
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx);
9898
} else if (params.n_ctx < 8) {
9999
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
100100
params.n_ctx = 8;

0 commit comments

Comments
 (0)