|
1 | 1 | import type { PipelineType } from "../pipelines";
|
| 2 | +import type { ChatCompletionInputMessage } from "../tasks"; |
2 | 3 | import type { ModelDataMinimal } from "./types";
|
3 | 4 |
|
4 | 5 | const inputsZeroShotClassification = () =>
|
@@ -40,7 +41,30 @@ const inputsTextClassification = () => `"I like you. I love you"`;
|
40 | 41 |
|
41 | 42 | const inputsTokenClassification = () => `"My name is Sarah Jessica Parker but you can call me Jessica"`;
|
42 | 43 |
|
43 |
| -const inputsTextGeneration = () => `"Can you please let us know more details about your "`; |
| 44 | +const inputsTextGeneration = (model: ModelDataMinimal): string | ChatCompletionInputMessage[] => { |
| 45 | + if (model.tags.includes("conversational")) { |
| 46 | + return model.pipeline_tag === "text-generation" |
| 47 | + ? [{ role: "user", content: "What is the capital of France?" }] |
| 48 | + : [ |
| 49 | + { |
| 50 | + role: "user", |
| 51 | + content: [ |
| 52 | + { |
| 53 | + type: "text", |
| 54 | + text: "Describe this image in one sentence.", |
| 55 | + }, |
| 56 | + { |
| 57 | + type: "image_url", |
| 58 | + image_url: { |
| 59 | + url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", |
| 60 | + }, |
| 61 | + }, |
| 62 | + ], |
| 63 | + }, |
| 64 | + ]; |
| 65 | + } |
| 66 | + return `"Can you please let us know more details about your "`; |
| 67 | +}; |
44 | 68 |
|
45 | 69 | const inputsText2TextGeneration = () => `"The answer to the universe is"`;
|
46 | 70 |
|
@@ -84,7 +108,7 @@ const inputsTabularPrediction = () =>
|
84 | 108 | const inputsZeroShotImageClassification = () => `"cats.jpg"`;
|
85 | 109 |
|
86 | 110 | const modelInputSnippets: {
|
87 |
| - [key in PipelineType]?: (model: ModelDataMinimal) => string; |
| 111 | + [key in PipelineType]?: (model: ModelDataMinimal) => string | ChatCompletionInputMessage[]; |
88 | 112 | } = {
|
89 | 113 | "audio-to-audio": inputsAudioToAudio,
|
90 | 114 | "audio-classification": inputsAudioClassification,
|
@@ -116,18 +140,24 @@ const modelInputSnippets: {
|
116 | 140 |
|
117 | 141 | // Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)
|
118 | 142 | // Use noQuotes to strip quotes from start & end (example: "abc" -> abc)
|
119 |
| -export function getModelInputSnippet(model: ModelDataMinimal, noWrap = false, noQuotes = false): string { |
| 143 | +export function getModelInputSnippet( |
| 144 | + model: ModelDataMinimal, |
| 145 | + noWrap = false, |
| 146 | + noQuotes = false |
| 147 | +): string | ChatCompletionInputMessage[] { |
120 | 148 | if (model.pipeline_tag) {
|
121 | 149 | const inputs = modelInputSnippets[model.pipeline_tag];
|
122 | 150 | if (inputs) {
|
123 | 151 | let result = inputs(model);
|
124 |
| - if (noWrap) { |
125 |
| - result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " "); |
126 |
| - } |
127 |
| - if (noQuotes) { |
128 |
| - const REGEX_QUOTES = /^"(.+)"$/s; |
129 |
| - const match = result.match(REGEX_QUOTES); |
130 |
| - result = match ? match[1] : result; |
| 152 | + if (typeof result === "string") { |
| 153 | + if (noWrap) { |
| 154 | + result = result.replace(/(?:(?:\r?\n|\r)\t*)|\t+/g, " "); |
| 155 | + } |
| 156 | + if (noQuotes) { |
| 157 | + const REGEX_QUOTES = /^"(.+)"$/s; |
| 158 | + const match = result.match(REGEX_QUOTES); |
| 159 | + result = match ? match[1] : result; |
| 160 | + } |
131 | 161 | }
|
132 | 162 | return result;
|
133 | 163 | }
|
|
0 commit comments