Skip to content

Commit 21177c4

Browse files
committed
feat: retrieve jumpstart estimator and predictor without specifying model id (infer from tags)
1 parent 57fd632 commit 21177c4

File tree

10 files changed

+481
-27
lines changed

10 files changed

+481
-27
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
3535
from sagemaker.jumpstart.factory.model import get_default_predictor
3636
from sagemaker.jumpstart.utils import (
37+
get_jumpstart_model_id_version_from_training_job,
3738
is_valid_model_id,
3839
resolve_model_sagemaker_config_field,
3940
)
@@ -664,8 +665,8 @@ def fit(
664665
def attach(
665666
cls,
666667
training_job_name: str,
667-
model_id: str,
668-
model_version: str = "*",
668+
model_id: Optional[str] = None,
669+
model_version: Optional[str] = None,
669670
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
670671
model_channel_name: str = "model",
671672
) -> "JumpStartEstimator":
@@ -709,6 +710,20 @@ def attach(
709710
training job.
710711
"""
711712

713+
if model_id is None:
714+
model_id, inferred_model_version = get_jumpstart_model_id_version_from_training_job(
715+
training_job_name=training_job_name,
716+
sagemaker_session=sagemaker_session,
717+
)
718+
model_version = model_version or inferred_model_version
719+
if not model_id:
720+
raise ValueError(
721+
f"Cannot infer JumpStart model ID from training job '{training_job_name}'. "
722+
"Please specify JumpStart `model_id` when retrieving Estimator for this training job."
723+
)
724+
725+
model_version = model_version or "*"
726+
712727
return cls._attach(
713728
training_job_name=training_job_name,
714729
sagemaker_session=sagemaker_session,

src/sagemaker/jumpstart/utils.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17-
from typing import Any, Dict, List, Optional, Union
17+
from typing import Any, Dict, List, Optional, Tuple, Union
1818
from urllib.parse import urlparse
1919
import boto3
2020
from packaging.version import Version
@@ -41,7 +41,7 @@
4141
)
4242
from sagemaker.session import Session
4343
from sagemaker.config import load_sagemaker_config
44-
from sagemaker.utils import resolve_value_from_config
44+
from sagemaker.utils import aws_partition, resolve_value_from_config
4545
from sagemaker.workflow import is_pipeline_variable
4646

4747

@@ -757,3 +757,66 @@ def is_valid_model_id(
757757
if script == enums.JumpStartScriptScope.TRAINING:
758758
return model_id in model_id_set
759759
raise ValueError(f"Unsupported script: {script}")
760+
761+
762+
def _get_jumpstart_model_id_version_from_resource_arn(
763+
resource_arn: str,
764+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
765+
) -> Tuple[Optional[str], Optional[str]]:
766+
767+
list_tags_result = sagemaker_session.list_tags(resource_arn)
768+
769+
model_id: Optional[str] = None
770+
model_version: Optional[str] = None
771+
772+
if tag_key_in_array(enums.JumpStartTag.MODEL_ID, list_tags_result):
773+
try:
774+
model_id = get_tag_value(enums.JumpStartTag.MODEL_ID, list_tags_result)
775+
except KeyError:
776+
model_id = None
777+
778+
if tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, list_tags_result):
779+
try:
780+
model_version = get_tag_value(enums.JumpStartTag.MODEL_VERSION, list_tags_result)
781+
except KeyError:
782+
model_version = None
783+
784+
return model_id, model_version
785+
786+
787+
def get_jumpstart_model_id_version_from_training_job(
788+
training_job_name: str,
789+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
790+
) -> Tuple[Optional[str], Optional[str]]:
791+
"""Inspects tags of training job to return JumpStart model ID and version.
792+
793+
Returns None if information cannot be inferred.
794+
"""
795+
796+
region: str = sagemaker_session.boto_region_name
797+
partition: str = aws_partition(region)
798+
account_id: str = sagemaker_session.account_id()
799+
800+
training_job_arn = (
801+
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
802+
)
803+
804+
return _get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)
805+
806+
807+
def get_jumpstart_model_id_version_from_endpoint(
808+
endpoint_name: str,
809+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
810+
) -> Tuple[Optional[str], Optional[str]]:
811+
"""Inspects tags of endpoint to return JumpStart model ID and version.
812+
813+
Returns None if information cannot be inferred.
814+
"""
815+
816+
region: str = sagemaker_session.boto_region_name
817+
partition: str = aws_partition(region)
818+
account_id: str = sagemaker_session.account_id()
819+
820+
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
821+
822+
return _get_jumpstart_model_id_version_from_resource_arn(endpoint_arn, sagemaker_session)

src/sagemaker/predictor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1818

1919
from sagemaker.jumpstart.factory.model import get_default_predictor
20-
from sagemaker.jumpstart.utils import is_jumpstart_model_input
20+
from sagemaker.jumpstart.utils import (
21+
get_jumpstart_model_id_version_from_endpoint,
22+
)
2123

2224
from sagemaker.session import Session
2325

@@ -66,11 +68,19 @@ def retrieve_default(
6668
ValueError: If the combination of arguments specified is not supported.
6769
"""
6870

69-
if not is_jumpstart_model_input(model_id, model_version):
70-
raise ValueError(
71-
"Must specify JumpStart `model_id` and `model_version` "
72-
"when retrieving default predictor."
71+
if model_id is None:
72+
model_id, inferred_model_version = get_jumpstart_model_id_version_from_endpoint(
73+
endpoint_name=endpoint_name,
74+
sagemaker_session=sagemaker_session,
7375
)
76+
model_version = model_version or inferred_model_version
77+
if not model_id:
78+
raise ValueError(
79+
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
80+
"Please specify JumpStart `model_id` when retrieving default predictor for this endpoint."
81+
)
82+
83+
model_version = model_version or "*"
7484

7585
predictor = Predictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)
7686

src/sagemaker/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def _botocore_resolver():
746746
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
747747

748748

749-
def _aws_partition(region):
749+
def aws_partition(region):
750750
"""Given a region name (ex: "cn-north-1"), return the corresponding aws partition ("aws-cn").
751751
752752
Args:

tests/integ/kms_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _create_kms_key(
6464
):
6565
if role_arn:
6666
principal = PRINCIPAL_TEMPLATE.format(
67-
partition=utils._aws_partition(region),
67+
partition=utils.aws_partition(region),
6868
account_id=account_id,
6969
role_arn=role_arn,
7070
sagemaker_role=sagemaker_role,
@@ -95,7 +95,7 @@ def _add_role_to_policy(
9595

9696
if role_arn not in principal or sagemaker_role not in principal:
9797
principal = PRINCIPAL_TEMPLATE.format(
98-
partition=utils._aws_partition(region),
98+
partition=utils.aws_partition(region),
9999
account_id=account_id,
100100
role_arn=role_arn,
101101
sagemaker_role=sagemaker_role,
@@ -198,7 +198,7 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
198198
s3_client.put_bucket_policy(
199199
Bucket=bucket_name,
200200
Policy=KMS_BUCKET_POLICY.format(
201-
partition=utils._aws_partition(region), bucket_name=bucket_name
201+
partition=utils.aws_partition(region), bucket_name=bucket_name
202202
),
203203
)
204204

tests/integ/test_marketplace.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from sagemaker import AlgorithmEstimator, ModelPackage, Model
2828
from sagemaker.serializers import CSVSerializer
2929
from sagemaker.tuner import IntegerParameter, HyperparameterTuner
30-
from sagemaker.utils import sagemaker_timestamp, _aws_partition, unique_name_from_base
30+
from sagemaker.utils import sagemaker_timestamp, aws_partition, unique_name_from_base
3131
from tests.integ import DATA_DIR
3232
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
3333
from tests.integ.marketplace_utils import REGION_ACCOUNT_MAP
@@ -76,7 +76,7 @@ def test_marketplace_estimator(sagemaker_session, cpu_instance_type):
7676
region = sagemaker_session.boto_region_name
7777
account = REGION_ACCOUNT_MAP[region]
7878
algorithm_arn = ALGORITHM_ARN.format(
79-
partition=_aws_partition(region), region=region, account=account
79+
partition=aws_partition(region), region=region, account=account
8080
)
8181

8282
algo = AlgorithmEstimator(
@@ -118,7 +118,7 @@ def test_marketplace_attach(sagemaker_session, cpu_instance_type):
118118
region = sagemaker_session.boto_region_name
119119
account = REGION_ACCOUNT_MAP[region]
120120
algorithm_arn = ALGORITHM_ARN.format(
121-
partition=_aws_partition(region), region=region, account=account
121+
partition=aws_partition(region), region=region, account=account
122122
)
123123

124124
mktplace = AlgorithmEstimator(
@@ -170,7 +170,7 @@ def test_marketplace_model(sagemaker_session, cpu_instance_type):
170170
region = sagemaker_session.boto_region_name
171171
account = REGION_ACCOUNT_MAP[region]
172172
model_package_arn = MODEL_PACKAGE_ARN.format(
173-
partition=_aws_partition(region), region=region, account=account
173+
partition=aws_partition(region), region=region, account=account
174174
)
175175

176176
def predict_wrapper(endpoint, session):
@@ -337,7 +337,7 @@ def test_marketplace_tuning_job(sagemaker_session, cpu_instance_type):
337337
region = sagemaker_session.boto_region_name
338338
account = REGION_ACCOUNT_MAP[region]
339339
algorithm_arn = ALGORITHM_ARN.format(
340-
partition=_aws_partition(region), region=region, account=account
340+
partition=aws_partition(region), region=region, account=account
341341
)
342342

343343
mktplace = AlgorithmEstimator(
@@ -380,7 +380,7 @@ def test_marketplace_transform_job(sagemaker_session, cpu_instance_type):
380380
region = sagemaker_session.boto_region_name
381381
account = REGION_ACCOUNT_MAP[region]
382382
algorithm_arn = ALGORITHM_ARN.format(
383-
partition=_aws_partition(region), region=region, account=account
383+
partition=aws_partition(region), region=region, account=account
384384
)
385385

386386
algo = AlgorithmEstimator(
@@ -428,7 +428,7 @@ def test_marketplace_transform_job_from_model_package(sagemaker_session, cpu_ins
428428
region = sagemaker_session.boto_region_name
429429
account = REGION_ACCOUNT_MAP[region]
430430
model_package_arn = MODEL_PACKAGE_ARN.format(
431-
partition=_aws_partition(region), region=region, account=account
431+
partition=aws_partition(region), region=region, account=account
432432
)
433433

434434
model = ModelPackage(

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,83 @@ def test_jumpstart_estimator_tags(
971971
[{"Key": "blah", "Value": "blahagain"}] + js_tags,
972972
)
973973

974+
@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
975+
@mock.patch("sagemaker.jumpstart.estimator.get_jumpstart_model_id_version_from_training_job")
976+
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
977+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
978+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
979+
def test_jumpstart_estimator_attach_no_model_id_happy_case(
980+
self,
981+
mock_get_model_specs: mock.Mock,
982+
mock_is_valid_model_id: mock.Mock,
983+
mock_get_jumpstart_model_id_version_from_training_job: mock.Mock,
984+
mock_attach: mock.Mock,
985+
):
986+
987+
mock_is_valid_model_id.return_value = True
988+
989+
mock_get_jumpstart_model_id_version_from_training_job.return_value = (
990+
"js-trainable-model-prepacked",
991+
"1.0.0",
992+
)
993+
994+
mock_get_model_specs.side_effect = get_special_model_spec
995+
996+
mock_session = mock.MagicMock(sagemaker_config={}, boto_region_name="us-west-2")
997+
998+
JumpStartEstimator.attach(
999+
training_job_name="some-training-job-name", sagemaker_session=mock_session
1000+
)
1001+
1002+
mock_get_jumpstart_model_id_version_from_training_job.assert_called_once_with(
1003+
training_job_name="some-training-job-name", sagemaker_session=mock_session
1004+
)
1005+
1006+
mock_attach.assert_called_once_with(
1007+
training_job_name="some-training-job-name",
1008+
sagemaker_session=mock_session,
1009+
model_channel_name="model",
1010+
additional_kwargs={
1011+
"model_id": "js-trainable-model-prepacked",
1012+
"model_version": "1.0.0",
1013+
},
1014+
)
1015+
1016+
@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
1017+
@mock.patch("sagemaker.jumpstart.estimator.get_jumpstart_model_id_version_from_training_job")
1018+
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
1019+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1020+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
1021+
def test_jumpstart_estimator_attach_no_model_id_sad_case(
1022+
self,
1023+
mock_get_model_specs: mock.Mock,
1024+
mock_is_valid_model_id: mock.Mock,
1025+
mock_get_jumpstart_model_id_version_from_training_job: mock.Mock,
1026+
mock_attach: mock.Mock,
1027+
):
1028+
1029+
mock_is_valid_model_id.return_value = True
1030+
1031+
mock_get_jumpstart_model_id_version_from_training_job.return_value = (
1032+
None,
1033+
None,
1034+
)
1035+
1036+
mock_get_model_specs.side_effect = get_special_model_spec
1037+
1038+
mock_session = mock.MagicMock(sagemaker_config={}, boto_region_name="us-west-2")
1039+
1040+
with pytest.raises(ValueError):
1041+
JumpStartEstimator.attach(
1042+
training_job_name="some-training-job-name", sagemaker_session=mock_session
1043+
)
1044+
1045+
mock_get_jumpstart_model_id_version_from_training_job.assert_called_once_with(
1046+
training_job_name="some-training-job-name", sagemaker_session=mock_session
1047+
)
1048+
1049+
mock_attach.assert_not_called()
1050+
9741051
def test_jumpstart_estimator_kwargs_match_parent_class(self):
9751052

9761053
"""If you add arguments to <Estimator constructor>, this test will fail.

0 commit comments

Comments
 (0)