@@ -394,32 +394,91 @@ model = keras.saving.load_model("hf://${model.id}")
394
394
` ,
395
395
] ;
396
396
397
- export const keras_nlp = ( model : ModelData ) : string [ ] => [
398
- `# Available backend options are: "jax", "torch", "tensorflow".
399
- import os
400
- os.environ["KERAS_BACKEND"] = "jax"
397
+ const _keras_hub_causal_lm = ( modelId : string ) : string => `
398
+ import keras_hub
401
399
402
- import keras_nlp
400
+ # Load CausalLM model (optional: use half precision for inference)
401
+ causal_lm = keras_hub.models.CausalLM.from_preset(${ modelId } , dtype="bfloat16")
402
+ causal_lm.compile(sampler="greedy") # (optional) specify a sampler
403
403
404
- tokenizer = keras_nlp.models.Tokenizer.from_preset("hf://${ model . id } ")
405
- backbone = keras_nlp.models.Backbone.from_preset("hf://${ model . id } ")
406
- ` ,
407
- ] ;
404
+ # Generate text
405
+ causal_lm.generate("Keras: deep learning for", max_length=64)
406
+ ` ;
408
407
409
- export const keras_hub = ( model : ModelData ) : string [ ] => [
410
- `# Available backend options are: "jax", "torch", "tensorflow".
411
- import os
412
- os.environ["KERAS_BACKEND"] = "jax"
408
+ const _keras_hub_text_to_image = ( modelId : string ) : string => `
409
+ import keras_hub
410
+
411
+ # Load TextToImage model (optional: use half precision for inference)
412
+ text_to_image = keras_hub.models.TextToImage.from_preset(${ modelId } , dtype="bfloat16")
413
+
414
+ # Generate images with a TextToImage model.
415
+ text_to_image.generate("Astronaut in a jungle")
416
+ ` ;
413
417
418
+ const _keras_hub_text_classifier = ( modelId : string ) : string => `
414
419
import keras_hub
415
420
416
- # Load a task-specific model (*replace CausalLM with your task*)
417
- model = keras_hub.models.CausalLM .from_preset("hf:// ${ model . id } ", dtype="bfloat16" )
421
+ # Load TextClassifier model
422
+ text_classifier = keras_hub.models.TextClassifier .from_preset(${ modelId } )
418
423
419
- # Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
420
- # full list here: https://keras.io/api/keras_hub/models/#api-documentation
421
- ` ,
422
- ] ;
424
+ # Classify text
425
+ text_classifier.predict("Keras is a multi-backend ML framework.")
426
+ ` ;
427
+
428
+ const _keras_hub_image_classifier = ( modelId : string ) : string => `
429
+ import keras_hub
430
+
431
+ # Load ImageClassifier model
432
+ text_classifier = keras_hub.models.ImageClassifier.from_preset(${ modelId } )
433
+
434
+ # Classify image
435
+ image_classifier.predict(keras.ops.ones((1, 64, 64, 3)))
436
+ ` ;
437
+
438
+ const _keras_hub_tasks_with_example = {
439
+ CausalLM : _keras_hub_causal_lm ,
440
+ TextToImage : _keras_hub_text_to_image ,
441
+ TextClassifier : _keras_hub_text_classifier ,
442
+ ImageClassifier : _keras_hub_image_classifier ,
443
+ } ;
444
+
445
+ const _keras_hub_task_without_example = ( task : string , modelId : string ) : string => `
446
+ import keras_hub
447
+
448
+ # Create a ${ task } model
449
+ task = keras_hub.models.${ task } .from_preset(${ modelId } )
450
+ ` ;
451
+
452
+ const _keras_hub_generic_backbone = ( modelId : string ) : string => `
453
+ import keras_hub
454
+
455
+ # Create a Backbone model unspecialized for any task
456
+ backbone = keras_hub.models.Backbone.from_preset(${ modelId } )
457
+ ` ;
458
+
459
+ export const keras_hub = ( model : ModelData ) : string [ ] => {
460
+ const modelId = model . id ;
461
+ const tasks = model . config ?. keras_hub ?. tasks ?? [ ] ;
462
+
463
+ const snippets : string [ ] = [ ] ;
464
+
465
+ // First, generate tasks with examples
466
+ for ( const [ task , snippet ] of Object . entries ( _keras_hub_tasks_with_example ) ) {
467
+ if ( tasks . includes ( task ) ) {
468
+ snippets . push ( snippet ( modelId ) ) ;
469
+ }
470
+ }
471
+ // Then, add remaining tasks
472
+ for ( const task in tasks ) {
473
+ if ( ! Object . keys ( _keras_hub_tasks_with_example ) . includes ( task ) ) {
474
+ snippets . push ( _keras_hub_task_without_example ( task , modelId ) ) ;
475
+ }
476
+ }
477
+ // Finally, add generic backbone snippet
478
+ snippets . push ( _keras_hub_generic_backbone ( modelId ) ) ;
479
+
480
+ return snippets ;
481
+ } ;
423
482
424
483
export const llama_cpp_python = ( model : ModelData ) : string [ ] => {
425
484
const snippets = [
0 commit comments