@@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
4
4
import { getModelInputSnippet } from "./inputs.js" ;
5
5
import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
6
6
7
+ const snippetImportInferenceClient = ( model : ModelDataMinimal , accessToken : string ) : string =>
8
+ `from huggingface_hub import InferenceClient
9
+
10
+ client = InferenceClient(${ model . id } , token="${ accessToken || "{API_TOKEN}" } ")` ;
11
+
7
12
export const snippetConversational = (
8
13
model : ModelDataMinimal ,
9
14
accessToken : string ,
@@ -184,18 +189,31 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
184
189
output = query(${ getModelInputSnippet ( model ) } )` ,
185
190
} ) ;
186
191
187
- export const snippetTextToImage = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
188
- content : `def query(payload):
192
+ export const snippetTextToImage = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => {
193
+ return [
194
+ {
195
+ client : "requests" ,
196
+ content : `def query(payload):
189
197
response = requests.post(API_URL, headers=headers, json=payload)
190
198
return response.content
199
+
191
200
image_bytes = query({
192
201
"inputs": ${ getModelInputSnippet ( model ) } ,
193
202
})
194
203
# You can access the image with PIL.Image for example
195
204
import io
196
205
from PIL import Image
197
206
image = Image.open(io.BytesIO(image_bytes))` ,
198
- } ) ;
207
+ } ,
208
+ {
209
+ client : "huggingface_hub" ,
210
+ content : `${ snippetImportInferenceClient ( model , accessToken ) }
211
+
212
+ # output is a PIL.Image object
213
+ image = client.text_to_image(${ getModelInputSnippet ( model ) } )` ,
214
+ } ,
215
+ ] ;
216
+ } ;
199
217
200
218
export const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
201
219
content : `def query(payload):
@@ -300,6 +318,9 @@ export function getPythonInferenceSnippet(
300
318
if ( model . tags . includes ( "conversational" ) ) {
301
319
// Conversational model detected, so we display a code snippet that features the Messages API
302
320
return snippetConversational ( model , accessToken , opts ) ;
321
+ } else if ( model . pipeline_tag == "text-to-image" ) {
322
+ // TODO: factorize this logic
323
+ return snippetTextToImage ( model , accessToken ) ;
303
324
} else {
304
325
let snippets =
305
326
model . pipeline_tag && model . pipeline_tag in pythonSnippets
0 commit comments