-
Notifications
You must be signed in to change notification settings - Fork 434
Generate KerasHub snippets based on tasks from metadata.json
#1118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's wait for some reviews from the Keras side, but looks great to me!! 💥
Thanks for pushing this accross the finish line, we will have great support for Keras on the huggingface Hub. cc @fchollet for visibility too
CausalLM: _keras_hub_causal_lm, | ||
TextToImage: _keras_hub_text_to_image, | ||
TextClassifier: _keras_hub_text_classifier, | ||
ImageClassifier: _keras_hub_image_classifier, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a few more for ImageSegmenter, ImagetoImage, Inpaint, ImageObjectDetector
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That'd be great! Would be nice if we can settle this PR with the existing snippets first and once we agreed on the format / what we put in the snippets then open a new PR with remaining tasks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Mostly comments on the classifiers. I think showing them with a fit call will make the snippets a little more general.
import keras_hub | ||
|
||
# Load a task-specific model (*replace CausalLM with your task*) | ||
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16") | ||
# Load TextClassifier model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because our most popular classifiers won't have pretrained heads (thinking of bert likes), I think this makes sense to show with some toy fine-tuning. E.g.
import keras_hub
import keras
# Load TextClassifier model
text_classifier = keras_hub.models.TextClassifier.from_preset(
${modelId},
num_classes=2,
)
# Fine-tune
text_classifier.fit(x=["Thilling adventure!", "Total snoozefest."], y=[1, 0])
# Classify text
text_classifier.predict(["Not my cup of tea."])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why we need to provide num_classes=2
? Shouldn't this depend on the model we chose ? I feel that we should aim at showcasing how to use the model on the specific task it has been trained on (i.e. the simplest use case). If we add a num_classes=2
+ a .fit(...)
step, then a normal user won't be able to just use the model. They would need to know they have to remove the num_classes
parameter.
@mattdangerw If we want to show how to finetune + other use cases, what do you think about linking to the base class documentation page like this?
# Check out https://keras.io/keras_hub/api/base_classes/text_classifier/ for more examples.
import keras_hub
# Load TextClassifier model
text_classifier = keras_hub.models.TextClassifier.from_preset(${modelId})
# Classify text
text_classifier.predict("Keras is a multi-backend ML framework.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no super strong opinion but in our experience it's good to keep snippets as minimal as possible (and focus on predict rather than fit)
Your call in the end though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
The issue is really with the bert-likes (but those are fairly prominent in quick starts, etc). Since these have no built in head, if you try to create those without specifying num_classes
you actually get an error. Something like...
>>> keras_hub.models.TextClassifier.from_preset("roberta_base_en")
Error you must supply num_classes to the classifier
I know the transformer answer here is to default these as "fill mask" models. And we could consider something like that but I think it's probably better done as a follow up. Almost everyone will be using these models for classification, and the code I'm showing above will work generally for classifiers (unlike the code in the PR currently).
Let me know if that makes sense! We could also ditch the fit()
line, but we'd be showing random predictions because the heads won't be trained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes perfect sense! Then we'll go with what you've suggested, e.g. num_classes + .fit example. Thanks for the explanation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in e608ac0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! LGTM!
(failing test is unrelated) |
TIL https://stackoverflow.com/a/29286412. See https://huggingface.co/keras/stable_diffusion_3.5_large?library=keras-hub (invalid snippets) (follow-up PR after #1118)
Another oversight in my KerasHub PR: #1118 Currently on https://huggingface.co/keras/stable_diffusion_3.5_large?library=keras-hub: 
This PR updates the
keras-hub
snippets based on the newmetadata.json > tasks
field. This field is now uploaded for all KerasHub models (see keras-team/keras-hub#1997) and contains the list of tasks compatible for a given model. This allows us to generate multiple snippets when relevant. For instance, keras/stable_diffusion_3.5_large_turbo is compatible withImageToImage
,Inpaint
andTextToImage
tasks.For this PR to work, we'll need to parse the metadata.json file server-side. This is done in https://github.com/huggingface-internal/moon-landing/pull/11693 (private PR). We can merge these 2 PRs independently.
cc @martin-gorner @mattdangerw @SamanehSaadat who coordinated this
Note: I also removed the legacy
keras-nlp
library (only 18 remaining models).