Skip to content

Commit 31a6579

Browse files
authored
gguf: Add types for LLM architectures (#640)
Resolve #566 This PR introduces a generator script that pulls `llama.cpp` file, then use regex to extract needed information. It also addresses some smaller issues: - Some places (for example `Attention`, `Rope`, `MOE`,...) use type intersection while it should be union - `TokenizerModel` should be limited to 4 models, see `llm_load_vocab()`: https://github.com/ggerganov/llama.cpp/blob/928e0b7013c862cf10701957b3d654aa70f11bd8/llama.cpp#L4198 While it's fun to have the script & these type definitions, I think it would be nice to have some clearer directions on how to use them in the future. Demo: ![image](https://github.com/huggingface/huggingface.js/assets/7702203/d4d893e4-afe4-40ae-ab29-f6593b020ec6)
1 parent 1cce626 commit 31a6579

File tree

4 files changed

+423
-46
lines changed

4 files changed

+423
-46
lines changed

packages/gguf/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"format:check": "prettier --check .",
3333
"prepublishOnly": "pnpm run build",
3434
"build": "tsup src/index.ts --format cjs,esm --clean --dts",
35+
"build:llm": "tsx scripts/generate-llm.ts && pnpm run format",
3536
"test": "vitest run",
3637
"check": "tsc"
3738
},

packages/gguf/scripts/generate-llm.ts

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/**
2+
* Script for generating llm.ts
3+
* The source data is taken from llama.cpp
4+
*/
5+
6+
import { writeFileSync } from "node:fs";
7+
8+
const SOURCE_CPP_URL = "https://raw.githubusercontent.com/ggerganov/llama.cpp/master/llama.cpp";
9+
const DEST_FILE_PATH = "./src/transformer-llm.ts";
10+
const DEST_COMMON_SOURCE = `
11+
type Attention<TArchitecture extends string> =
12+
& { [K in \`\${TArchitecture}.attention.head_count\`]: number }
13+
& { [K in \`\${TArchitecture}.attention.head_count_kv\`]: number }
14+
& { [K in \`\${TArchitecture}.attention.layer_norm_epsilon\`]: number }
15+
& { [K in \`\${TArchitecture}.attention.layer_norm_rms_epsilon\`]: number }
16+
& { [K in \`\${TArchitecture}.attention.alibi_bias_max\`]: number }
17+
& { [K in \`\${TArchitecture}.attention.clip_kqv\`]: number }
18+
& { [K in \`\${TArchitecture}.attention.use_norm\`]: number };
19+
20+
type Rope<TArchitecture extends LLMArchitecture> =
21+
& { [K in \`\${TArchitecture}.rope.dimension_count\`]: number }
22+
& { [K in \`\${TArchitecture}.rope.freq_base\`]: number }
23+
& { [K in \`\${TArchitecture}.rope.scale\`]: number }
24+
& { [K in \`\${TArchitecture}.rope.scale_linear\`]: number };
25+
26+
type MOE<TArchitecture extends LLMArchitecture> =
27+
& { [K in \`\${TArchitecture}.expert_count\`]: number }
28+
& { [K in \`\${TArchitecture}.expert_used_count\`]: number };
29+
30+
export type TransformerLLMArchitecture = LLMArchitecture; // type alias
31+
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture>
32+
& MOE<TArchitecture>
33+
& Attention<TArchitecture>
34+
& Rope<TArchitecture>;
35+
36+
export enum TransformerLLMPoolingType {
37+
UNSPECIFIED = -1,
38+
NONE = 0,
39+
MEAN = 1,
40+
CLS = 2,
41+
};
42+
`;
43+
44+
const KV_TYPE = {
45+
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS: "number",
46+
LLM_KV_ATTENTION_LAYERNORM_EPS: "number",
47+
LLM_KV_ATTENTION_CAUSAL: "boolean",
48+
LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT: "number",
49+
LLM_KV_POOLING_TYPE: "TransformerLLMPoolingType",
50+
LLM_KV_ATTENTION_CLAMP_KQV: "number",
51+
LLM_KV_ATTENTION_MAX_ALIBI_BIAS: "number",
52+
LLM_KV_SSM_CONV_KERNEL: "number",
53+
LLM_KV_SSM_INNER_SIZE: "number",
54+
LLM_KV_SSM_STATE_SIZE: "number",
55+
LLM_KV_SSM_TIME_STEP_RANK: "number",
56+
LLM_KV_LOGIT_SCALE: "number",
57+
};
58+
59+
interface Arch {
60+
cppConst: string; // for example: "LLM_ARCH_LLAMA"
61+
name: string; // for example: "llama"
62+
tsName: string; // for example: "ArchLlama"
63+
tensorNames: string[]; // for example: "token_embd"
64+
hparams: string[];
65+
}
66+
67+
async function main() {
68+
const res = await fetch(SOURCE_CPP_URL);
69+
const cppSource = await res.text();
70+
71+
/////////////////////////////////////
72+
// extract list of all architectures
73+
const archList: Arch[] = [];
74+
const RE_ARCH_NAME = /LLM_ARCH_[A-Z0-9_]+/;
75+
const matchedArchList = cppSource.match(/LLM_ARCH_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n");
76+
if (!matchedArchList?.length) {
77+
throw new Error("LLM_ARCH_NAMES is empty");
78+
}
79+
for (const line of matchedArchList) {
80+
const matched = line.match(/(?<cppConst>LLM_ARCH_[A-Z0-9_]+),\s+"(?<name>.+?)"/);
81+
if (matched?.groups && !matched.groups.name.match(/unknown/)) {
82+
archList.push({
83+
cppConst: matched.groups.cppConst,
84+
name: matched.groups.name,
85+
tsName: snakeToPascal(matched.groups.cppConst.replace("LLM_", "")),
86+
tensorNames: [],
87+
hparams: [],
88+
});
89+
}
90+
}
91+
92+
/////////////////////////////////////
93+
// extract map constant name to kv name
94+
// for example: LLM_KV_ATTENTION_LAYERNORM_RMS_EPS ==> "%s.attention.layer_norm_rms_epsilon"
95+
const constToKVName: { [cppConst: string]: string } = {};
96+
const matchedKVList = cppSource.match(/LLM_KV_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n");
97+
if (!matchedKVList?.length) {
98+
throw new Error("LLM_KV_NAMES is empty");
99+
}
100+
for (const line of matchedKVList) {
101+
const matched = line.match(/(?<cppConst>LLM_KV_[A-Z0-9_]+)[,\s]+"(?<name>.+?)"/);
102+
if (matched?.groups) {
103+
constToKVName[matched.groups.cppConst] = matched.groups.name;
104+
}
105+
}
106+
107+
/////////////////////////////////////
108+
// extract list of tensor names based on architecture
109+
// TODO: unused for now
110+
const matchedTensorList = cppSource.match(/LLM_TENSOR_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n");
111+
if (!matchedTensorList?.length) {
112+
throw new Error("LLM_TENSOR_NAMES is empty");
113+
}
114+
let currCppConst = "";
115+
for (const line of matchedTensorList) {
116+
// check if current line has LLM_ARCH_*
117+
const cppConst = line.match(RE_ARCH_NAME)?.[0];
118+
if (cppConst) {
119+
currCppConst = cppConst;
120+
continue;
121+
}
122+
// check if current line has LLM_TENSOR_*
123+
const tensorMatched = line.match(/LLM_TENSOR_[A-Z0-9_]+[,\s]+"(?<name>.+?)"/);
124+
if (tensorMatched?.groups) {
125+
const arch = archList.find((a) => a.cppConst === currCppConst);
126+
if (arch) arch.tensorNames.push(tensorMatched.groups.name);
127+
}
128+
}
129+
130+
/////////////////////////////////////
131+
// extract list of hyper params based on architecture
132+
let insideLoadHParamsFn = false;
133+
currCppConst = "";
134+
for (const line of cppSource.split("\n")) {
135+
// check if current line is function llm_load_hparams()
136+
if (line.startsWith("static void llm_load_hparams")) {
137+
insideLoadHParamsFn = true;
138+
}
139+
if (!insideLoadHParamsFn) {
140+
continue;
141+
}
142+
// check if current line has LLM_ARCH_*
143+
const RE_CASE = new RegExp(`case (${RE_ARCH_NAME.source})`);
144+
const cppConst = line.match(RE_CASE)?.[1];
145+
if (cppConst) {
146+
currCppConst = cppConst;
147+
continue;
148+
}
149+
// check if current line has get_key(...)
150+
const keyConst = line.match(/LLM_KV_[A-Z0-9_]+/)?.[0];
151+
if (keyConst) {
152+
const arch = archList.find((a) => a.cppConst === currCppConst);
153+
if (arch) {
154+
arch.hparams.push(keyConst);
155+
}
156+
}
157+
// check if current line is end-of-function
158+
if (line === "}") {
159+
break;
160+
}
161+
}
162+
163+
/////////////////////////////////////
164+
// write result to file
165+
const content = [
166+
"/** This file is auto-generated by generate-llm.ts */",
167+
"",
168+
'import type { ModelBase } from "./types";',
169+
"",
170+
"export const LLM_ARCHITECTURES = [",
171+
...archList.map((a) => `\t${JSON.stringify(a.name)},`),
172+
"] as const;",
173+
"type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];",
174+
DEST_COMMON_SOURCE,
175+
...archList.map((a) => {
176+
let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`;
177+
if (a.hparams.length) {
178+
code += [
179+
" & {",
180+
...a.hparams.map((k) => `\t${JSON.stringify(constToKVName[k].replace("%s", a.name))}: ${KV_TYPE[k]},`),
181+
"};",
182+
].join("\n");
183+
} else {
184+
code += ";";
185+
}
186+
return code;
187+
}),
188+
"",
189+
`export type TransformerLLM = ${archList.map((a) => a.tsName).join(" | ")};`,
190+
].join("\n");
191+
192+
writeFileSync(DEST_FILE_PATH, content);
193+
}
194+
195+
function snakeToPascal(str: string) {
196+
return str
197+
.split("_")
198+
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
199+
.join("");
200+
}
201+
202+
main();

0 commit comments

Comments
 (0)