Skip to content

better snippets for KerasHub models #1021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/tasks/src/model-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ export interface ModelData {
base_model_name_or_path?: string;
task_type?: string;
};
keras_hub_task_json?: {
class_name: string;
alt_class_names?: string[];
};
keras_hub_config_json?: {
class_name: string;
};
keras_hub_tokenizer_json?: {
class_name: string;
};
Comment on lines +69 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slight preference for a more concise

Suggested change
keras_hub_task_json?: {
class_name: string;
alt_class_names?: string[];
};
keras_hub_config_json?: {
class_name: string;
};
keras_hub_tokenizer_json?: {
class_name: string;
};
keras_hub?: {
// relevant task.json content
};

From what I understand, task.json is the future-proof way of getting this info correctly. And getting things from config.json/tokenizer.json is more of a default for previous models up to now. Is my assumption correct or not? If that's the case, then let's focus on parsing only task.json to only promote the "correct" way.

In any case (no matter if the config comes from task.json, config.json or tokenizer.json) I think that having a single field with nested values is better rather than exposing 3 different high-level fields related to keras_hub.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a look at this recent keras-hub model (https://huggingface.co/evandrarf/health-care-gemma2-kagglex/tree/main), I can see that task.json, config.json, tokenizers.json and preprocessor.json are all set. And the content of task.json is strictly a superset of the other 3. Do you know if other files are kept for backward compatbility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that standardizing on a single config file would be the best. Let me ask the keras team?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matt on the Keras team responded. task.json is not always present and that is by design and not a legacy thing. I recommend we deploy the currently implemented logic while we continue the discussion with Matt and possibly simplify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather wait for a simplification before merging except if it's time-sensitive

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes and imo we can influence the standardization by supporting the simpler / single-version version that Wauplin mentions

};
/**
* all the model tags
Expand Down
67 changes: 59 additions & 8 deletions packages/tasks/src/model-libraries-snippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -403,20 +403,71 @@ backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
`,
];

export const keras_hub = (model: ModelData): string[] => [
`# Available backend options are: "jax", "torch", "tensorflow".
export function keras_hub(model: ModelData): string[] {
let class_name =
// If the model has a task.json config, then the base Task class is known
model.config?.keras_hub_task_json?.class_name ??
// If only a config.json is present, the base class will be a "backbone"
model.config?.keras_hub_config_json?.class_name;
Comment on lines +407 to +411
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to my comment above, if we can get rid of some logic by parsing only task.json, that would be for the best.


// Fallback heuristic until task.json is populated in more keras-hub models. For
// text-generation models only, disply "XXXCausalLM" base class instead of XXXBackbone.
if (model.pipeline_tag == "text-generation" && class_name?.endsWith("Backbone"))
class_name = class_name.replace("Backbone", "CausalLM");

// optional generation snippets
const optional_snippets = [
["text-generation", 'model.generate("Keras: deep learning for", max_length=64)'],
[
"image-text-to-text",
`output = model.generate(
inputs={
"images": image,
"prompts": prompt,
}
)`,
],
];
// Select a text generation snippet based on pipeline_tag
const selected_snippet_row = optional_snippets.filter((cols) => cols[0] == model.pipeline_tag);
const optional_snippet = selected_snippet_row.length == 0 ? "" : selected_snippet_row[0][1];

// de-duplicate possible alt classes
// from task.json
const alt_class_names = new Set(model.config?.keras_hub_task_json?.alt_class_names);
if (class_name) alt_class_names.delete(class_name);
// and from tokenizer.json
if (model.config?.keras_hub_tokenizer_json?.class_name)
alt_class_names.add(model.config?.keras_hub_tokenizer_json?.class_name);
// generate possible alternative class.from_preset() calls.
let alt_model_component_snippets = undefined;
if (alt_class_names.size > 0) {
const alt_model_component_snippet_lines = Array.from(alt_class_names).map(
(k) => `model = keras_hub.models.${k}.from_preset("hf://${model.id}")`
);
alt_model_component_snippets =
"# Individual model components can also be loaded from this preset:\n" +
alt_model_component_snippet_lines.join("\n");
}

const main_snippet = ` # Available backend options are: "jax", "torch", "tensorflow".
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras_hub

# Load a task-specific model (*replace CausalLM with your task*)
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
model = keras_hub.models.${class_name}.from_preset("hf://${model.id}")
${optional_snippet}

# Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
# full list here: https://keras.io/api/keras_hub/models/#api-documentation
`,
];
# All Keras models support: model(data), model.compile, model.fit, model.predict, model.evaluate.
# More info on this model: https://keras.io/search.html?query=${class_name}%20keras_hub
`;
const snippets = [main_snippet];
if (alt_model_component_snippets) {
snippets.push(alt_model_component_snippets);
}
return snippets;
}

export const llama_cpp_python = (model: ModelData): string[] => [
`from llama_cpp import Llama
Expand Down
Loading