Skip to content

Commit 468dadd

Browse files
author
Raghav Dhall
committed
documentation: Add dynamic model tables
1 parent f35d0d3 commit 468dadd

File tree

1 file changed

+117
-32
lines changed

1 file changed

+117
-32
lines changed

doc/doc_utils/jumpstart_doc_utils.py

Lines changed: 117 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ class ProblemTypes(str, Enum):
5757
TABULAR_REGRESSION = "Regression"
5858
TABULAR_CLASSIFICATION = "Classification"
5959

60+
class Frameworks(str, Enum):
61+
"""Possible frameworks for JumpStart models"""
62+
63+
TENSORFLOW = "Tensorflow Hub"
64+
PYTORCH = "Pytorch Hub"
65+
HUGGINGFACE = "HuggingFace"
66+
CATBOOST = "Catboost"
67+
GLUONCV = "GluonCV"
68+
LIGHTGBM = "LightGBM"
69+
XGBOOST = "XGBoost"
70+
SCIKIT_LEARN = "ScikitLearn"
71+
SOURCE = "Source"
72+
6073

6174
JUMPSTART_REGION = "eu-west-2"
6275
SDK_MANIFEST_FILE = "models_manifest.json"
@@ -82,6 +95,40 @@ class ProblemTypes(str, Enum):
8295
Tasks.TABULAR_CLASSIFICATION: ProblemTypes.TABULAR_CLASSIFICATION,
8396
}
8497

98+
TO_FRAMEWORK = {
99+
"Tensorflow Hub": Frameworks.TENSORFLOW,
100+
"Pytorch Hub": Frameworks.PYTORCH,
101+
"HuggingFace": Frameworks.HUGGINGFACE,
102+
"Catboost": Frameworks.CATBOOST,
103+
"GluonCV": Frameworks.GLUONCV,
104+
"LightGBM": Frameworks.LIGHTGBM,
105+
"XGBoost": Frameworks.XGBOOST,
106+
"ScikitLearn": Frameworks.SCIKIT_LEARN,
107+
"Source": Frameworks.SOURCE
108+
}
109+
110+
111+
MODALITY_MAP = {
112+
(Tasks.IC, Frameworks.PYTORCH): "algorithms/vision/image_classification_pytorch.rst",
113+
(Tasks.IC, Frameworks.TENSORFLOW): "algorithms/vision/image_classification_tensorflow.rst",
114+
(Tasks.IC_EMBEDDING, Frameworks.TENSORFLOW): "algorithms/vision/image_embedding_tensorflow.rst",
115+
(Tasks.IS, Frameworks.GLUONCV): "algorithms/vision/instance_segmentation_mxnet.rst",
116+
(Tasks.OD, Frameworks.GLUONCV): "algorithms/vision/object_detection_mxnet.rst",
117+
(Tasks.OD, Frameworks.PYTORCH): "algorithms/vision/object_detection_pytorch.rst",
118+
(Tasks.OD, Frameworks.TENSORFLOW): "algorithms/vision/object_detection_tensorflow.rst",
119+
(Tasks.SEMSEG, Frameworks.GLUONCV): "algorithms/vision/semantic_segmentation_mxnet.rst",
120+
(Tasks.TRANSLATION, Frameworks.HUGGINGFACE): "algorithms/text/machine_translation_hugging_face.rst",
121+
(Tasks.NER, Frameworks.GLUONCV): "algorithms/text/named_entity_recognition_hugging_face.rst",
122+
(Tasks.EQA, Frameworks.PYTORCH): "algorithms/text/question_answering_pytorch.rst",
123+
(Tasks.SPC, Frameworks.HUGGINGFACE): "algorithms/text/sentence_pair_classification_hugging_face.rst",
124+
(Tasks.SPC, Frameworks.TENSORFLOW): "algorithms/text/sentence_pair_classification_tensorflow.rst",
125+
(Tasks.TC, Frameworks.TENSORFLOW): "algorithms/text/text_classification_tensorflow.rst",
126+
(Tasks.TC_EMBEDDING, Frameworks.GLUONCV): "algorithms/text/text_embedding_tensorflow_mxnet.rst",
127+
(Tasks.TC_EMBEDDING, Frameworks.TENSORFLOW): "algorithms/text/text_embedding_tensorflow_mxnet.rst",
128+
(Tasks.TEXT_GENERATION, Frameworks.HUGGINGFACE): "algorithms/text/text_generation_hugging_face.rst",
129+
(Tasks.SUMMARIZATION, Frameworks.HUGGINGFACE): "algorithms/text/text_summarization_hugging_face.rst",
130+
}
131+
85132

86133
def get_jumpstart_sdk_manifest():
87134
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE)
@@ -102,6 +149,10 @@ def get_model_task(id):
102149
return TASK_MAP[task_short] if task_short in TASK_MAP else "Source"
103150

104151

152+
def get_string_model_task(id):
153+
return id.split("-")[1]
154+
155+
105156
def get_model_source(url):
106157
if "tfhub" in url:
107158
return "Tensorflow Hub"
@@ -113,8 +164,6 @@ def get_model_source(url):
113164
return "Catboost"
114165
if "gluon" in url:
115166
return "GluonCV"
116-
if "catboost" in url:
117-
return "Catboost"
118167
if "lightgbm" in url:
119168
return "LightGBM"
120169
if "xgboost" in url:
@@ -138,58 +187,94 @@ def create_jumpstart_model_table():
138187
) < Version(model["version"]):
139188
sdk_manifest_top_versions_for_models[model["model_id"]] = model
140189

141-
file_content = []
190+
file_content_intro = []
142191

143-
file_content.append(".. _all-pretrained-models:\n\n")
144-
file_content.append(".. |external-link| raw:: html\n\n")
145-
file_content.append(' <i class="fa fa-external-link"></i>\n\n')
192+
file_content_intro.append(".. _all-pretrained-models:\n\n")
193+
file_content_intro.append(".. |external-link| raw:: html\n\n")
194+
file_content_intro.append(' <i class="fa fa-external-link"></i>\n\n')
146195

147-
file_content.append("================================================\n")
148-
file_content.append("Built-in Algorithms with pre-trained Model Table\n")
149-
file_content.append("================================================\n")
150-
file_content.append(
196+
file_content_intro.append("================================================\n")
197+
file_content_intro.append("Built-in Algorithms with pre-trained Model Table\n")
198+
file_content_intro.append("================================================\n")
199+
file_content_intro.append(
151200
"""
152201
The SageMaker Python SDK uses model IDs and model versions to access the necessary
153202
utilities for pre-trained models. This table serves to provide the core material plus
154203
some extra information that can be useful in selecting the correct model ID and
155204
corresponding parameters.\n"""
156205
)
157-
file_content.append(
206+
file_content_intro.append(
158207
"""
159208
If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
160209
We highly suggest pinning an exact model version however.\n"""
161210
)
162-
file_content.append(
211+
file_content_intro.append(
163212
"""
164213
These models are also available through the
165214
`JumpStart UI in SageMaker Studio <https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html>`__\n"""
166215
)
167-
file_content.append("\n")
168-
file_content.append(".. list-table:: Available Models\n")
169-
file_content.append(" :widths: 50 20 20 20 30 20\n")
170-
file_content.append(" :header-rows: 1\n")
171-
file_content.append(" :class: datatable\n")
172-
file_content.append("\n")
173-
file_content.append(" * - Model ID\n")
174-
file_content.append(" - Fine Tunable?\n")
175-
file_content.append(" - Latest Version\n")
176-
file_content.append(" - Min SDK Version\n")
177-
file_content.append(" - Problem Type\n")
178-
file_content.append(" - Source\n")
216+
file_content_intro.append("\n")
217+
file_content_intro.append(".. list-table:: Available Models\n")
218+
file_content_intro.append(" :widths: 50 20 20 20 30 20\n")
219+
file_content_intro.append(" :header-rows: 1\n")
220+
file_content_intro.append(" :class: datatable\n")
221+
file_content_intro.append("\n")
222+
file_content_intro.append(" * - Model ID\n")
223+
file_content_intro.append(" - Fine Tunable?\n")
224+
file_content_intro.append(" - Latest Version\n")
225+
file_content_intro.append(" - Min SDK Version\n")
226+
file_content_intro.append(" - Problem Type\n")
227+
file_content_intro.append(" - Source\n")
228+
229+
dynamic_table_files = []
230+
file_content_entries = []
179231

180232
for model in sdk_manifest_top_versions_for_models.values():
181233
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
182234
model_task = get_model_task(model_spec["model_id"])
235+
string_model_task = get_string_model_task(model_spec["model_id"])
183236
model_source = get_model_source(model_spec["url"])
184-
file_content.append(" * - {}\n".format(model_spec["model_id"]))
185-
file_content.append(" - {}\n".format(model_spec["training_supported"]))
186-
file_content.append(" - {}\n".format(model["version"]))
187-
file_content.append(" - {}\n".format(model["min_version"]))
188-
file_content.append(" - {}\n".format(model_task))
189-
file_content.append(
237+
file_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
238+
file_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
239+
file_content_entries.append(" - {}\n".format(model["version"]))
240+
file_content_entries.append(" - {}\n".format(model["min_version"]))
241+
file_content_entries.append(" - {}\n".format(model_task))
242+
file_content_entries.append(
190243
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
191244
)
192245

193-
f = open("doc_utils/pretrainedmodels.rst", "w")
194-
f.writelines(file_content)
246+
if (string_model_task, TO_FRAMEWORK[model_source]) in MODALITY_MAP:
247+
file_content_single_entry = []
248+
249+
if MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])] not in dynamic_table_files:
250+
file_content_single_entry.append("\n")
251+
file_content_single_entry.append(".. list-table:: Available Models\n")
252+
file_content_single_entry.append(" :widths: 50 20 20 20 30 20\n")
253+
file_content_single_entry.append(" :header-rows: 1\n")
254+
file_content_single_entry.append(" :class: datatable\n")
255+
file_content_single_entry.append("\n")
256+
file_content_single_entry.append(" * - Model ID\n")
257+
file_content_single_entry.append(" - Fine Tunable?\n")
258+
file_content_single_entry.append(" - Latest Version\n")
259+
file_content_single_entry.append(" - Min SDK Version\n")
260+
file_content_single_entry.append(" - Problem Type\n")
261+
file_content_single_entry.append(" - Source\n")
262+
263+
dynamic_table_files.append(MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])])
264+
265+
file_content_single_entry.append(" * - {}\n".format(model_spec["model_id"]))
266+
file_content_single_entry.append(" - {}\n".format(model_spec["training_supported"]))
267+
file_content_single_entry.append(" - {}\n".format(model["version"]))
268+
file_content_single_entry.append(" - {}\n".format(model["min_version"]))
269+
file_content_single_entry.append(" - {}\n".format(model_task))
270+
file_content_single_entry.append(
271+
" - `{} <{}>`__ \n".format(model_source, model_spec["url"])
272+
)
273+
f = open(MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])], "a")
274+
f.writelines(file_content_single_entry)
275+
f.close()
276+
277+
f = open("doc_utils/pretrainedmodels.rst", "a")
278+
f.writelines(file_content_intro)
279+
f.writelines(file_content_entries)
195280
f.close()

0 commit comments

Comments
 (0)