Skip to content

Commit 430a517

Browse files
committed
fix: unit tests, use verify_model_region_and_return_specs in notebook utils
1 parent e2daefc commit 430a517

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
)
3535
from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, evaluate_filter_expression
3636
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
37-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version
37+
from sagemaker.jumpstart.utils import (
38+
get_jumpstart_content_bucket,
39+
get_sagemaker_version,
40+
verify_model_region_and_return_specs,
41+
)
3842
from sagemaker.session import Session
3943

4044
MAX_SEARCH_WORKERS = int(100 * 1e6 / 25 * 1e3) # max 100MB total memory, 25kB per thread)
@@ -221,11 +225,12 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
221225
filter=filter, region=region, sagemaker_session=sagemaker_session
222226
):
223227
scripts.add(JumpStartScriptScope.INFERENCE)
224-
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
228+
model_specs = verify_model_region_and_return_specs(
225229
region=region,
226230
model_id=model_id,
227231
version=version,
228-
s3_client=sagemaker_session.s3_client,
232+
sagemaker_session=sagemaker_session,
233+
scope=JumpStartScriptScope.INFERENCE,
229234
)
230235
if model_specs.training_supported:
231236
scripts.add(JumpStartScriptScope.TRAINING)
@@ -462,11 +467,12 @@ def get_model_url(
462467
to retrieve the model url.
463468
"""
464469

465-
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
470+
model_specs = verify_model_region_and_return_specs(
466471
region=region,
467472
model_id=model_id,
468473
version=model_version,
469-
s3_client=sagemaker_session.s3_client,
474+
sagemaker_session=sagemaker_session,
475+
scope=JumpStartScriptScope.INFERENCE,
470476
)
471477
return model_specs.url
472478

@@ -488,10 +494,11 @@ def _get_model_eula_key(
488494
to retrieve the EULA S3 key.
489495
"""
490496

491-
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
497+
model_specs = verify_model_region_and_return_specs(
492498
region=region,
493499
model_id=model_id,
494500
version=model_version,
495-
s3_client=sagemaker_session.s3_client,
501+
sagemaker_session=sagemaker_session,
502+
scope=JumpStartScriptScope.INFERENCE,
496503
)
497504
return model_specs.hosting_eula_key

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def test__get_model_eula_key(
713713
assert "fmhMetadata/eula/llamaEula.txt" == _get_model_eula_key(model_id, version)
714714

715715
model_id, version = "variant-model", "1.0.0"
716-
assert None == _get_model_eula_key(model_id, version)
716+
assert None is _get_model_eula_key(model_id, version)
717717

718718
region = "fake-region"
719719

0 commit comments

Comments
 (0)