-
Notifications
You must be signed in to change notification settings - Fork 434
Generate KerasHub snippets based on tasks from metadata.json
#1118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -394,32 +394,103 @@ model = keras.saving.load_model("hf://${model.id}") | |
`, | ||
]; | ||
|
||
export const keras_nlp = (model: ModelData): string[] => [ | ||
`# Available backend options are: "jax", "torch", "tensorflow". | ||
import os | ||
os.environ["KERAS_BACKEND"] = "jax" | ||
const _keras_hub_causal_lm = (modelId: string): string => ` | ||
import keras_hub | ||
|
||
import keras_nlp | ||
# Load CausalLM model (optional: use half precision for inference) | ||
causal_lm = keras_hub.models.CausalLM.from_preset(${modelId}, dtype="bfloat16") | ||
causal_lm.compile(sampler="greedy") # (optional) specify a sampler | ||
|
||
tokenizer = keras_nlp.models.Tokenizer.from_preset("hf://${model.id}") | ||
backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}") | ||
`, | ||
]; | ||
# Generate text | ||
causal_lm.generate("Keras: deep learning for", max_length=64) | ||
`; | ||
|
||
export const keras_hub = (model: ModelData): string[] => [ | ||
`# Available backend options are: "jax", "torch", "tensorflow". | ||
import os | ||
os.environ["KERAS_BACKEND"] = "jax" | ||
const _keras_hub_text_to_image = (modelId: string): string => ` | ||
import keras_hub | ||
|
||
# Load TextToImage model (optional: use half precision for inference) | ||
text_to_image = keras_hub.models.TextToImage.from_preset(${modelId}, dtype="bfloat16") | ||
|
||
# Generate images with a TextToImage model. | ||
text_to_image.generate("Astronaut in a jungle") | ||
`; | ||
|
||
const _keras_hub_text_classifier = (modelId: string): string => ` | ||
import keras_hub | ||
|
||
# Load a task-specific model (*replace CausalLM with your task*) | ||
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16") | ||
# Load TextClassifier model | ||
text_classifier = keras_hub.models.TextClassifier.from_preset( | ||
${modelId}, | ||
num_classes=2, | ||
) | ||
# Fine-tune | ||
text_classifier.fit(x=["Thilling adventure!", "Total snoozefest."], y=[1, 0]) | ||
# Classify text | ||
text_classifier.predict(["Not my cup of tea."]) | ||
`; | ||
|
||
# Possible tasks are CausalLM, TextToImage, ImageClassifier, ... | ||
# full list here: https://keras.io/api/keras_hub/models/#api-documentation | ||
`, | ||
]; | ||
const _keras_hub_image_classifier = (modelId: string): string => ` | ||
import keras_hub | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import keras | ||
|
||
# Load ImageClassifier model | ||
image_classifier = keras_hub.models.ImageClassifier.from_preset( | ||
${modelId}, | ||
num_classes=2, | ||
) | ||
# Fine-tune | ||
image_classifier.fit( | ||
x=keras.random.randint((32, 64, 64, 3), 0, 256), | ||
y=keras.random.randint((32, 1), 0, 2), | ||
) | ||
# Classify image | ||
image_classifier.predict(keras.random.randint((1, 64, 64, 3), 0, 256)) | ||
`; | ||
|
||
const _keras_hub_tasks_with_example = { | ||
CausalLM: _keras_hub_causal_lm, | ||
TextToImage: _keras_hub_text_to_image, | ||
TextClassifier: _keras_hub_text_classifier, | ||
ImageClassifier: _keras_hub_image_classifier, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have a few more for ImageSegmenter, ImagetoImage, Inpaint, ImageObjectDetector There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That'd be great! Would be nice if we can settle this PR with the existing snippets first and once we agreed on the format / what we put in the snippets then open a new PR with remaining tasks |
||
}; | ||
|
||
const _keras_hub_task_without_example = (task: string, modelId: string): string => ` | ||
import keras_hub | ||
|
||
# Create a ${task} model | ||
task = keras_hub.models.${task}.from_preset(${modelId}) | ||
`; | ||
|
||
const _keras_hub_generic_backbone = (modelId: string): string => ` | ||
import keras_hub | ||
|
||
# Create a Backbone model unspecialized for any task | ||
backbone = keras_hub.models.Backbone.from_preset(${modelId}) | ||
`; | ||
|
||
export const keras_hub = (model: ModelData): string[] => { | ||
const modelId = model.id; | ||
const tasks = model.config?.keras_hub?.tasks ?? []; | ||
|
||
const snippets: string[] = []; | ||
|
||
// First, generate tasks with examples | ||
for (const [task, snippet] of Object.entries(_keras_hub_tasks_with_example)) { | ||
if (tasks.includes(task)) { | ||
snippets.push(snippet(modelId)); | ||
} | ||
} | ||
// Then, add remaining tasks | ||
for (const task in tasks) { | ||
if (!Object.keys(_keras_hub_tasks_with_example).includes(task)) { | ||
snippets.push(_keras_hub_task_without_example(task, modelId)); | ||
} | ||
} | ||
// Finally, add generic backbone snippet | ||
snippets.push(_keras_hub_generic_backbone(modelId)); | ||
|
||
return snippets; | ||
}; | ||
|
||
export const llama_cpp_python = (model: ModelData): string[] => { | ||
const snippets = [ | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because our most popular classifiers won't have pretrained heads (thinking of bert likes), I think this makes sense to show with some toy fine-tuning. E.g.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why we need to provide
num_classes=2
? Shouldn't this depend on the model we chose ? I feel that we should aim at showcasing how to use the model on the specific task it has been trained on (i.e. the simplest use case). If we add anum_classes=2
+ a.fit(...)
step, then a normal user won't be able to just use the model. They would need to know they have to remove thenum_classes
parameter.@mattdangerw If we want to show how to finetune + other use cases, what do you think about linking to the base class documentation page like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no super strong opinion but in our experience it's good to keep snippets as minimal as possible (and focus on predict rather than fit)
Your call in the end though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
The issue is really with the bert-likes (but those are fairly prominent in quick starts, etc). Since these have no built in head, if you try to create those without specifying
num_classes
you actually get an error. Something like...I know the transformer answer here is to default these as "fill mask" models. And we could consider something like that but I think it's probably better done as a follow up. Almost everyone will be using these models for classification, and the code I'm showing above will work generally for classifiers (unlike the code in the PR currently).
Let me know if that makes sense! We could also ditch the
fit()
line, but we'd be showing random predictions because the heads won't be trained.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes perfect sense! Then we'll go with what you've suggested, e.g. num_classes + .fit example. Thanks for the explanation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in e608ac0