35
35
from sagemaker .jumpstart .filters import Constant , ModelFilter , Operator , evaluate_filter_expression
36
36
from sagemaker .jumpstart .types import JumpStartModelHeader , JumpStartModelSpecs
37
37
from sagemaker .jumpstart .utils import get_jumpstart_content_bucket , get_sagemaker_version
38
+ from sagemaker .session import Session
38
39
39
40
40
41
def _compare_model_version_tuples ( # pylint: disable=too-many-return-statements
@@ -137,6 +138,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
137
138
def list_jumpstart_tasks ( # pylint: disable=redefined-builtin
138
139
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
139
140
region : str = JUMPSTART_DEFAULT_REGION_NAME ,
141
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
140
142
) -> List [str ]:
141
143
"""List tasks for JumpStart, and optionally apply filters to result.
142
144
@@ -148,10 +150,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
148
150
(Default: Constant(BooleanValues.TRUE)).
149
151
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
150
152
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
153
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
154
+ use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
151
155
"""
152
156
153
157
tasks : Set [str ] = set ()
154
- for model_id , _ in _generate_jumpstart_model_versions (filter = filter , region = region ):
158
+ for model_id , _ in _generate_jumpstart_model_versions (
159
+ filter = filter , region = region , sagemaker_session = sagemaker_session
160
+ ):
155
161
_ , task , _ = extract_framework_task_model (model_id )
156
162
tasks .add (task )
157
163
return sorted (list (tasks ))
@@ -160,6 +166,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
160
166
def list_jumpstart_frameworks ( # pylint: disable=redefined-builtin
161
167
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
162
168
region : str = JUMPSTART_DEFAULT_REGION_NAME ,
169
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
163
170
) -> List [str ]:
164
171
"""List frameworks for JumpStart, and optionally apply filters to result.
165
172
@@ -171,10 +178,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
171
178
(Default: Constant(BooleanValues.TRUE)).
172
179
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
173
180
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
181
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
182
+ to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
174
183
"""
175
184
176
185
frameworks : Set [str ] = set ()
177
- for model_id , _ in _generate_jumpstart_model_versions (filter = filter , region = region ):
186
+ for model_id , _ in _generate_jumpstart_model_versions (
187
+ filter = filter , region = region , sagemaker_session = sagemaker_session
188
+ ):
178
189
framework , _ , _ = extract_framework_task_model (model_id )
179
190
frameworks .add (framework )
180
191
return sorted (list (frameworks ))
@@ -183,6 +194,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
183
194
def list_jumpstart_scripts ( # pylint: disable=redefined-builtin
184
195
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
185
196
region : str = JUMPSTART_DEFAULT_REGION_NAME ,
197
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
186
198
) -> List [str ]:
187
199
"""List scripts for JumpStart, and optionally apply filters to result.
188
200
@@ -194,19 +206,24 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
194
206
(Default: Constant(BooleanValues.TRUE)).
195
207
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
196
208
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
209
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
210
+ use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
197
211
"""
198
212
if (isinstance (filter , Constant ) and filter .resolved_value == BooleanValues .TRUE ) or (
199
213
isinstance (filter , str ) and filter .lower () == BooleanValues .TRUE .lower ()
200
214
):
201
215
return sorted ([e .value for e in JumpStartScriptScope ])
202
216
203
217
scripts : Set [str ] = set ()
204
- for model_id , version in _generate_jumpstart_model_versions (filter = filter , region = region ):
218
+ for model_id , version in _generate_jumpstart_model_versions (
219
+ filter = filter , region = region , sagemaker_session = sagemaker_session
220
+ ):
205
221
scripts .add (JumpStartScriptScope .INFERENCE )
206
222
model_specs = accessors .JumpStartModelsAccessor .get_model_specs (
207
223
region = region ,
208
224
model_id = model_id ,
209
225
version = version ,
226
+ s3_client = sagemaker_session .s3_client ,
210
227
)
211
228
if model_specs .training_supported :
212
229
scripts .add (JumpStartScriptScope .TRAINING )
@@ -222,6 +239,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
222
239
list_incomplete_models : bool = False ,
223
240
list_old_models : bool = False ,
224
241
list_versions : bool = False ,
242
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
225
243
) -> List [Union [Tuple [str ], Tuple [str , str ]]]:
226
244
"""List models for JumpStart, and optionally apply filters to result.
227
245
@@ -241,11 +259,16 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
241
259
versions should be included in the returned result. (Default: False).
242
260
list_versions (bool): Optional. True if versions for models should be returned in addition
243
261
to the id of the model. (Default: False).
262
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
263
+ to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
244
264
"""
245
265
246
266
model_id_version_dict : Dict [str , List [str ]] = dict ()
247
267
for model_id , version in _generate_jumpstart_model_versions (
248
- filter = filter , region = region , list_incomplete_models = list_incomplete_models
268
+ filter = filter ,
269
+ region = region ,
270
+ list_incomplete_models = list_incomplete_models ,
271
+ sagemaker_session = sagemaker_session ,
249
272
):
250
273
if model_id not in model_id_version_dict :
251
274
model_id_version_dict [model_id ] = list ()
@@ -271,6 +294,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
271
294
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
272
295
region : str = JUMPSTART_DEFAULT_REGION_NAME ,
273
296
list_incomplete_models : bool = False ,
297
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
274
298
) -> Generator :
275
299
"""Generate models for JumpStart, and optionally apply filters to result.
276
300
@@ -286,9 +310,13 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
286
310
requested by the filter, and the filter cannot be resolved to a include/not include,
287
311
whether the model should be included. By default, these models are omitted from
288
312
results. (Default: False).
313
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
314
+ to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
289
315
"""
290
316
291
- models_manifest_list = accessors .JumpStartModelsAccessor ._get_manifest (region = region )
317
+ models_manifest_list = accessors .JumpStartModelsAccessor ._get_manifest (
318
+ region = region , s3_client = sagemaker_session .s3_client
319
+ )
292
320
293
321
if isinstance (filter , str ):
294
322
filter = Identity (filter )
@@ -366,7 +394,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
366
394
367
395
model_specs = JumpStartModelSpecs (
368
396
json .loads (
369
- DEFAULT_JUMPSTART_SAGEMAKER_SESSION .read_s3_file (
397
+ sagemaker_session .read_s3_file (
370
398
get_jumpstart_content_bucket (region ), model_manifest .spec_key
371
399
)
372
400
)
@@ -418,7 +446,10 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
418
446
419
447
420
448
def get_model_url (
421
- model_id : str , model_version : str , region : str = JUMPSTART_DEFAULT_REGION_NAME
449
+ model_id : str ,
450
+ model_version : str ,
451
+ region : str = JUMPSTART_DEFAULT_REGION_NAME ,
452
+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
422
453
) -> str :
423
454
"""Retrieve web url describing pretrained model.
424
455
@@ -427,9 +458,14 @@ def get_model_url(
427
458
model_version (str): The model version for which to retrieve the url.
428
459
region (str): Optional. The region from which to retrieve metadata.
429
460
(Default: JUMPSTART_DEFAULT_REGION_NAME)
461
+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
462
+ to retrieve the model url.
430
463
"""
431
464
432
465
model_specs = accessors .JumpStartModelsAccessor .get_model_specs (
433
- region = region , model_id = model_id , version = model_version
466
+ region = region ,
467
+ model_id = model_id ,
468
+ version = model_version ,
469
+ s3_client = sagemaker_session .s3_client ,
434
470
)
435
471
return model_specs .url
0 commit comments