Skip to content

Commit 3881d12

Browse files
martin-gornerpcuencaWauplinjulien-cVaibhavs10
authored
initial keras-hub support (#986)
Initial support for Keras-hub: - generic snippet for loading models - library definition with metadata This PR also adds "keras-hub" as one of the "filtered" i.e. top libraries and removes "tf-keras" from the list. This affects the default library filters in "Models" and the library completions in the Model details edition form. --------- Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Lucain <[email protected]> Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: vb <[email protected]>
1 parent b0225c4 commit 3881d12

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,9 @@ model = GLiNER.from_pretrained("${model.id}")`,
364364
];
365365

366366
export const keras = (model: ModelData): string[] => [
367-
`# Available backend options are: "jax", "tensorflow", "torch".
367+
`# Available backend options are: "jax", "torch", "tensorflow".
368368
import os
369-
os.environ["KERAS_BACKEND"] = "tensorflow"
369+
os.environ["KERAS_BACKEND"] = "jax"
370370
371371
import keras
372372
@@ -375,9 +375,9 @@ model = keras.saving.load_model("hf://${model.id}")
375375
];
376376

377377
export const keras_nlp = (model: ModelData): string[] => [
378-
`# Available backend options are: "jax", "tensorflow", "torch".
378+
`# Available backend options are: "jax", "torch", "tensorflow".
379379
import os
380-
os.environ["KERAS_BACKEND"] = "tensorflow"
380+
os.environ["KERAS_BACKEND"] = "jax"
381381
382382
import keras_nlp
383383
@@ -386,6 +386,21 @@ backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
386386
`,
387387
];
388388

389+
export const keras_hub = (model: ModelData): string[] => [
390+
`# Available backend options are: "jax", "torch", "tensorflow".
391+
import os
392+
os.environ["KERAS_BACKEND"] = "jax"
393+
394+
import keras_hub
395+
396+
# Load a task-specific model (*replace CausalLM with your task*)
397+
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
398+
399+
# Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
400+
# full list here: https://keras.io/api/keras_hub/models/#api-documentation
401+
`,
402+
];
403+
389404
export const llama_cpp_python = (model: ModelData): string[] => [
390405
`from llama_cpp import Llama
391406

packages/tasks/src/model-libraries.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,16 +345,23 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
345345
repoUrl: "https://github.com/keras-team/tf-keras",
346346
docsUrl: "https://huggingface.co/docs/hub/tf-keras",
347347
snippets: snippets.tf_keras,
348-
filter: true,
349348
countDownloads: `path:"saved_model.pb"`,
350349
},
351350
"keras-nlp": {
352351
prettyLabel: "KerasNLP",
353352
repoName: "KerasNLP",
354-
repoUrl: "https://keras.io/keras_nlp/",
355-
docsUrl: "https://github.com/keras-team/keras-nlp",
353+
repoUrl: "https://github.com/keras-team/keras-nlp",
354+
docsUrl: "https://keras.io/keras_nlp/",
356355
snippets: snippets.keras_nlp,
357356
},
357+
"keras-hub": {
358+
prettyLabel: "KerasHub",
359+
repoName: "KerasHub",
360+
repoUrl: "https://github.com/keras-team/keras-hub",
361+
docsUrl: "https://keras.io/keras_hub/",
362+
snippets: snippets.keras_hub,
363+
filter: true,
364+
},
358365
k2: {
359366
prettyLabel: "K2",
360367
repoName: "k2",

0 commit comments

Comments
 (0)