Skip to content

Generate KerasHub snippets based on tasks from metadata.json #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/tasks/src/model-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ export interface ModelData {
base_model_name_or_path?: string;
task_type?: string;
};
keras_hub?: {
tasks?: string[];
};
};
/**
* all the model tags
Expand Down
109 changes: 90 additions & 19 deletions packages/tasks/src/model-libraries-snippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -394,32 +394,103 @@ model = keras.saving.load_model("hf://${model.id}")
`,
];

export const keras_nlp = (model: ModelData): string[] => [
`# Available backend options are: "jax", "torch", "tensorflow".
import os
os.environ["KERAS_BACKEND"] = "jax"
const _keras_hub_causal_lm = (modelId: string): string => `
import keras_hub

import keras_nlp
# Load CausalLM model (optional: use half precision for inference)
causal_lm = keras_hub.models.CausalLM.from_preset(${modelId}, dtype="bfloat16")
causal_lm.compile(sampler="greedy") # (optional) specify a sampler

tokenizer = keras_nlp.models.Tokenizer.from_preset("hf://${model.id}")
backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
`,
];
# Generate text
causal_lm.generate("Keras: deep learning for", max_length=64)
`;

export const keras_hub = (model: ModelData): string[] => [
`# Available backend options are: "jax", "torch", "tensorflow".
import os
os.environ["KERAS_BACKEND"] = "jax"
const _keras_hub_text_to_image = (modelId: string): string => `
import keras_hub

# Load TextToImage model (optional: use half precision for inference)
text_to_image = keras_hub.models.TextToImage.from_preset(${modelId}, dtype="bfloat16")

# Generate images with a TextToImage model.
text_to_image.generate("Astronaut in a jungle")
`;

const _keras_hub_text_classifier = (modelId: string): string => `
import keras_hub

# Load a task-specific model (*replace CausalLM with your task*)
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
# Load TextClassifier model
Copy link

@mattdangerw mattdangerw Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because our most popular classifiers won't have pretrained heads (thinking of bert likes), I think this makes sense to show with some toy fine-tuning. E.g.

import keras_hub
import keras

# Load TextClassifier model
text_classifier = keras_hub.models.TextClassifier.from_preset(
    ${modelId},
    num_classes=2,
)
# Fine-tune
text_classifier.fit(x=["Thilling adventure!", "Total snoozefest."], y=[1, 0])
# Classify text
text_classifier.predict(["Not my cup of tea."])

Copy link
Contributor Author

@Wauplin Wauplin Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we need to provide num_classes=2 ? Shouldn't this depend on the model we chose ? I feel that we should aim at showcasing how to use the model on the specific task it has been trained on (i.e. the simplest use case). If we add a num_classes=2 + a .fit(...) step, then a normal user won't be able to just use the model. They would need to know they have to remove the num_classes parameter.


@mattdangerw If we want to show how to finetune + other use cases, what do you think about linking to the base class documentation page like this?

# Check out https://keras.io/keras_hub/api/base_classes/text_classifier/ for more examples.
import keras_hub

# Load TextClassifier model
text_classifier = keras_hub.models.TextClassifier.from_preset(${modelId})

# Classify text
text_classifier.predict("Keras is a multi-backend ML framework.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no super strong opinion but in our experience it's good to keep snippets as minimal as possible (and focus on predict rather than fit)

Your call in the end though!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

The issue is really with the bert-likes (but those are fairly prominent in quick starts, etc). Since these have no built in head, if you try to create those without specifying num_classes you actually get an error. Something like...

>>> keras_hub.models.TextClassifier.from_preset("roberta_base_en")
Error you must supply num_classes to the classifier

I know the transformer answer here is to default these as "fill mask" models. And we could consider something like that but I think it's probably better done as a follow up. Almost everyone will be using these models for classification, and the code I'm showing above will work generally for classifiers (unlike the code in the PR currently).

Let me know if that makes sense! We could also ditch the fit() line, but we'd be showing random predictions because the heads won't be trained.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes perfect sense! Then we'll go with what you've suggested, e.g. num_classes + .fit example. Thanks for the explanation :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in e608ac0

text_classifier = keras_hub.models.TextClassifier.from_preset(
${modelId},
num_classes=2,
)
# Fine-tune
text_classifier.fit(x=["Thilling adventure!", "Total snoozefest."], y=[1, 0])
# Classify text
text_classifier.predict(["Not my cup of tea."])
`;

# Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
# full list here: https://keras.io/api/keras_hub/models/#api-documentation
`,
];
const _keras_hub_image_classifier = (modelId: string): string => `
import keras_hub
import keras

# Load ImageClassifier model
image_classifier = keras_hub.models.ImageClassifier.from_preset(
${modelId},
num_classes=2,
)
# Fine-tune
image_classifier.fit(
x=keras.random.randint((32, 64, 64, 3), 0, 256),
y=keras.random.randint((32, 1), 0, 2),
)
# Classify image
image_classifier.predict(keras.random.randint((1, 64, 64, 3), 0, 256))
`;

const _keras_hub_tasks_with_example = {
CausalLM: _keras_hub_causal_lm,
TextToImage: _keras_hub_text_to_image,
TextClassifier: _keras_hub_text_classifier,
ImageClassifier: _keras_hub_image_classifier,
Copy link

@divyashreepathihalli divyashreepathihalli Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a few more for ImageSegmenter, ImagetoImage, Inpaint, ImageObjectDetector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'd be great! Would be nice if we can settle this PR with the existing snippets first and once we agreed on the format / what we put in the snippets then open a new PR with remaining tasks

};

const _keras_hub_task_without_example = (task: string, modelId: string): string => `
import keras_hub

# Create a ${task} model
task = keras_hub.models.${task}.from_preset(${modelId})
`;

const _keras_hub_generic_backbone = (modelId: string): string => `
import keras_hub

# Create a Backbone model unspecialized for any task
backbone = keras_hub.models.Backbone.from_preset(${modelId})
`;

export const keras_hub = (model: ModelData): string[] => {
const modelId = model.id;
const tasks = model.config?.keras_hub?.tasks ?? [];

const snippets: string[] = [];

// First, generate tasks with examples
for (const [task, snippet] of Object.entries(_keras_hub_tasks_with_example)) {
if (tasks.includes(task)) {
snippets.push(snippet(modelId));
}
}
// Then, add remaining tasks
for (const task in tasks) {
if (!Object.keys(_keras_hub_tasks_with_example).includes(task)) {
snippets.push(_keras_hub_task_without_example(task, modelId));
}
}
// Finally, add generic backbone snippet
snippets.push(_keras_hub_generic_backbone(modelId));

return snippets;
};

export const llama_cpp_python = (model: ModelData): string[] => {
const snippets = [
Expand Down
7 changes: 0 additions & 7 deletions packages/tasks/src/model-libraries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,6 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
snippets: snippets.tf_keras,
countDownloads: `path:"saved_model.pb"`,
},
"keras-nlp": {
prettyLabel: "KerasNLP",
repoName: "KerasNLP",
repoUrl: "https://github.com/keras-team/keras-nlp",
docsUrl: "https://keras.io/keras_nlp/",
snippets: snippets.keras_nlp,
},
"keras-hub": {
prettyLabel: "KerasHub",
repoName: "KerasHub",
Expand Down
Loading