|
1 | 1 | import type { ModelData } from "./model-data";
|
2 |
| -import type { WidgetExampleTextInput } from "./widget-example"; |
| 2 | +import type { WidgetExampleTextInput, WidgetExampleSentenceSimilarityInput } from "./widget-example"; |
3 | 3 | import { LIBRARY_TASK_MAPPING } from "./library-to-tasks";
|
4 | 4 |
|
5 | 5 | const TAG_CUSTOM_CODE = "custom_code";
|
@@ -704,13 +704,32 @@ export const sampleFactory = (model: ModelData): string[] => [
|
704 | 704 | `python -m sample_factory.huggingface.load_from_hub -r ${model.id} -d ./train_dir`,
|
705 | 705 | ];
|
706 | 706 |
|
| 707 | +function get_widget_examples_from_st_model(model: ModelData): string[] | undefined { |
| 708 | + const widgetExample = model.widgetData?.[0] as WidgetExampleSentenceSimilarityInput | undefined; |
| 709 | + if (widgetExample) { |
| 710 | + return [widgetExample.source_sentence, ...widgetExample.sentences]; |
| 711 | + } |
| 712 | +} |
| 713 | + |
707 | 714 | export const sentenceTransformers = (model: ModelData): string[] => {
|
708 | 715 | const remote_code_snippet = model.tags.includes(TAG_CUSTOM_CODE) ? ", trust_remote_code=True" : "";
|
| 716 | + const exampleSentences = get_widget_examples_from_st_model(model) ?? [ |
| 717 | + "The weather is lovely today.", |
| 718 | + "It's so sunny outside!", |
| 719 | + "He drove to the stadium.", |
| 720 | + ]; |
709 | 721 |
|
710 | 722 | return [
|
711 | 723 | `from sentence_transformers import SentenceTransformer
|
712 | 724 |
|
713 |
| -model = SentenceTransformer("${model.id}"${remote_code_snippet})`, |
| 725 | +model = SentenceTransformer("${model.id}"${remote_code_snippet}) |
| 726 | +
|
| 727 | +sentences = ${JSON.stringify(exampleSentences, null, 4)} |
| 728 | +embeddings = model.encode(sentences) |
| 729 | +
|
| 730 | +similarities = model.similarity(embeddings, embeddings) |
| 731 | +print(similarities.shape) |
| 732 | +# [${exampleSentences.length}, ${exampleSentences.length}]`, |
714 | 733 | ];
|
715 | 734 | };
|
716 | 735 |
|
|
0 commit comments