Skip to content

Commit f1c2b27

Browse files
apolinariomultimodalartmishig25Vaibhavs10
authored
Improve prompting for diffusers default snippets (#909)
Addressing comments left on #907, specially for LoRAs --------- Co-authored-by: multimodalart <[email protected]> Co-authored-by: Mishig <[email protected]> Co-authored-by: vb <[email protected]>
1 parent ed90897 commit f1c2b27

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

packages/tasks/src/model-data.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ export interface ModelData {
107107
parameters?: Record<string, unknown>;
108108
};
109109
base_model?: string | string[];
110+
instance_prompt?: string;
110111
};
111112
/**
112113
* Library name

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ModelData } from "./model-data";
2+
import type { WidgetExampleTextInput } from "./widget-example";
23
import { LIBRARY_TASK_MAPPING } from "./library-to-tasks";
34

45
const TAG_CUSTOM_CODE = "custom_code";
@@ -8,6 +9,8 @@ function nameWithoutNamespace(modelId: string): string {
89
return splitted.length === 1 ? splitted[0] : splitted[1];
910
}
1011

12+
const escapeStringForJson = (str: string): string => JSON.stringify(str);
13+
1114
//#region snippets
1215

1316
export const adapters = (model: ModelData): string[] => [
@@ -70,6 +73,13 @@ function get_base_diffusers_model(model: ModelData): string {
7073
return model.cardData?.base_model?.toString() ?? "fill-in-base-model";
7174
}
7275

76+
function get_prompt_from_diffusers_model(model: ModelData): string | undefined {
77+
const prompt = (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt;
78+
if (prompt) {
79+
return escapeStringForJson(prompt);
80+
}
81+
}
82+
7383
export const bertopic = (model: ModelData): string[] => [
7484
`from bertopic import BERTopic
7585
@@ -129,12 +139,14 @@ depth = model.infer_image(raw_img) # HxW raw depth map in numpy
129139
];
130140
};
131141

142+
const diffusersDefaultPrompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k";
143+
132144
const diffusers_default = (model: ModelData) => [
133145
`from diffusers import DiffusionPipeline
134146
135147
pipe = DiffusionPipeline.from_pretrained("${model.id}")
136148
137-
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
149+
prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}"
138150
image = pipe(prompt).images[0]`,
139151
];
140152

@@ -153,7 +165,7 @@ const diffusers_lora = (model: ModelData) => [
153165
pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}")
154166
pipe.load_lora_weights("${model.id}")
155167
156-
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
168+
prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}"
157169
image = pipe(prompt).images[0]`,
158170
];
159171

0 commit comments

Comments
 (0)