Skip to content

Commit 30cd674

Browse files
nopperlpull[bot]
authored andcommitted
Implement the OLMo architecture (#6741)
* implement olmo architecture * remove unused variable * remove unused moe branch * remove check for weight * remove superfluous moe, bias and rope tensors * clarified comment * fix clamp_kqv setting * remove obsolete parameter name filter
1 parent 86db56c commit 30cd674

File tree

4 files changed

+271
-0
lines changed

4 files changed

+271
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ Typically finetunes of the base models below are supported as well.
122122
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
123123
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
124124
- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
125+
- [x] [OLMo](https://allenai.org/olmo)
125126

126127
(instructions for supporting more models: [HOWTO-add-model.md](./docs/HOWTO-add-model.md))
127128

convert-hf-to-gguf.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2636,6 +2636,66 @@ def set_gguf_parameters(self):
26362636
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
26372637

26382638

2639+
@Model.register("OlmoForCausalLM")
2640+
@Model.register("OLMoForCausalLM")
2641+
class OlmoModel(Model):
2642+
model_arch = gguf.MODEL_ARCH.OLMO
2643+
2644+
def set_gguf_parameters(self):
2645+
super().set_gguf_parameters()
2646+
self.gguf_writer.add_layer_norm_eps(1e-5)
2647+
if "clip_qkv" in self.hparams is not None:
2648+
self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"])
2649+
2650+
# Same as super class, but permuting q_proj, k_proj
2651+
# Copied from: LlamaModel
2652+
def write_tensors(self):
2653+
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
2654+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
2655+
n_head = self.hparams.get("num_attention_heads")
2656+
n_kv_head = self.hparams.get("num_key_value_heads")
2657+
for name, data_torch in self.get_tensors():
2658+
old_dtype = data_torch.dtype
2659+
2660+
# convert any unsupported data types to float32
2661+
if data_torch.dtype not in (torch.float16, torch.float32):
2662+
data_torch = data_torch.to(torch.float32)
2663+
2664+
data = data_torch.numpy()
2665+
2666+
if name.endswith("q_proj.weight"):
2667+
data = permute(data, n_head, n_head)
2668+
if name.endswith("k_proj.weight"):
2669+
data = permute(data, n_head, n_kv_head)
2670+
2671+
data = data.squeeze()
2672+
2673+
# map tensor names
2674+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
2675+
if new_name is None:
2676+
print(f"Can not map tensor {name!r}")
2677+
sys.exit()
2678+
2679+
n_dims = len(data.shape)
2680+
data_dtype = data.dtype
2681+
2682+
# if f32 desired, convert any float16 to float32
2683+
if self.ftype == 0 and data_dtype == np.float16:
2684+
data = data.astype(np.float32)
2685+
2686+
# 1d tensors need to be converted to float32
2687+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
2688+
data = data.astype(np.float32)
2689+
2690+
# if f16 desired, convert any float32 2-dim weight tensors to float16
2691+
if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2:
2692+
data = data.astype(np.float16)
2693+
2694+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
2695+
2696+
self.gguf_writer.add_tensor(new_name, data)
2697+
2698+
26392699
###### CONVERSION LOGIC ######
26402700

26412701

gguf-py/gguf/constants.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum):
135135
XVERSE = auto()
136136
COMMAND_R = auto()
137137
DBRX = auto()
138+
OLMO = auto()
138139

139140

140141
class MODEL_TENSOR(IntEnum):
@@ -210,6 +211,7 @@ class MODEL_TENSOR(IntEnum):
210211
MODEL_ARCH.XVERSE: "xverse",
211212
MODEL_ARCH.COMMAND_R: "command-r",
212213
MODEL_ARCH.DBRX: "dbrx",
214+
MODEL_ARCH.OLMO: "olmo",
213215
}
214216

215217
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -695,6 +697,17 @@ class MODEL_TENSOR(IntEnum):
695697
MODEL_TENSOR.FFN_DOWN_EXP,
696698
MODEL_TENSOR.FFN_UP_EXP,
697699
],
700+
MODEL_ARCH.OLMO: [
701+
MODEL_TENSOR.TOKEN_EMBD,
702+
MODEL_TENSOR.OUTPUT,
703+
MODEL_TENSOR.ATTN_Q,
704+
MODEL_TENSOR.ATTN_K,
705+
MODEL_TENSOR.ATTN_V,
706+
MODEL_TENSOR.ATTN_OUT,
707+
MODEL_TENSOR.FFN_GATE,
708+
MODEL_TENSOR.FFN_DOWN,
709+
MODEL_TENSOR.FFN_UP,
710+
],
698711
# TODO
699712
}
700713

llama.cpp

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ enum llm_arch {
222222
LLM_ARCH_XVERSE,
223223
LLM_ARCH_COMMAND_R,
224224
LLM_ARCH_DBRX,
225+
LLM_ARCH_OLMO,
225226
LLM_ARCH_UNKNOWN,
226227
};
227228

@@ -256,6 +257,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
256257
{ LLM_ARCH_XVERSE, "xverse" },
257258
{ LLM_ARCH_COMMAND_R, "command-r" },
258259
{ LLM_ARCH_DBRX, "dbrx" },
260+
{ LLM_ARCH_OLMO, "olmo" },
259261
{ LLM_ARCH_UNKNOWN, "(unknown)" },
260262
};
261263

@@ -990,6 +992,20 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
990992
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
991993
},
992994
},
995+
{
996+
LLM_ARCH_OLMO,
997+
{
998+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
999+
{ LLM_TENSOR_OUTPUT, "output" },
1000+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1001+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1002+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1003+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1004+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1005+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1006+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1007+
},
1008+
},
9931009
{
9941010
LLM_ARCH_UNKNOWN,
9951011
{
@@ -4070,6 +4086,18 @@ static void llm_load_hparams(
40704086
default: model.type = e_model::MODEL_UNKNOWN;
40714087
}
40724088
} break;
4089+
case LLM_ARCH_OLMO:
4090+
{
4091+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
4092+
ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false);
4093+
4094+
switch (hparams.n_layer) {
4095+
case 22: model.type = e_model::MODEL_1B; break;
4096+
case 32: model.type = e_model::MODEL_7B; break;
4097+
case 80: model.type = e_model::MODEL_70B; break;
4098+
default: model.type = e_model::MODEL_UNKNOWN;
4099+
}
4100+
} break;
40734101
default: (void)0;
40744102
}
40754103

@@ -5666,6 +5694,37 @@ static bool llm_load_tensors(
56665694
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
56675695
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
56685696

5697+
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
5698+
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
5699+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
5700+
}
5701+
} break;
5702+
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
5703+
{
5704+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
5705+
5706+
// output
5707+
{
5708+
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
5709+
// if output is NULL, init from the input tok embed
5710+
if (model.output == NULL) {
5711+
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
5712+
ml.n_created--; // artificial tensor
5713+
ml.size_data += ggml_nbytes(model.output);
5714+
}
5715+
}
5716+
5717+
for (int i = 0; i < n_layer; ++i) {
5718+
ggml_context * ctx_split = ctx_for_layer_split(i);
5719+
5720+
auto & layer = model.layers[i];
5721+
5722+
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
5723+
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
5724+
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
5725+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
5726+
5727+
56695728
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
56705729
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
56715730
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
@@ -10096,6 +10155,139 @@ struct llm_build_context {
1009610155
return gf;
1009710156

1009810157
}
10158+
10159+
// ref: https://allenai.org/olmo
10160+
// based on the original build_llama() function, changes:
10161+
// * non-parametric layer norm
10162+
// * clamp qkv
10163+
// * removed bias
10164+
// * removed MoE
10165+
struct ggml_cgraph * build_olmo() {
10166+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
10167+
10168+
// mutable variable, needed during the last layer of the computation to skip unused tokens
10169+
int32_t n_tokens = this->n_tokens;
10170+
10171+
const int64_t n_embd_head = hparams.n_embd_head_v;
10172+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
10173+
GGML_ASSERT(n_embd_head == hparams.n_rot);
10174+
10175+
struct ggml_tensor * cur;
10176+
struct ggml_tensor * inpL;
10177+
10178+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
10179+
10180+
// inp_pos - contains the positions
10181+
struct ggml_tensor * inp_pos = build_inp_pos();
10182+
10183+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
10184+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
10185+
10186+
for (int il = 0; il < n_layer; ++il) {
10187+
struct ggml_tensor * inpSA = inpL;
10188+
10189+
// norm
10190+
cur = llm_build_norm(ctx0, inpL, hparams,
10191+
NULL, NULL,
10192+
LLM_NORM, cb, il);
10193+
cb(cur, "attn_norm", il);
10194+
10195+
// self-attention
10196+
{
10197+
// compute Q and K and RoPE them
10198+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
10199+
cb(Qcur, "Qcur", il);
10200+
if (hparams.f_clamp_kqv > 0.0f) {
10201+
Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
10202+
cb(Qcur, "Qcur", il);
10203+
}
10204+
10205+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
10206+
cb(Kcur, "Kcur", il);
10207+
if (hparams.f_clamp_kqv > 0.0f) {
10208+
Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
10209+
cb(Kcur, "Kcur", il);
10210+
}
10211+
10212+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
10213+
cb(Vcur, "Vcur", il);
10214+
if (hparams.f_clamp_kqv > 0.0f) {
10215+
Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
10216+
cb(Vcur, "Vcur", il);
10217+
}
10218+
10219+
Qcur = ggml_rope_custom(
10220+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
10221+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
10222+
ext_factor, attn_factor, beta_fast, beta_slow
10223+
);
10224+
cb(Qcur, "Qcur", il);
10225+
10226+
Kcur = ggml_rope_custom(
10227+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
10228+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
10229+
ext_factor, attn_factor, beta_fast, beta_slow
10230+
);
10231+
cb(Kcur, "Kcur", il);
10232+
10233+
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
10234+
model.layers[il].wo, nullptr,
10235+
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
10236+
}
10237+
10238+
if (il == n_layer - 1) {
10239+
// skip computing output for unused tokens
10240+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
10241+
n_tokens = n_outputs;
10242+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10243+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
10244+
}
10245+
10246+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
10247+
cb(ffn_inp, "ffn_inp", il);
10248+
10249+
// feed-forward network
10250+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
10251+
NULL, NULL,
10252+
LLM_NORM, cb, il);
10253+
cb(cur, "ffn_norm", il);
10254+
10255+
cur = llm_build_ffn(ctx0, cur,
10256+
model.layers[il].ffn_up, NULL,
10257+
model.layers[il].ffn_gate, NULL,
10258+
model.layers[il].ffn_down, NULL,
10259+
NULL,
10260+
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
10261+
cb(cur, "ffn_out", il);
10262+
10263+
cur = ggml_add(ctx0, cur, ffn_inp);
10264+
cb(cur, "ffn_out", il);
10265+
10266+
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
10267+
if (layer_dir != nullptr) {
10268+
cur = ggml_add(ctx0, cur, layer_dir);
10269+
}
10270+
cb(cur, "l_out", il);
10271+
10272+
// input for next layer
10273+
inpL = cur;
10274+
}
10275+
10276+
cur = inpL;
10277+
10278+
cur = llm_build_norm(ctx0, cur, hparams,
10279+
NULL, NULL,
10280+
LLM_NORM, cb, -1);
10281+
cb(cur, "result_norm", -1);
10282+
10283+
// lm_head
10284+
cur = ggml_mul_mat(ctx0, model.output, cur);
10285+
cb(cur, "result_output", -1);
10286+
10287+
ggml_build_forward_expand(gf, cur);
10288+
10289+
return gf;
10290+
}
1009910291
};
1010010292

1010110293
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -10301,6 +10493,10 @@ static struct ggml_cgraph * llama_build_graph(
1030110493
{
1030210494
result = llm.build_dbrx();
1030310495
} break;
10496+
case LLM_ARCH_OLMO:
10497+
{
10498+
result = llm.build_olmo();
10499+
} break;
1030410500
default:
1030510501
GGML_ASSERT(false);
1030610502
}
@@ -15154,6 +15350,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1515415350
case LLM_ARCH_MINICPM:
1515515351
case LLM_ARCH_XVERSE:
1515615352
case LLM_ARCH_COMMAND_R:
15353+
case LLM_ARCH_OLMO:
1515715354
return LLAMA_ROPE_TYPE_NORM;
1515815355

1515915356
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)