Skip to content

Commit f93f6a4

Browse files
authored
Refactor conversational input to getModelInputSnippet (#989)
Refactor conversational input to getModelInputSnippet. Follow up to comment #985 (comment) @Wauplin
1 parent e14e8c1 commit f93f6a4

File tree

4 files changed

+43
-61
lines changed

4 files changed

+43
-61
lines changed

packages/tasks/src/snippets/curl.ts

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,7 @@ export const snippetTextGeneration = (
2626
if (model.tags.includes("conversational")) {
2727
// Conversational model detected, so we display a code snippet that features the Messages API
2828
const streaming = opts?.streaming ?? true;
29-
const exampleMessages: ChatCompletionInputMessage[] =
30-
model.pipeline_tag === "text-generation"
31-
? [{ role: "user", content: "What is the capital of France?" }]
32-
: [
33-
{
34-
role: "user",
35-
content: [
36-
{
37-
type: "image_url",
38-
image_url: {
39-
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
40-
},
41-
},
42-
{ type: "text", text: "Describe this image in one sentence." },
43-
],
44-
},
45-
];
29+
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
4630
const messages = opts?.messages ?? exampleMessages;
4731

4832
const config = {

packages/tasks/src/snippets/inputs.ts

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { PipelineType } from "../pipelines";
2+
import type { ChatCompletionInputMessage } from "../tasks";
23
import type { ModelDataMinimal } from "./types";
34

45
const inputsZeroShotClassification = () =>
@@ -40,7 +41,30 @@ const inputsTextClassification = () => `"I like you. I love you"`;
4041

4142
const inputsTokenClassification = () => `"My name is Sarah Jessica Parker but you can call me Jessica"`;
4243

43-
const inputsTextGeneration = () => `"Can you please let us know more details about your "`;
44+
const inputsTextGeneration = (model: ModelDataMinimal): string | ChatCompletionInputMessage[] => {
45+
if (model.tags.includes("conversational")) {
46+
return model.pipeline_tag === "text-generation"
47+
? [{ role: "user", content: "What is the capital of France?" }]
48+
: [
49+
{
50+
role: "user",
51+
content: [
52+
{
53+
type: "text",
54+
text: "Describe this image in one sentence.",
55+
},
56+
{
57+
type: "image_url",
58+
image_url: {
59+
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
60+
},
61+
},
62+
],
63+
},
64+
];
65+
}
66+
return `"Can you please let us know more details about your "`;
67+
};
4468

4569
const inputsText2TextGeneration = () => `"The answer to the universe is"`;
4670

@@ -84,7 +108,7 @@ const inputsTabularPrediction = () =>
84108
const inputsZeroShotImageClassification = () => `"cats.jpg"`;
85109

86110
const modelInputSnippets: {
87-
[key in PipelineType]?: (model: ModelDataMinimal) => string;
111+
[key in PipelineType]?: (model: ModelDataMinimal) => string | ChatCompletionInputMessage[];
88112
} = {
89113
"audio-to-audio": inputsAudioToAudio,
90114
"audio-classification": inputsAudioClassification,
@@ -116,18 +140,24 @@ const modelInputSnippets: {
116140

117141
// Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)
118142
// Use noQuotes to strip quotes from start & end (example: "abc" -> abc)
119-
export function getModelInputSnippet(model: ModelDataMinimal, noWrap = false, noQuotes = false): string {
143+
export function getModelInputSnippet(
144+
model: ModelDataMinimal,
145+
noWrap = false,
146+
noQuotes = false
147+
): string | ChatCompletionInputMessage[] {
120148
if (model.pipeline_tag) {
121149
const inputs = modelInputSnippets[model.pipeline_tag];
122150
if (inputs) {
123151
let result = inputs(model);
124-
if (noWrap) {
125-
result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ");
126-
}
127-
if (noQuotes) {
128-
const REGEX_QUOTES = /^"(.+)"$/s;
129-
const match = result.match(REGEX_QUOTES);
130-
result = match ? match[1] : result;
152+
if (typeof result === "string") {
153+
if (noWrap) {
154+
result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " ");
155+
}
156+
if (noQuotes) {
157+
const REGEX_QUOTES = /^"(.+)"$/s;
158+
const match = result.match(REGEX_QUOTES);
159+
result = match ? match[1] : result;
160+
}
131161
}
132162
return result;
133163
}

packages/tasks/src/snippets/js.ts

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,7 @@ export const snippetTextGeneration = (
4040
if (model.tags.includes("conversational")) {
4141
// Conversational model detected, so we display a code snippet that features the Messages API
4242
const streaming = opts?.streaming ?? true;
43-
const exampleMessages: ChatCompletionInputMessage[] =
44-
model.pipeline_tag === "text-generation"
45-
? [{ role: "user", content: "What is the capital of France?" }]
46-
: [
47-
{
48-
role: "user",
49-
content: [
50-
{
51-
type: "image_url",
52-
image_url: {
53-
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
54-
},
55-
},
56-
{ type: "text", text: "Describe this image in one sentence." },
57-
],
58-
},
59-
];
43+
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
6044
const messages = opts?.messages ?? exampleMessages;
6145
const messagesStr = stringifyMessages(messages, { sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
6246

packages/tasks/src/snippets/python.ts

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,7 @@ export const snippetConversational = (
1616
}
1717
): InferenceSnippet[] => {
1818
const streaming = opts?.streaming ?? true;
19-
const exampleMessages: ChatCompletionInputMessage[] =
20-
model.pipeline_tag === "text-generation"
21-
? [{ role: "user", content: "What is the capital of France?" }]
22-
: [
23-
{
24-
role: "user",
25-
content: [
26-
{
27-
type: "image_url",
28-
image_url: {
29-
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
30-
},
31-
},
32-
{ type: "text", text: "Describe this image in one sentence." },
33-
],
34-
},
35-
];
19+
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
3620
const messages = opts?.messages ?? exampleMessages;
3721
const messagesStr = stringifyMessages(messages, {
3822
sep: ",\n\t",

0 commit comments

Comments
 (0)