@@ -57,6 +57,19 @@ class ProblemTypes(str, Enum):
57
57
TABULAR_REGRESSION = "Regression"
58
58
TABULAR_CLASSIFICATION = "Classification"
59
59
60
+ class Frameworks (str , Enum ):
61
+ """Possible frameworks for JumpStart models"""
62
+
63
+ TENSORFLOW = "Tensorflow Hub"
64
+ PYTORCH = "Pytorch Hub"
65
+ HUGGINGFACE = "HuggingFace"
66
+ CATBOOST = "Catboost"
67
+ GLUONCV = "GluonCV"
68
+ LIGHTGBM = "LightGBM"
69
+ XGBOOST = "XGBoost"
70
+ SCIKIT_LEARN = "ScikitLearn"
71
+ SOURCE = "Source"
72
+
60
73
61
74
JUMPSTART_REGION = "eu-west-2"
62
75
SDK_MANIFEST_FILE = "models_manifest.json"
@@ -82,6 +95,40 @@ class ProblemTypes(str, Enum):
82
95
Tasks .TABULAR_CLASSIFICATION : ProblemTypes .TABULAR_CLASSIFICATION ,
83
96
}
84
97
98
+ TO_FRAMEWORK = {
99
+ "Tensorflow Hub" : Frameworks .TENSORFLOW ,
100
+ "Pytorch Hub" : Frameworks .PYTORCH ,
101
+ "HuggingFace" : Frameworks .HUGGINGFACE ,
102
+ "Catboost" : Frameworks .CATBOOST ,
103
+ "GluonCV" : Frameworks .GLUONCV ,
104
+ "LightGBM" : Frameworks .LIGHTGBM ,
105
+ "XGBoost" : Frameworks .XGBOOST ,
106
+ "ScikitLearn" : Frameworks .SCIKIT_LEARN ,
107
+ "Source" : Frameworks .SOURCE
108
+ }
109
+
110
+
111
+ MODALITY_MAP = {
112
+ (Tasks .IC , Frameworks .PYTORCH ): "algorithms/vision/image_classification_pytorch.rst" ,
113
+ (Tasks .IC , Frameworks .TENSORFLOW ): "algorithms/vision/image_classification_tensorflow.rst" ,
114
+ (Tasks .IC_EMBEDDING , Frameworks .TENSORFLOW ): "algorithms/vision/image_embedding_tensorflow.rst" ,
115
+ (Tasks .IS , Frameworks .GLUONCV ): "algorithms/vision/instance_segmentation_mxnet.rst" ,
116
+ (Tasks .OD , Frameworks .GLUONCV ): "algorithms/vision/object_detection_mxnet.rst" ,
117
+ (Tasks .OD , Frameworks .PYTORCH ): "algorithms/vision/object_detection_pytorch.rst" ,
118
+ (Tasks .OD , Frameworks .TENSORFLOW ): "algorithms/vision/object_detection_tensorflow.rst" ,
119
+ (Tasks .SEMSEG , Frameworks .GLUONCV ): "algorithms/vision/semantic_segmentation_mxnet.rst" ,
120
+ (Tasks .TRANSLATION , Frameworks .HUGGINGFACE ): "algorithms/text/machine_translation_hugging_face.rst" ,
121
+ (Tasks .NER , Frameworks .GLUONCV ): "algorithms/text/named_entity_recognition_hugging_face.rst" ,
122
+ (Tasks .EQA , Frameworks .PYTORCH ): "algorithms/text/question_answering_pytorch.rst" ,
123
+ (Tasks .SPC , Frameworks .HUGGINGFACE ): "algorithms/text/sentence_pair_classification_hugging_face.rst" ,
124
+ (Tasks .SPC , Frameworks .TENSORFLOW ): "algorithms/text/sentence_pair_classification_tensorflow.rst" ,
125
+ (Tasks .TC , Frameworks .TENSORFLOW ): "algorithms/text/text_classification_tensorflow.rst" ,
126
+ (Tasks .TC_EMBEDDING , Frameworks .GLUONCV ): "algorithms/text/text_embedding_tensorflow_mxnet.rst" ,
127
+ (Tasks .TC_EMBEDDING , Frameworks .TENSORFLOW ): "algorithms/text/text_embedding_tensorflow_mxnet.rst" ,
128
+ (Tasks .TEXT_GENERATION , Frameworks .HUGGINGFACE ): "algorithms/text/text_generation_hugging_face.rst" ,
129
+ (Tasks .SUMMARIZATION , Frameworks .HUGGINGFACE ): "algorithms/text/text_summarization_hugging_face.rst" ,
130
+ }
131
+
85
132
86
133
def get_jumpstart_sdk_manifest ():
87
134
url = "{}/{}" .format (JUMPSTART_BUCKET_BASE_URL , SDK_MANIFEST_FILE )
@@ -102,6 +149,10 @@ def get_model_task(id):
102
149
return TASK_MAP [task_short ] if task_short in TASK_MAP else "Source"
103
150
104
151
152
+ def get_string_model_task (id ):
153
+ return id .split ("-" )[1 ]
154
+
155
+
105
156
def get_model_source (url ):
106
157
if "tfhub" in url :
107
158
return "Tensorflow Hub"
@@ -113,8 +164,6 @@ def get_model_source(url):
113
164
return "Catboost"
114
165
if "gluon" in url :
115
166
return "GluonCV"
116
- if "catboost" in url :
117
- return "Catboost"
118
167
if "lightgbm" in url :
119
168
return "LightGBM"
120
169
if "xgboost" in url :
@@ -138,58 +187,94 @@ def create_jumpstart_model_table():
138
187
) < Version (model ["version" ]):
139
188
sdk_manifest_top_versions_for_models [model ["model_id" ]] = model
140
189
141
- file_content = []
190
+ file_content_intro = []
142
191
143
- file_content .append (".. _all-pretrained-models:\n \n " )
144
- file_content .append (".. |external-link| raw:: html\n \n " )
145
- file_content .append (' <i class="fa fa-external-link"></i>\n \n ' )
192
+ file_content_intro .append (".. _all-pretrained-models:\n \n " )
193
+ file_content_intro .append (".. |external-link| raw:: html\n \n " )
194
+ file_content_intro .append (' <i class="fa fa-external-link"></i>\n \n ' )
146
195
147
- file_content .append ("================================================\n " )
148
- file_content .append ("Built-in Algorithms with pre-trained Model Table\n " )
149
- file_content .append ("================================================\n " )
150
- file_content .append (
196
+ file_content_intro .append ("================================================\n " )
197
+ file_content_intro .append ("Built-in Algorithms with pre-trained Model Table\n " )
198
+ file_content_intro .append ("================================================\n " )
199
+ file_content_intro .append (
151
200
"""
152
201
The SageMaker Python SDK uses model IDs and model versions to access the necessary
153
202
utilities for pre-trained models. This table serves to provide the core material plus
154
203
some extra information that can be useful in selecting the correct model ID and
155
204
corresponding parameters.\n """
156
205
)
157
- file_content .append (
206
+ file_content_intro .append (
158
207
"""
159
208
If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
160
209
We highly suggest pinning an exact model version however.\n """
161
210
)
162
- file_content .append (
211
+ file_content_intro .append (
163
212
"""
164
213
These models are also available through the
165
214
`JumpStart UI in SageMaker Studio <https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html>`__\n """
166
215
)
167
- file_content .append ("\n " )
168
- file_content .append (".. list-table:: Available Models\n " )
169
- file_content .append (" :widths: 50 20 20 20 30 20\n " )
170
- file_content .append (" :header-rows: 1\n " )
171
- file_content .append (" :class: datatable\n " )
172
- file_content .append ("\n " )
173
- file_content .append (" * - Model ID\n " )
174
- file_content .append (" - Fine Tunable?\n " )
175
- file_content .append (" - Latest Version\n " )
176
- file_content .append (" - Min SDK Version\n " )
177
- file_content .append (" - Problem Type\n " )
178
- file_content .append (" - Source\n " )
216
+ file_content_intro .append ("\n " )
217
+ file_content_intro .append (".. list-table:: Available Models\n " )
218
+ file_content_intro .append (" :widths: 50 20 20 20 30 20\n " )
219
+ file_content_intro .append (" :header-rows: 1\n " )
220
+ file_content_intro .append (" :class: datatable\n " )
221
+ file_content_intro .append ("\n " )
222
+ file_content_intro .append (" * - Model ID\n " )
223
+ file_content_intro .append (" - Fine Tunable?\n " )
224
+ file_content_intro .append (" - Latest Version\n " )
225
+ file_content_intro .append (" - Min SDK Version\n " )
226
+ file_content_intro .append (" - Problem Type\n " )
227
+ file_content_intro .append (" - Source\n " )
228
+
229
+ dynamic_table_files = []
230
+ file_content_entries = []
179
231
180
232
for model in sdk_manifest_top_versions_for_models .values ():
181
233
model_spec = get_jumpstart_sdk_spec (model ["spec_key" ])
182
234
model_task = get_model_task (model_spec ["model_id" ])
235
+ string_model_task = get_string_model_task (model_spec ["model_id" ])
183
236
model_source = get_model_source (model_spec ["url" ])
184
- file_content .append (" * - {}\n " .format (model_spec ["model_id" ]))
185
- file_content .append (" - {}\n " .format (model_spec ["training_supported" ]))
186
- file_content .append (" - {}\n " .format (model ["version" ]))
187
- file_content .append (" - {}\n " .format (model ["min_version" ]))
188
- file_content .append (" - {}\n " .format (model_task ))
189
- file_content .append (
237
+ file_content_entries .append (" * - {}\n " .format (model_spec ["model_id" ]))
238
+ file_content_entries .append (" - {}\n " .format (model_spec ["training_supported" ]))
239
+ file_content_entries .append (" - {}\n " .format (model ["version" ]))
240
+ file_content_entries .append (" - {}\n " .format (model ["min_version" ]))
241
+ file_content_entries .append (" - {}\n " .format (model_task ))
242
+ file_content_entries .append (
190
243
" - `{} <{}>`__ |external-link|\n " .format (model_source , model_spec ["url" ])
191
244
)
192
245
193
- f = open ("doc_utils/pretrainedmodels.rst" , "w" )
194
- f .writelines (file_content )
246
+ if (string_model_task , TO_FRAMEWORK [model_source ]) in MODALITY_MAP :
247
+ file_content_single_entry = []
248
+
249
+ if MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])] not in dynamic_table_files :
250
+ file_content_single_entry .append ("\n " )
251
+ file_content_single_entry .append (".. list-table:: Available Models\n " )
252
+ file_content_single_entry .append (" :widths: 50 20 20 20 30 20\n " )
253
+ file_content_single_entry .append (" :header-rows: 1\n " )
254
+ file_content_single_entry .append (" :class: datatable\n " )
255
+ file_content_single_entry .append ("\n " )
256
+ file_content_single_entry .append (" * - Model ID\n " )
257
+ file_content_single_entry .append (" - Fine Tunable?\n " )
258
+ file_content_single_entry .append (" - Latest Version\n " )
259
+ file_content_single_entry .append (" - Min SDK Version\n " )
260
+ file_content_single_entry .append (" - Problem Type\n " )
261
+ file_content_single_entry .append (" - Source\n " )
262
+
263
+ dynamic_table_files .append (MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])])
264
+
265
+ file_content_single_entry .append (" * - {}\n " .format (model_spec ["model_id" ]))
266
+ file_content_single_entry .append (" - {}\n " .format (model_spec ["training_supported" ]))
267
+ file_content_single_entry .append (" - {}\n " .format (model ["version" ]))
268
+ file_content_single_entry .append (" - {}\n " .format (model ["min_version" ]))
269
+ file_content_single_entry .append (" - {}\n " .format (model_task ))
270
+ file_content_single_entry .append (
271
+ " - `{} <{}>`__ \n " .format (model_source , model_spec ["url" ])
272
+ )
273
+ f = open (MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])], "a" )
274
+ f .writelines (file_content_single_entry )
275
+ f .close ()
276
+
277
+ f = open ("doc_utils/pretrainedmodels.rst" , "a" )
278
+ f .writelines (file_content_intro )
279
+ f .writelines (file_content_entries )
195
280
f .close ()
0 commit comments