Skip to content

Commit 77e3ce2

Browse files
committed
better escape functions
1 parent 5bc694b commit 77e3ce2

File tree

5 files changed

+103
-88
lines changed

5 files changed

+103
-88
lines changed

packages/tasks/src/snippets/common.ts

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";
2+
3+
export interface StringifyMessagesOptions {
4+
sep: string;
5+
start: string;
6+
end: string;
7+
attributeKeyQuotes?: boolean;
8+
customContentEscaper?: (str: string) => string;
9+
}
10+
11+
export function stringifyMessages(messages: ChatCompletionInputMessage[], opts: StringifyMessagesOptions): string {
12+
const keyRole = opts.attributeKeyQuotes ? `"role"` : "role";
13+
const keyContent = opts.attributeKeyQuotes ? `"role"` : "role";
14+
15+
const messagesStringified = messages.map(({ role, content }) => {
16+
if (typeof content === "string") {
17+
content = JSON.stringify(content).slice(1, -1);
18+
if (opts.customContentEscaper) {
19+
content = opts.customContentEscaper(content);
20+
}
21+
return `{ ${keyRole}: "${role}", ${keyContent}: "${content}" }`;
22+
} else {
23+
2;
24+
content = content.map(({ image_url, text, type }) => ({
25+
type,
26+
image_url,
27+
...(text ? { text: JSON.stringify(text).slice(1, -1) } : undefined),
28+
}));
29+
content = JSON.stringify(content).slice(1, -1);
30+
if (opts.customContentEscaper) {
31+
content = opts.customContentEscaper(content);
32+
}
33+
return `{ ${keyRole}: "${role}", ${keyContent}: ${content} }`;
34+
}
35+
});
36+
37+
return opts.start + messagesStringified.join(opts.sep) + opts.end;
38+
}
39+
40+
type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;
41+
42+
export interface StringifyGenerationConfigOptions {
43+
sep: string;
44+
start: string;
45+
end: string;
46+
attributeValueConnector: string;
47+
attributeKeyQuotes?: boolean;
48+
}
49+
50+
export function stringifyGenerationConfig(
51+
config: PartialGenerationParameters,
52+
opts: StringifyGenerationConfigOptions
53+
): string {
54+
const quote = opts.attributeKeyQuotes ? `"` : "";
55+
56+
return (
57+
opts.start +
58+
Object.entries(config)
59+
.map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${quote}${val}${quote}`)
60+
.join(opts.sep) +
61+
opts.end
62+
);
63+
}

packages/tasks/src/snippets/curl.ts

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import type { PipelineType } from "../pipelines.js";
22
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
3+
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
34
import { getModelInputSnippet } from "./inputs.js";
4-
import type {
5-
GenerationConfigFormatter,
6-
GenerationMessagesFormatter,
7-
InferenceSnippet,
8-
ModelDataMinimal,
9-
} from "./types.js";
5+
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
106

117
export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
128
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
@@ -16,25 +12,6 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): Infe
1612
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`,
1713
});
1814

19-
const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
20-
start +
21-
messages
22-
.map(({ role }) => {
23-
// escape single quotes since single quotes is used to define http post body inside curl requests
24-
// TODO: handle the case below
25-
// content = content?.replace(/'/g, "'\\''");
26-
return `{ "role": "${role}", "content": "test msg" }`;
27-
})
28-
.join(sep) +
29-
end;
30-
31-
const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end }) =>
32-
start +
33-
Object.entries(config)
34-
.map(([key, val]) => `"${key}": ${val}`)
35-
.join(sep) +
36-
end;
37-
3815
export const snippetTextGeneration = (
3916
model: ModelDataMinimal,
4017
accessToken: string,
@@ -64,8 +41,20 @@ export const snippetTextGeneration = (
6441
-H 'Content-Type: application/json' \\
6542
--data '{
6643
"model": "${model.id}",
67-
"messages": ${formatGenerationMessages({ messages, sep: ",\n ", start: `[\n `, end: `\n]` })},
68-
${formatGenerationConfig({ config, sep: ",\n ", start: "", end: "" })},
44+
"messages": ${stringifyMessages(messages, {
45+
sep: ",\n ",
46+
start: `[\n `,
47+
end: `\n]`,
48+
attributeKeyQuotes: true,
49+
customContentEscaper: (str) => str.replace(/'/g, "'\\''"),
50+
})},
51+
${stringifyGenerationConfig(config, {
52+
sep: ",\n ",
53+
start: "",
54+
end: "",
55+
attributeKeyQuotes: true,
56+
attributeValueConnector: ": ",
57+
})},
6958
"stream": ${!!streaming}
7059
}'`,
7160
};

packages/tasks/src/snippets/js.ts

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import type { PipelineType } from "../pipelines.js";
22
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
3+
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
34
import { getModelInputSnippet } from "./inputs.js";
4-
import type {
5-
GenerationConfigFormatter,
6-
GenerationMessagesFormatter,
7-
InferenceSnippet,
8-
ModelDataMinimal,
9-
} from "./types.js";
5+
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
106

117
export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
128
content: `async function query(data) {
@@ -30,16 +26,6 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
3026
});`,
3127
});
3228

33-
const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
34-
start + messages.map(({ role, content }) => `{ role: "${role}", content: "${content}" }`).join(sep) + end;
35-
36-
const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end }) =>
37-
start +
38-
Object.entries(config)
39-
.map(([key, val]) => `${key}: ${val}`)
40-
.join(sep) +
41-
end;
42-
4329
export const snippetTextGeneration = (
4430
model: ModelDataMinimal,
4531
accessToken: string,
@@ -57,14 +43,19 @@ export const snippetTextGeneration = (
5743
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
5844
{ role: "user", content: "What is the capital of France?" },
5945
];
60-
const messagesStr = formatGenerationMessages({ messages, sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
46+
const messagesStr = stringifyMessages(messages, { sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
6147

6248
const config = {
6349
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
6450
max_tokens: opts?.max_tokens ?? 500,
6551
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
6652
};
67-
const configStr = formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "" });
53+
const configStr = stringifyGenerationConfig(config, {
54+
sep: ",\n\t",
55+
start: "",
56+
end: "",
57+
attributeValueConnector: ": ",
58+
});
6859

6960
if (streaming) {
7061
return [

packages/tasks/src/snippets/python.ts

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,8 @@
11
import type { PipelineType } from "../pipelines.js";
22
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
3+
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
34
import { getModelInputSnippet } from "./inputs.js";
4-
import type {
5-
GenerationConfigFormatter,
6-
GenerationMessagesFormatter,
7-
InferenceSnippet,
8-
ModelDataMinimal,
9-
} from "./types.js";
10-
11-
const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
12-
start + messages.map(({ role, content }) => `{ "role": "${role}", "content": "${content}" }`).join(sep) + end;
13-
14-
const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end, connector }) =>
15-
start +
16-
Object.entries(config)
17-
.map(([key, val]) => `${key}${connector}${val}`)
18-
.join(sep) +
19-
end;
5+
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
206

217
export const snippetConversational = (
228
model: ModelDataMinimal,
@@ -33,14 +19,25 @@ export const snippetConversational = (
3319
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
3420
{ role: "user", content: "What is the capital of France?" },
3521
];
36-
const messagesStr = formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` });
22+
const messagesStr = stringifyMessages(messages, {
23+
sep: ",\n\t",
24+
start: `[\n\t`,
25+
end: `\n]`,
26+
attributeKeyQuotes: true,
27+
});
3728

3829
const config = {
3930
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
4031
max_tokens: opts?.max_tokens ?? 500,
4132
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
4233
};
43-
const configStr = formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" });
34+
const configStr = stringifyGenerationConfig(config, {
35+
sep: ",\n\t",
36+
start: "",
37+
end: "",
38+
attributeValueConnector: "=",
39+
attributeKeyQuotes: true,
40+
});
4441

4542
if (streaming) {
4643
return [

packages/tasks/src/snippets/types.ts

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import type { ModelData } from "../model-data";
2-
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";
32

43
/**
54
* Minimal model data required for snippets.
@@ -15,27 +14,3 @@ export interface InferenceSnippet {
1514
content: string;
1615
client?: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets
1716
}
18-
19-
interface GenerationSnippetDelimiter {
20-
sep: string;
21-
start: string;
22-
end: string;
23-
connector?: string;
24-
}
25-
26-
type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;
27-
28-
export type GenerationMessagesFormatter = ({
29-
messages,
30-
sep,
31-
start,
32-
end,
33-
}: GenerationSnippetDelimiter & { messages: ChatCompletionInputMessage[] }) => string;
34-
35-
export type GenerationConfigFormatter = ({
36-
config,
37-
sep,
38-
start,
39-
end,
40-
connector,
41-
}: GenerationSnippetDelimiter & { config: PartialGenerationParameters }) => string;

0 commit comments

Comments
 (0)