|
| 1 | +/** |
| 2 | + * Script for generating llm.ts |
| 3 | + * The source data is taken from llama.cpp |
| 4 | + */ |
| 5 | + |
| 6 | +import { writeFileSync } from "node:fs"; |
| 7 | + |
| 8 | +const SOURCE_CPP_URL = "https://raw.githubusercontent.com/ggerganov/llama.cpp/master/llama.cpp"; |
| 9 | +const DEST_FILE_PATH = "./src/transformer-llm.ts"; |
| 10 | +const DEST_COMMON_SOURCE = ` |
| 11 | +type Attention<TArchitecture extends string> = |
| 12 | + & { [K in \`\${TArchitecture}.attention.head_count\`]: number } |
| 13 | + & { [K in \`\${TArchitecture}.attention.head_count_kv\`]: number } |
| 14 | + & { [K in \`\${TArchitecture}.attention.layer_norm_epsilon\`]: number } |
| 15 | + & { [K in \`\${TArchitecture}.attention.layer_norm_rms_epsilon\`]: number } |
| 16 | + & { [K in \`\${TArchitecture}.attention.alibi_bias_max\`]: number } |
| 17 | + & { [K in \`\${TArchitecture}.attention.clip_kqv\`]: number } |
| 18 | + & { [K in \`\${TArchitecture}.attention.use_norm\`]: number }; |
| 19 | +
|
| 20 | +type Rope<TArchitecture extends LLMArchitecture> = |
| 21 | + & { [K in \`\${TArchitecture}.rope.dimension_count\`]: number } |
| 22 | + & { [K in \`\${TArchitecture}.rope.freq_base\`]: number } |
| 23 | + & { [K in \`\${TArchitecture}.rope.scale\`]: number } |
| 24 | + & { [K in \`\${TArchitecture}.rope.scale_linear\`]: number }; |
| 25 | +
|
| 26 | +type MOE<TArchitecture extends LLMArchitecture> = |
| 27 | + & { [K in \`\${TArchitecture}.expert_count\`]: number } |
| 28 | + & { [K in \`\${TArchitecture}.expert_used_count\`]: number }; |
| 29 | +
|
| 30 | +export type TransformerLLMArchitecture = LLMArchitecture; // type alias |
| 31 | +export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture> |
| 32 | + & MOE<TArchitecture> |
| 33 | + & Attention<TArchitecture> |
| 34 | + & Rope<TArchitecture>; |
| 35 | +
|
| 36 | +export enum TransformerLLMPoolingType { |
| 37 | + UNSPECIFIED = -1, |
| 38 | + NONE = 0, |
| 39 | + MEAN = 1, |
| 40 | + CLS = 2, |
| 41 | +}; |
| 42 | +`; |
| 43 | + |
| 44 | +const KV_TYPE = { |
| 45 | + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS: "number", |
| 46 | + LLM_KV_ATTENTION_LAYERNORM_EPS: "number", |
| 47 | + LLM_KV_ATTENTION_CAUSAL: "boolean", |
| 48 | + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT: "number", |
| 49 | + LLM_KV_POOLING_TYPE: "TransformerLLMPoolingType", |
| 50 | + LLM_KV_ATTENTION_CLAMP_KQV: "number", |
| 51 | + LLM_KV_ATTENTION_MAX_ALIBI_BIAS: "number", |
| 52 | + LLM_KV_SSM_CONV_KERNEL: "number", |
| 53 | + LLM_KV_SSM_INNER_SIZE: "number", |
| 54 | + LLM_KV_SSM_STATE_SIZE: "number", |
| 55 | + LLM_KV_SSM_TIME_STEP_RANK: "number", |
| 56 | + LLM_KV_LOGIT_SCALE: "number", |
| 57 | +}; |
| 58 | + |
| 59 | +interface Arch { |
| 60 | + cppConst: string; // for example: "LLM_ARCH_LLAMA" |
| 61 | + name: string; // for example: "llama" |
| 62 | + tsName: string; // for example: "ArchLlama" |
| 63 | + tensorNames: string[]; // for example: "token_embd" |
| 64 | + hparams: string[]; |
| 65 | +} |
| 66 | + |
| 67 | +async function main() { |
| 68 | + const res = await fetch(SOURCE_CPP_URL); |
| 69 | + const cppSource = await res.text(); |
| 70 | + |
| 71 | + ///////////////////////////////////// |
| 72 | + // extract list of all architectures |
| 73 | + const archList: Arch[] = []; |
| 74 | + const RE_ARCH_NAME = /LLM_ARCH_[A-Z0-9_]+/; |
| 75 | + const matchedArchList = cppSource.match(/LLM_ARCH_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n"); |
| 76 | + if (!matchedArchList?.length) { |
| 77 | + throw new Error("LLM_ARCH_NAMES is empty"); |
| 78 | + } |
| 79 | + for (const line of matchedArchList) { |
| 80 | + const matched = line.match(/(?<cppConst>LLM_ARCH_[A-Z0-9_]+),\s+"(?<name>.+?)"/); |
| 81 | + if (matched?.groups && !matched.groups.name.match(/unknown/)) { |
| 82 | + archList.push({ |
| 83 | + cppConst: matched.groups.cppConst, |
| 84 | + name: matched.groups.name, |
| 85 | + tsName: snakeToPascal(matched.groups.cppConst.replace("LLM_", "")), |
| 86 | + tensorNames: [], |
| 87 | + hparams: [], |
| 88 | + }); |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + ///////////////////////////////////// |
| 93 | + // extract map constant name to kv name |
| 94 | + // for example: LLM_KV_ATTENTION_LAYERNORM_RMS_EPS ==> "%s.attention.layer_norm_rms_epsilon" |
| 95 | + const constToKVName: { [cppConst: string]: string } = {}; |
| 96 | + const matchedKVList = cppSource.match(/LLM_KV_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n"); |
| 97 | + if (!matchedKVList?.length) { |
| 98 | + throw new Error("LLM_KV_NAMES is empty"); |
| 99 | + } |
| 100 | + for (const line of matchedKVList) { |
| 101 | + const matched = line.match(/(?<cppConst>LLM_KV_[A-Z0-9_]+)[,\s]+"(?<name>.+?)"/); |
| 102 | + if (matched?.groups) { |
| 103 | + constToKVName[matched.groups.cppConst] = matched.groups.name; |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + ///////////////////////////////////// |
| 108 | + // extract list of tensor names based on architecture |
| 109 | + // TODO: unused for now |
| 110 | + const matchedTensorList = cppSource.match(/LLM_TENSOR_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n"); |
| 111 | + if (!matchedTensorList?.length) { |
| 112 | + throw new Error("LLM_TENSOR_NAMES is empty"); |
| 113 | + } |
| 114 | + let currCppConst = ""; |
| 115 | + for (const line of matchedTensorList) { |
| 116 | + // check if current line has LLM_ARCH_* |
| 117 | + const cppConst = line.match(RE_ARCH_NAME)?.[0]; |
| 118 | + if (cppConst) { |
| 119 | + currCppConst = cppConst; |
| 120 | + continue; |
| 121 | + } |
| 122 | + // check if current line has LLM_TENSOR_* |
| 123 | + const tensorMatched = line.match(/LLM_TENSOR_[A-Z0-9_]+[,\s]+"(?<name>.+?)"/); |
| 124 | + if (tensorMatched?.groups) { |
| 125 | + const arch = archList.find((a) => a.cppConst === currCppConst); |
| 126 | + if (arch) arch.tensorNames.push(tensorMatched.groups.name); |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + ///////////////////////////////////// |
| 131 | + // extract list of hyper params based on architecture |
| 132 | + let insideLoadHParamsFn = false; |
| 133 | + currCppConst = ""; |
| 134 | + for (const line of cppSource.split("\n")) { |
| 135 | + // check if current line is function llm_load_hparams() |
| 136 | + if (line.startsWith("static void llm_load_hparams")) { |
| 137 | + insideLoadHParamsFn = true; |
| 138 | + } |
| 139 | + if (!insideLoadHParamsFn) { |
| 140 | + continue; |
| 141 | + } |
| 142 | + // check if current line has LLM_ARCH_* |
| 143 | + const RE_CASE = new RegExp(`case (${RE_ARCH_NAME.source})`); |
| 144 | + const cppConst = line.match(RE_CASE)?.[1]; |
| 145 | + if (cppConst) { |
| 146 | + currCppConst = cppConst; |
| 147 | + continue; |
| 148 | + } |
| 149 | + // check if current line has get_key(...) |
| 150 | + const keyConst = line.match(/LLM_KV_[A-Z0-9_]+/)?.[0]; |
| 151 | + if (keyConst) { |
| 152 | + const arch = archList.find((a) => a.cppConst === currCppConst); |
| 153 | + if (arch) { |
| 154 | + arch.hparams.push(keyConst); |
| 155 | + } |
| 156 | + } |
| 157 | + // check if current line is end-of-function |
| 158 | + if (line === "}") { |
| 159 | + break; |
| 160 | + } |
| 161 | + } |
| 162 | + |
| 163 | + ///////////////////////////////////// |
| 164 | + // write result to file |
| 165 | + const content = [ |
| 166 | + "/** This file is auto-generated by generate-llm.ts */", |
| 167 | + "", |
| 168 | + 'import type { ModelBase } from "./types";', |
| 169 | + "", |
| 170 | + "export const LLM_ARCHITECTURES = [", |
| 171 | + ...archList.map((a) => `\t${JSON.stringify(a.name)},`), |
| 172 | + "] as const;", |
| 173 | + "type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];", |
| 174 | + DEST_COMMON_SOURCE, |
| 175 | + ...archList.map((a) => { |
| 176 | + let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`; |
| 177 | + if (a.hparams.length) { |
| 178 | + code += [ |
| 179 | + " & {", |
| 180 | + ...a.hparams.map((k) => `\t${JSON.stringify(constToKVName[k].replace("%s", a.name))}: ${KV_TYPE[k]},`), |
| 181 | + "};", |
| 182 | + ].join("\n"); |
| 183 | + } else { |
| 184 | + code += ";"; |
| 185 | + } |
| 186 | + return code; |
| 187 | + }), |
| 188 | + "", |
| 189 | + `export type TransformerLLM = ${archList.map((a) => a.tsName).join(" | ")};`, |
| 190 | + ].join("\n"); |
| 191 | + |
| 192 | + writeFileSync(DEST_FILE_PATH, content); |
| 193 | +} |
| 194 | + |
| 195 | +function snakeToPascal(str: string) { |
| 196 | + return str |
| 197 | + .split("_") |
| 198 | + .map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase()) |
| 199 | + .join(""); |
| 200 | +} |
| 201 | + |
| 202 | +main(); |
0 commit comments