Skip to content

Generate KerasHub snippets based on tasks from metadata.json #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2025

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Jan 20, 2025

This PR updates the keras-hub snippets based on the new metadata.json > tasks field. This field is now uploaded for all KerasHub models (see keras-team/keras-hub#1997) and contains the list of tasks compatible for a given model. This allows us to generate multiple snippets when relevant. For instance, keras/stable_diffusion_3.5_large_turbo is compatible with ImageToImage, Inpaint and TextToImage tasks.

For this PR to work, we'll need to parse the metadata.json file server-side. This is done in https://github.com/huggingface-internal/moon-landing/pull/11693 (private PR). We can merge these 2 PRs independently.

cc @martin-gorner @mattdangerw @SamanehSaadat who coordinated this


Note: I also removed the legacy keras-nlp library (only 18 remaining models).

Copy link
Member

@julien-c julien-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's wait for some reviews from the Keras side, but looks great to me!! 💥

Thanks for pushing this accross the finish line, we will have great support for Keras on the huggingface Hub. cc @fchollet for visibility too

CausalLM: _keras_hub_causal_lm,
TextToImage: _keras_hub_text_to_image,
TextClassifier: _keras_hub_text_classifier,
ImageClassifier: _keras_hub_image_classifier,
Copy link

@divyashreepathihalli divyashreepathihalli Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a few more for ImageSegmenter, ImagetoImage, Inpaint, ImageObjectDetector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'd be great! Would be nice if we can settle this PR with the existing snippets first and once we agreed on the format / what we put in the snippets then open a new PR with remaining tasks

Copy link

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Mostly comments on the classifiers. I think showing them with a fit call will make the snippets a little more general.

import keras_hub

# Load a task-specific model (*replace CausalLM with your task*)
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
# Load TextClassifier model
Copy link

@mattdangerw mattdangerw Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because our most popular classifiers won't have pretrained heads (thinking of bert likes), I think this makes sense to show with some toy fine-tuning. E.g.

import keras_hub
import keras

# Load TextClassifier model
text_classifier = keras_hub.models.TextClassifier.from_preset(
    ${modelId},
    num_classes=2,
)
# Fine-tune
text_classifier.fit(x=["Thilling adventure!", "Total snoozefest."], y=[1, 0])
# Classify text
text_classifier.predict(["Not my cup of tea."])

Copy link
Contributor Author

@Wauplin Wauplin Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we need to provide num_classes=2 ? Shouldn't this depend on the model we chose ? I feel that we should aim at showcasing how to use the model on the specific task it has been trained on (i.e. the simplest use case). If we add a num_classes=2 + a .fit(...) step, then a normal user won't be able to just use the model. They would need to know they have to remove the num_classes parameter.


@mattdangerw If we want to show how to finetune + other use cases, what do you think about linking to the base class documentation page like this?

# Check out https://keras.io/keras_hub/api/base_classes/text_classifier/ for more examples.
import keras_hub

# Load TextClassifier model
text_classifier = keras_hub.models.TextClassifier.from_preset(${modelId})

# Classify text
text_classifier.predict("Keras is a multi-backend ML framework.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no super strong opinion but in our experience it's good to keep snippets as minimal as possible (and focus on predict rather than fit)

Your call in the end though!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

The issue is really with the bert-likes (but those are fairly prominent in quick starts, etc). Since these have no built in head, if you try to create those without specifying num_classes you actually get an error. Something like...

>>> keras_hub.models.TextClassifier.from_preset("roberta_base_en")
Error you must supply num_classes to the classifier

I know the transformer answer here is to default these as "fill mask" models. And we could consider something like that but I think it's probably better done as a follow up. Almost everyone will be using these models for classification, and the code I'm showing above will work generally for classifiers (unlike the code in the PR currently).

Let me know if that makes sense! We could also ditch the fit() line, but we'd be showing random predictions because the heads won't be trained.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes perfect sense! Then we'll go with what you've suggested, e.g. num_classes + .fit example. Thanks for the explanation :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in e608ac0

Copy link

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! LGTM!

@Wauplin
Copy link
Contributor Author

Wauplin commented Jan 23, 2025

(failing test is unrelated)

@Wauplin Wauplin merged commit 69ceb3a into main Jan 23, 2025
4 of 5 checks passed
@Wauplin Wauplin deleted the keras-hub-snippets branch January 23, 2025 09:11
@Wauplin Wauplin mentioned this pull request Jan 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants