Skip to content

Commit 50dd33c

Browse files
committed
chore: add sagemaker session to notebook utils
1 parent 1f80808 commit 50dd33c

File tree

2 files changed

+112
-32
lines changed

2 files changed

+112
-32
lines changed

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, evaluate_filter_expression
3636
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
3737
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version
38+
from sagemaker.session import Session
3839

3940

4041
def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
@@ -137,6 +138,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
137138
def list_jumpstart_tasks( # pylint: disable=redefined-builtin
138139
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
139140
region: str = JUMPSTART_DEFAULT_REGION_NAME,
141+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
140142
) -> List[str]:
141143
"""List tasks for JumpStart, and optionally apply filters to result.
142144
@@ -148,10 +150,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
148150
(Default: Constant(BooleanValues.TRUE)).
149151
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
150152
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
153+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
154+
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
151155
"""
152156

153157
tasks: Set[str] = set()
154-
for model_id, _ in _generate_jumpstart_model_versions(filter=filter, region=region):
158+
for model_id, _ in _generate_jumpstart_model_versions(
159+
filter=filter, region=region, sagemaker_session=sagemaker_session
160+
):
155161
_, task, _ = extract_framework_task_model(model_id)
156162
tasks.add(task)
157163
return sorted(list(tasks))
@@ -160,6 +166,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
160166
def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
161167
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
162168
region: str = JUMPSTART_DEFAULT_REGION_NAME,
169+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
163170
) -> List[str]:
164171
"""List frameworks for JumpStart, and optionally apply filters to result.
165172
@@ -171,10 +178,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
171178
(Default: Constant(BooleanValues.TRUE)).
172179
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
173180
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
181+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
182+
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
174183
"""
175184

176185
frameworks: Set[str] = set()
177-
for model_id, _ in _generate_jumpstart_model_versions(filter=filter, region=region):
186+
for model_id, _ in _generate_jumpstart_model_versions(
187+
filter=filter, region=region, sagemaker_session=sagemaker_session
188+
):
178189
framework, _, _ = extract_framework_task_model(model_id)
179190
frameworks.add(framework)
180191
return sorted(list(frameworks))
@@ -183,6 +194,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
183194
def list_jumpstart_scripts( # pylint: disable=redefined-builtin
184195
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
185196
region: str = JUMPSTART_DEFAULT_REGION_NAME,
197+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
186198
) -> List[str]:
187199
"""List scripts for JumpStart, and optionally apply filters to result.
188200
@@ -194,19 +206,24 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
194206
(Default: Constant(BooleanValues.TRUE)).
195207
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
196208
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
209+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
210+
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
197211
"""
198212
if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or (
199213
isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower()
200214
):
201215
return sorted([e.value for e in JumpStartScriptScope])
202216

203217
scripts: Set[str] = set()
204-
for model_id, version in _generate_jumpstart_model_versions(filter=filter, region=region):
218+
for model_id, version in _generate_jumpstart_model_versions(
219+
filter=filter, region=region, sagemaker_session=sagemaker_session
220+
):
205221
scripts.add(JumpStartScriptScope.INFERENCE)
206222
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
207223
region=region,
208224
model_id=model_id,
209225
version=version,
226+
s3_client=sagemaker_session.s3_client,
210227
)
211228
if model_specs.training_supported:
212229
scripts.add(JumpStartScriptScope.TRAINING)
@@ -222,6 +239,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
222239
list_incomplete_models: bool = False,
223240
list_old_models: bool = False,
224241
list_versions: bool = False,
242+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
225243
) -> List[Union[Tuple[str], Tuple[str, str]]]:
226244
"""List models for JumpStart, and optionally apply filters to result.
227245
@@ -241,11 +259,16 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
241259
versions should be included in the returned result. (Default: False).
242260
list_versions (bool): Optional. True if versions for models should be returned in addition
243261
to the id of the model. (Default: False).
262+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
263+
to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
244264
"""
245265

246266
model_id_version_dict: Dict[str, List[str]] = dict()
247267
for model_id, version in _generate_jumpstart_model_versions(
248-
filter=filter, region=region, list_incomplete_models=list_incomplete_models
268+
filter=filter,
269+
region=region,
270+
list_incomplete_models=list_incomplete_models,
271+
sagemaker_session=sagemaker_session,
249272
):
250273
if model_id not in model_id_version_dict:
251274
model_id_version_dict[model_id] = list()
@@ -271,6 +294,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
271294
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
272295
region: str = JUMPSTART_DEFAULT_REGION_NAME,
273296
list_incomplete_models: bool = False,
297+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
274298
) -> Generator:
275299
"""Generate models for JumpStart, and optionally apply filters to result.
276300
@@ -286,9 +310,13 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
286310
requested by the filter, and the filter cannot be resolved to a include/not include,
287311
whether the model should be included. By default, these models are omitted from
288312
results. (Default: False).
313+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
314+
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
289315
"""
290316

291-
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region)
317+
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
318+
region=region, s3_client=sagemaker_session.s3_client
319+
)
292320

293321
if isinstance(filter, str):
294322
filter = Identity(filter)
@@ -366,7 +394,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
366394

367395
model_specs = JumpStartModelSpecs(
368396
json.loads(
369-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file(
397+
sagemaker_session.read_s3_file(
370398
get_jumpstart_content_bucket(region), model_manifest.spec_key
371399
)
372400
)
@@ -418,7 +446,10 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
418446

419447

420448
def get_model_url(
421-
model_id: str, model_version: str, region: str = JUMPSTART_DEFAULT_REGION_NAME
449+
model_id: str,
450+
model_version: str,
451+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
452+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
422453
) -> str:
423454
"""Retrieve web url describing pretrained model.
424455
@@ -427,9 +458,14 @@ def get_model_url(
427458
model_version (str): The model version for which to retrieve the url.
428459
region (str): Optional. The region from which to retrieve metadata.
429460
(Default: JUMPSTART_DEFAULT_REGION_NAME)
461+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
462+
to retrieve the model url.
430463
"""
431464

432465
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
433-
region=region, model_id=model_id, version=model_version
466+
region=region,
467+
model_id=model_id,
468+
version=model_version,
469+
s3_client=sagemaker_session.s3_client,
434470
)
435471
return model_specs.url

0 commit comments

Comments
 (0)