Skip to content

Commit 6949f03

Browse files
committed
Generate KerasHub snippets based on tasks
1 parent 8542173 commit 6949f03

File tree

2 files changed

+81
-19
lines changed

2 files changed

+81
-19
lines changed

packages/tasks/src/model-data.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ export interface ModelData {
6666
base_model_name_or_path?: string;
6767
task_type?: string;
6868
};
69+
keras_hub?: {
70+
tasks?: string[];
71+
};
6972
};
7073
/**
7174
* all the model tags

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -394,32 +394,91 @@ model = keras.saving.load_model("hf://${model.id}")
394394
`,
395395
];
396396

397-
export const keras_nlp = (model: ModelData): string[] => [
398-
`# Available backend options are: "jax", "torch", "tensorflow".
399-
import os
400-
os.environ["KERAS_BACKEND"] = "jax"
397+
const _keras_hub_causal_lm = (modelId: string): string => `
398+
import keras_hub
401399
402-
import keras_nlp
400+
# Load CausalLM model (optional: use half precision for inference)
401+
causal_lm = keras_hub.models.CausalLM.from_preset(${modelId}, dtype="bfloat16")
402+
causal_lm.compile(sampler="greedy") # (optional) specify a sampler
403403
404-
tokenizer = keras_nlp.models.Tokenizer.from_preset("hf://${model.id}")
405-
backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
406-
`,
407-
];
404+
# Generate text
405+
causal_lm.generate("Keras: deep learning for", max_length=64)
406+
`;
408407

409-
export const keras_hub = (model: ModelData): string[] => [
410-
`# Available backend options are: "jax", "torch", "tensorflow".
411-
import os
412-
os.environ["KERAS_BACKEND"] = "jax"
408+
const _keras_hub_text_to_image = (modelId: string): string => `
409+
import keras_hub
410+
411+
# Load TextToImage model (optional: use half precision for inference)
412+
text_to_image = keras_hub.models.TextToImage.from_preset(${modelId}, dtype="bfloat16")
413+
414+
# Generate images with a TextToImage model.
415+
text_to_image.generate("Astronaut in a jungle")
416+
`;
413417

418+
const _keras_hub_text_classifier = (modelId: string): string => `
414419
import keras_hub
415420
416-
# Load a task-specific model (*replace CausalLM with your task*)
417-
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
421+
# Load TextClassifier model
422+
text_classifier = keras_hub.models.TextClassifier.from_preset(${modelId})
418423
419-
# Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
420-
# full list here: https://keras.io/api/keras_hub/models/#api-documentation
421-
`,
422-
];
424+
# Classify text
425+
text_classifier.predict("Keras is a multi-backend ML framework.")
426+
`;
427+
428+
const _keras_hub_image_classifier = (modelId: string): string => `
429+
import keras_hub
430+
431+
# Load ImageClassifier model
432+
text_classifier = keras_hub.models.ImageClassifier.from_preset(${modelId})
433+
434+
# Classify image
435+
image_classifier.predict(keras.ops.ones((1, 64, 64, 3)))
436+
`;
437+
438+
const _keras_hub_tasks_with_example = {
439+
CausalLM: _keras_hub_causal_lm,
440+
TextToImage: _keras_hub_text_to_image,
441+
TextClassifier: _keras_hub_text_classifier,
442+
ImageClassifier: _keras_hub_image_classifier,
443+
};
444+
445+
const _keras_hub_task_without_example = (task: string, modelId: string): string => `
446+
import keras_hub
447+
448+
# Create a ${task} model
449+
task = keras_hub.models.${task}.from_preset(${modelId})
450+
`;
451+
452+
const _keras_hub_generic_backbone = (modelId: string): string => `
453+
import keras_hub
454+
455+
# Create a Backbone model unspecialized for any task
456+
backbone = keras_hub.models.Backbone.from_preset(${modelId})
457+
`;
458+
459+
export const keras_hub = (model: ModelData): string[] => {
460+
const modelId = model.id;
461+
const tasks = model.config?.keras_hub?.tasks ?? [];
462+
463+
const snippets: string[] = [];
464+
465+
// First, generate tasks with examples
466+
for (const [task, snippet] of Object.entries(_keras_hub_tasks_with_example)) {
467+
if (tasks.includes(task)) {
468+
snippets.push(snippet(modelId));
469+
}
470+
}
471+
// Then, add remaining tasks
472+
for (const task in tasks) {
473+
if (!Object.keys(_keras_hub_tasks_with_example).includes(task)) {
474+
snippets.push(_keras_hub_task_without_example(task, modelId));
475+
}
476+
}
477+
// Finally, add generic backbone snippet
478+
snippets.push(_keras_hub_generic_backbone(modelId));
479+
480+
return snippets;
481+
};
423482

424483
export const llama_cpp_python = (model: ModelData): string[] => {
425484
const snippets = [

0 commit comments

Comments
 (0)