Skip to content

Commit 6a036d8

Browse files
authored
gguf: better type usage (#655)
Follow up #640 Ref comments: - #640 (review) by @julien-c suggests using a check `metadata["general.architecture"] === ...` to select the correct type - #640 (comment) by @coyotte508 suggests using less generic but more verbose code The type system introduce in this PR allows type-checking at both compile time & runtime: ```ts const model: GGUFMetadata<GGUFType.STRICT> = null as any; if (model["general.architecture"] === "whisper") { model["encoder.whisper.block_count"] = 0; // @ts-expect-error because it must be a number model["encoder.whisper.block_count"] = "abc"; } if (model["tokenizer.ggml.model"] === undefined) { // @ts-expect-error because it's undefined model["tokenizer.ggml.eos_token_id"] = 1; } if (model["tokenizer.ggml.model"] === "gpt2") { // @ts-expect-error because it must be a number model["tokenizer.ggml.eos_token_id"] = undefined; model["tokenizer.ggml.eos_token_id"] = 1; } if (model["general.architecture"] === "mamba") { model["mamba.ssm.conv_kernel"] = 0; // @ts-expect-error because it must be a number model["mamba.ssm.conv_kernel"] = "abc"; } if (model["general.architecture"] === "llama") { // @ts-expect-error llama does not have ssm.* keys model["mamba.ssm.conv_kernel"] = 0; } ``` Type checks can be disable with `GGUFMetadata<GGUFType.NON_STRICT>`
1 parent 99bbf1f commit 6a036d8

File tree

6 files changed

+227
-90
lines changed

6 files changed

+227
-90
lines changed

packages/gguf/scripts/generate-llm.ts

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,53 @@ import { writeFileSync } from "node:fs";
88
const SOURCE_CPP_URL = "https://raw.githubusercontent.com/ggerganov/llama.cpp/master/llama.cpp";
99
const DEST_FILE_PATH = "./src/transformer-llm.ts";
1010
const DEST_COMMON_SOURCE = `
11-
type Attention<TArchitecture extends string> =
12-
& { [K in \`\${TArchitecture}.attention.head_count\`]: number }
13-
& { [K in \`\${TArchitecture}.attention.head_count_kv\`]: number }
14-
& { [K in \`\${TArchitecture}.attention.layer_norm_epsilon\`]: number }
15-
& { [K in \`\${TArchitecture}.attention.layer_norm_rms_epsilon\`]: number }
16-
& { [K in \`\${TArchitecture}.attention.alibi_bias_max\`]: number }
17-
& { [K in \`\${TArchitecture}.attention.clip_kqv\`]: number }
18-
& { [K in \`\${TArchitecture}.attention.use_norm\`]: number };
19-
20-
type Rope<TArchitecture extends LLMArchitecture> =
21-
& { [K in \`\${TArchitecture}.rope.dimension_count\`]: number }
22-
& { [K in \`\${TArchitecture}.rope.freq_base\`]: number }
23-
& { [K in \`\${TArchitecture}.rope.scale\`]: number }
24-
& { [K in \`\${TArchitecture}.rope.scale_linear\`]: number };
25-
26-
type MOE<TArchitecture extends LLMArchitecture> =
27-
& { [K in \`\${TArchitecture}.expert_count\`]: number }
28-
& { [K in \`\${TArchitecture}.expert_used_count\`]: number };
11+
/** This file is auto-generated by generate-llm.ts */
12+
13+
import type { ModelBase, GGUFGeneralInfo } from "./types";
14+
15+
type LLMBase<TArchitecture extends string> = Partial<Record<
16+
\`\${TArchitecture}.vocab_size\`
17+
| \`\${TArchitecture}.use_parallel_residual\`
18+
| \`\${TArchitecture}.tensor_data_layout\`,
19+
number
20+
>>;
21+
22+
type Attention<TArchitecture extends string> = Record<
23+
\`\${TArchitecture}.attention.head_count\`,
24+
number
25+
> & Partial<Record<
26+
\`\${TArchitecture}.attention.head_count_kv\`
27+
| \`\${TArchitecture}.attention.key_length\`
28+
| \`\${TArchitecture}.attention.value_length\`,
29+
number
30+
>>;
31+
32+
export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn";
33+
type Rope<TArchitecture extends LLMArchitecture> = Partial<
34+
Record<
35+
\`\${TArchitecture}.rope.dimension_count\`
36+
| \`\${TArchitecture}.rope.freq_base\`
37+
| \`\${TArchitecture}.rope.scale_linear\`
38+
| \`\${TArchitecture}.rope.scaling.factor\`
39+
| \`\${TArchitecture}.rope.scaling.original_context_length\`,
40+
number
41+
>
42+
& Record<\`\${TArchitecture}.rope.scaling.type\`, TransformerLLMRopeScalingType>
43+
& Record<\`\${TArchitecture}.rope.finetuned\`, boolean>
44+
>;
45+
46+
type MOE<TArchitecture extends LLMArchitecture> = Partial<
47+
Record<
48+
\`\${TArchitecture}.expert_count\`
49+
| \`\${TArchitecture}.expert_used_count\`,
50+
number
51+
>
52+
>;
2953
3054
export type TransformerLLMArchitecture = LLMArchitecture; // type alias
31-
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture>
55+
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture>
56+
& LLMBase<TArchitecture>
57+
& ModelBase<TArchitecture>
3258
& MOE<TArchitecture>
3359
& Attention<TArchitecture>
3460
& Rope<TArchitecture>;
@@ -163,15 +189,11 @@ async function main() {
163189
/////////////////////////////////////
164190
// write result to file
165191
const content = [
166-
"/** This file is auto-generated by generate-llm.ts */",
167-
"",
168-
'import type { ModelBase } from "./types";',
169-
"",
192+
DEST_COMMON_SOURCE,
170193
"export const LLM_ARCHITECTURES = [",
171194
...archList.map((a) => `\t${JSON.stringify(a.name)},`),
172195
"] as const;",
173196
"type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];",
174-
DEST_COMMON_SOURCE,
175197
...archList.map((a) => {
176198
let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`;
177199
if (a.hparams.length) {

packages/gguf/src/gguf.spec.ts

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,25 @@ describe("gguf", () => {
3737
"llama.rope.dimension_count": 128,
3838
});
3939

40-
const tokens = metadata["tokenizer.ggml.tokens"];
41-
if (!Array.isArray(tokens)) {
42-
throw new Error();
40+
expect(metadata["tokenizer.ggml.model"]);
41+
if (metadata["tokenizer.ggml.model"]) {
42+
const tokens = metadata["tokenizer.ggml.tokens"];
43+
if (!Array.isArray(tokens)) {
44+
throw new Error();
45+
}
46+
expect(tokens.slice(0, 10)).toEqual([
47+
"<unk>",
48+
"<s>",
49+
"</s>",
50+
"<0x00>",
51+
"<0x01>",
52+
"<0x02>",
53+
"<0x03>",
54+
"<0x04>",
55+
"<0x05>",
56+
"<0x06>",
57+
]);
4358
}
44-
expect(tokens.slice(0, 10)).toEqual([
45-
"<unk>",
46-
"<s>",
47-
"</s>",
48-
"<0x00>",
49-
"<0x01>",
50-
"<0x02>",
51-
"<0x03>",
52-
"<0x04>",
53-
"<0x05>",
54-
"<0x06>",
55-
]);
5659

5760
/// Tensor infos
5861
/// By convention we test the first and last tensor.

packages/gguf/src/gguf.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ export async function gguf(
273273
offset += tensorCount.length;
274274
const numKv = readVersionedSize(r.view, offset, version, littleEndian);
275275
offset += numKv.length;
276-
const metadata: GGUFMetadata = {
276+
const metadata: GGUFMetadata<{ strict: false }> = {
277277
version,
278278
tensor_count: tensorCount.value,
279279
kv_count: numKv.value,

packages/gguf/src/transformer-llm.ts

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,56 @@
11
/** This file is auto-generated by generate-llm.ts */
22

3-
import type { ModelBase } from "./types";
3+
import type { ModelBase, GGUFGeneralInfo } from "./types";
4+
5+
type LLMBase<TArchitecture extends string> = Partial<
6+
Record<
7+
`${TArchitecture}.vocab_size` | `${TArchitecture}.use_parallel_residual` | `${TArchitecture}.tensor_data_layout`,
8+
number
9+
>
10+
>;
11+
12+
type Attention<TArchitecture extends string> = Record<`${TArchitecture}.attention.head_count`, number> &
13+
Partial<
14+
Record<
15+
| `${TArchitecture}.attention.head_count_kv`
16+
| `${TArchitecture}.attention.key_length`
17+
| `${TArchitecture}.attention.value_length`,
18+
number
19+
>
20+
>;
21+
22+
export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn";
23+
type Rope<TArchitecture extends LLMArchitecture> = Partial<
24+
Record<
25+
| `${TArchitecture}.rope.dimension_count`
26+
| `${TArchitecture}.rope.freq_base`
27+
| `${TArchitecture}.rope.scale_linear`
28+
| `${TArchitecture}.rope.scaling.factor`
29+
| `${TArchitecture}.rope.scaling.original_context_length`,
30+
number
31+
> &
32+
Record<`${TArchitecture}.rope.scaling.type`, TransformerLLMRopeScalingType> &
33+
Record<`${TArchitecture}.rope.finetuned`, boolean>
34+
>;
35+
36+
type MOE<TArchitecture extends LLMArchitecture> = Partial<
37+
Record<`${TArchitecture}.expert_count` | `${TArchitecture}.expert_used_count`, number>
38+
>;
39+
40+
export type TransformerLLMArchitecture = LLMArchitecture; // type alias
41+
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture> &
42+
LLMBase<TArchitecture> &
43+
ModelBase<TArchitecture> &
44+
MOE<TArchitecture> &
45+
Attention<TArchitecture> &
46+
Rope<TArchitecture>;
47+
48+
export enum TransformerLLMPoolingType {
49+
UNSPECIFIED = -1,
50+
NONE = 0,
51+
MEAN = 1,
52+
CLS = 2,
53+
}
454

555
export const LLM_ARCHITECTURES = [
656
"llama",
@@ -37,36 +87,6 @@ export const LLM_ARCHITECTURES = [
3787
"olmo",
3888
] as const;
3989
type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];
40-
41-
type Attention<TArchitecture extends string> = { [K in `${TArchitecture}.attention.head_count`]: number } & {
42-
[K in `${TArchitecture}.attention.head_count_kv`]: number;
43-
} & { [K in `${TArchitecture}.attention.layer_norm_epsilon`]: number } & {
44-
[K in `${TArchitecture}.attention.layer_norm_rms_epsilon`]: number;
45-
} & { [K in `${TArchitecture}.attention.alibi_bias_max`]: number } & {
46-
[K in `${TArchitecture}.attention.clip_kqv`]: number;
47-
} & { [K in `${TArchitecture}.attention.use_norm`]: number };
48-
49-
type Rope<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.rope.dimension_count`]: number } & {
50-
[K in `${TArchitecture}.rope.freq_base`]: number;
51-
} & { [K in `${TArchitecture}.rope.scale`]: number } & { [K in `${TArchitecture}.rope.scale_linear`]: number };
52-
53-
type MOE<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.expert_count`]: number } & {
54-
[K in `${TArchitecture}.expert_used_count`]: number;
55-
};
56-
57-
export type TransformerLLMArchitecture = LLMArchitecture; // type alias
58-
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture> &
59-
MOE<TArchitecture> &
60-
Attention<TArchitecture> &
61-
Rope<TArchitecture>;
62-
63-
export enum TransformerLLMPoolingType {
64-
UNSPECIFIED = -1,
65-
NONE = 0,
66-
MEAN = 1,
67-
CLS = 2,
68-
}
69-
7090
export type ArchLlama = TransformerLLMBase<"llama"> & {
7191
"llama.attention.layer_norm_rms_epsilon": number;
7292
};

packages/gguf/src/types.spec.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import { describe, it } from "vitest";
2+
import type { gguf } from "./gguf";
3+
import type { GGUFMetadata, GGUFParseOutput } from "./types";
4+
5+
describe("gguf-types", () => {
6+
it("gguf() type can be casted between STRICT and NON_STRICT (at compile time)", async () => {
7+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
8+
const result: Awaited<ReturnType<typeof gguf>> = { metadata: {} } as any;
9+
const strictType = result as GGUFParseOutput<{ strict: true }>;
10+
// @ts-expect-error because the key "abc" does not exist
11+
strictType.metadata.abc = 123;
12+
const nonStrictType = result as GGUFParseOutput<{ strict: false }>;
13+
nonStrictType.metadata.abc = 123; // PASS, because it can be anything
14+
// @ts-expect-error because ArrayBuffer is not a MetadataValue
15+
nonStrictType.metadata.fff = ArrayBuffer;
16+
});
17+
18+
it("GGUFType.NON_STRICT should be correct (at compile time)", async () => {
19+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
20+
const model: GGUFMetadata<{ strict: false }> = {} as any;
21+
model.kv_count = 123n;
22+
model.abc = 456; // PASS, because it can be anything
23+
});
24+
25+
it("GGUFType.STRICT should be correct (at compile time)", async () => {
26+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
27+
const model: GGUFMetadata<{ strict: true }> = {} as any;
28+
29+
if (model["general.architecture"] === "whisper") {
30+
model["encoder.whisper.block_count"] = 0;
31+
// @ts-expect-error because it must be a number
32+
model["encoder.whisper.block_count"] = "abc";
33+
}
34+
35+
if (model["tokenizer.ggml.model"] === undefined) {
36+
// @ts-expect-error because it's undefined
37+
model["tokenizer.ggml.eos_token_id"] = 1;
38+
}
39+
if (model["tokenizer.ggml.model"] === "gpt2") {
40+
// @ts-expect-error because it must be a number
41+
model["tokenizer.ggml.eos_token_id"] = undefined;
42+
model["tokenizer.ggml.eos_token_id"] = 1;
43+
}
44+
45+
if (model["general.architecture"] === "mamba") {
46+
model["mamba.ssm.conv_kernel"] = 0;
47+
// @ts-expect-error because it must be a number
48+
model["mamba.ssm.conv_kernel"] = "abc";
49+
}
50+
if (model["general.architecture"] === "llama") {
51+
// @ts-expect-error llama does not have ssm.* keys
52+
model["mamba.ssm.conv_kernel"] = 0;
53+
}
54+
});
55+
});

packages/gguf/src/types.ts

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,32 @@ export enum GGUFValueType {
5050
const ARCHITECTURES = [...LLM_ARCHITECTURES, "rwkv", "whisper"] as const;
5151
export type Architecture = (typeof ARCHITECTURES)[number];
5252

53-
interface General {
54-
"general.architecture": Architecture;
55-
"general.name": string;
56-
"general.file_type": number;
57-
"general.quantization_version": number;
53+
export interface GGUFGeneralInfo<TArchitecture extends Architecture> {
54+
"general.architecture": TArchitecture;
55+
"general.name"?: string;
56+
"general.file_type"?: number;
57+
"general.quantization_version"?: number;
58+
}
59+
60+
type ModelMetadata = Whisper | RWKV | TransformerLLM;
61+
interface NoModelMetadata {
62+
"general.architecture"?: undefined;
5863
}
5964

6065
export type ModelBase<
6166
TArchitecture extends
6267
| Architecture
6368
| `encoder.${Extract<Architecture, "whisper">}`
6469
| `decoder.${Extract<Architecture, "whisper">}`,
65-
> = { [K in `${TArchitecture}.layer_count`]: number } & { [K in `${TArchitecture}.feed_forward_length`]: number } & {
66-
[K in `${TArchitecture}.context_length`]: number;
67-
} & { [K in `${TArchitecture}.embedding_length`]: number } & { [K in `${TArchitecture}.block_count`]: number };
70+
> = Record<
71+
| `${TArchitecture}.context_length`
72+
| `${TArchitecture}.block_count`
73+
| `${TArchitecture}.embedding_length`
74+
| `${TArchitecture}.feed_forward_length`,
75+
number
76+
>;
77+
78+
/// Tokenizer
6879

6980
type TokenizerModel = "no_vocab" | "llama" | "gpt2" | "bert";
7081
interface Tokenizer {
@@ -75,21 +86,47 @@ interface Tokenizer {
7586
"tokenizer.ggml.bos_token_id": number;
7687
"tokenizer.ggml.eos_token_id": number;
7788
"tokenizer.ggml.add_bos_token": boolean;
78-
"tokenizer.chat_template": string;
89+
"tokenizer.chat_template"?: string;
7990
}
91+
interface NoTokenizer {
92+
"tokenizer.ggml.model"?: undefined;
93+
}
94+
95+
/// Models outside of llama.cpp: "rwkv" and "whisper"
8096

81-
export type RWKV = ModelBase<"rwkv"> & { "rwkv.architecture_version": number };
82-
export type LLM = TransformerLLM | RWKV;
83-
export type Whisper = ModelBase<"encoder.whisper"> & ModelBase<"decoder.whisper">;
84-
export type Model = (LLM | Whisper) & Partial<Tokenizer>;
97+
export type RWKV = GGUFGeneralInfo<"rwkv"> &
98+
ModelBase<"rwkv"> & {
99+
"rwkv.architecture_version": number;
100+
};
85101

86-
export type GGUFMetadata = {
102+
// TODO: whisper.cpp doesn't yet support gguf. This maybe changed in the future.
103+
export type Whisper = GGUFGeneralInfo<"whisper"> &
104+
ModelBase<"encoder.whisper"> &
105+
ModelBase<"decoder.whisper"> & {
106+
"whisper.encoder.mels_count": number;
107+
"whisper.encoder.attention.head_count": number;
108+
"whisper.decoder.attention.head_count": number;
109+
};
110+
111+
/// Types for parse output
112+
113+
export interface GGUFMetadataOptions {
114+
/**
115+
* Enable strict type for known GGUF fields.
116+
*
117+
* @default true
118+
*/
119+
strict: boolean;
120+
}
121+
122+
export type GGUFMetadata<Options extends GGUFMetadataOptions = { strict: true }> = {
87123
version: Version;
88124
tensor_count: bigint;
89125
kv_count: bigint;
90-
} & Partial<General> &
91-
Partial<Model> &
92-
Record<string, MetadataValue>;
126+
} & GGUFModelKV &
127+
(Options extends { strict: true } ? unknown : Record<string, MetadataValue>);
128+
129+
export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer);
93130

94131
export interface GGUFTensorInfo {
95132
name: string;
@@ -99,7 +136,7 @@ export interface GGUFTensorInfo {
99136
offset: bigint;
100137
}
101138

102-
export interface GGUFParseOutput {
103-
metadata: GGUFMetadata;
139+
export interface GGUFParseOutput<Options extends GGUFMetadataOptions = { strict: true }> {
140+
metadata: GGUFMetadata<Options>;
104141
tensorInfos: GGUFTensorInfo[];
105142
}

0 commit comments

Comments
 (0)