Skip to content

Commit 3988710

Browse files
committed
chore: add support for ic-based endpoints
1 parent 69b2929 commit 3988710

File tree

11 files changed

+652
-200
lines changed

11 files changed

+652
-200
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,9 @@
235235
"Unable to create default JumpStart SageMaker Session due to the following error: %s.",
236236
str(e),
237237
)
238+
239+
EXTRA_MODEL_ID_TAGS = ["sm-jumpstart-id", "sagemaker-studio:jumpstart-model-id"]
240+
EXTRA_MODEL_VERSION_TAGS = [
241+
"sm-jumpstart-model-version",
242+
"sagemaker-studio:jumpstart-model-version",
243+
]

src/sagemaker/jumpstart/estimator.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
3535
from sagemaker.jumpstart.factory.model import get_default_predictor
3636
from sagemaker.jumpstart.utils import (
37-
get_jumpstart_model_id_version_from_training_job,
37+
get_jumpstart_model_id_version_from_resource_arn,
3838
is_valid_model_id,
3939
resolve_model_sagemaker_config_field,
4040
)
41-
from sagemaker.utils import stringify_object
41+
from sagemaker.utils import aws_partition, stringify_object
4242
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
4343
from sagemaker.predictor import PredictorBase
4444

@@ -711,10 +711,19 @@ def attach(
711711
"""
712712

713713
if model_id is None:
714-
model_id, inferred_model_version = get_jumpstart_model_id_version_from_training_job(
715-
training_job_name=training_job_name,
716-
sagemaker_session=sagemaker_session,
714+
715+
region: str = sagemaker_session.boto_region_name
716+
partition: str = aws_partition(region)
717+
account_id: str = sagemaker_session.account_id()
718+
719+
training_job_arn = (
720+
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
717721
)
722+
723+
model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
724+
training_job_arn, sagemaker_session
725+
)
726+
718727
model_version = model_version or inferred_model_version
719728
if not model_id:
720729
raise ValueError(
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores SageMaker Session utilities for JumpStart models."""
14+
15+
from __future__ import absolute_import
16+
17+
from typing import Optional, Tuple
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
20+
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
21+
from sagemaker.session import Session
22+
from sagemaker.utils import aws_partition
23+
24+
25+
def get_model_id_version_from_endpoint(
26+
endpoint_name: str,
27+
inference_component_name: Optional[str] = None,
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
29+
) -> Tuple[str, str, Optional[str]]:
30+
"""Given an endpoint name (and optionally an IC name), get the JumpStart model ID and version.
31+
32+
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
33+
and version. A third string element is included in the tuple for any inferred inference
34+
component name, or None if it's a non-IC endpoint.
35+
36+
Raises:
37+
ValueError: If model ID and version cannot be inferred from the endpoint.
38+
"""
39+
if inference_component_name or sagemaker_session.is_ic_based_endpoint(endpoint_name):
40+
if inference_component_name:
41+
model_id, model_version = _get_model_id_version_from_ic_endpoint_with_ic_name(
42+
inference_component_name, sagemaker_session
43+
)
44+
45+
else:
46+
(
47+
model_id,
48+
model_version,
49+
inference_component_name,
50+
) = _get_model_id_version_from_ic_endpoint_without_ic_name(
51+
endpoint_name, sagemaker_session
52+
)
53+
54+
else:
55+
model_id, model_version = _get_model_id_version_from_non_ic_endpoint(
56+
endpoint_name, inference_component_name, sagemaker_session
57+
)
58+
return model_id, model_version, inference_component_name
59+
60+
61+
def _get_model_id_version_from_ic_endpoint_without_ic_name(
62+
endpoint_name: str, sagemaker_session: Session
63+
) -> Tuple[str, str, str]:
64+
"""Given an endpoint name, derives the model ID, version, and inferred inference component name.
65+
66+
This function assumes the endpoint corresponds to an inference component-based endpoint.
67+
68+
Raises:
69+
ValueError: If there is not a single inference component associated with the endpoint.
70+
"""
71+
inference_components = sagemaker_session.list_inference_components(
72+
endpoint_name_equals=endpoint_name
73+
)["InferenceComponents"]
74+
75+
if len(inference_components) == 0:
76+
raise ValueError(
77+
f"No inference components found for the following endpoint: {endpoint_name}. "
78+
"Use ``SageMaker.CreateInferenceComponent`` to add inference components to "
79+
"your endpoint."
80+
)
81+
if len(inference_components) > 1:
82+
raise ValueError(
83+
"Must provide inference component name for endpoint with "
84+
"multiple inference components."
85+
)
86+
inference_component_name = inference_components[0]["InferenceComponentName"]
87+
return (
88+
*_get_model_id_version_from_ic_endpoint_with_ic_name(
89+
inference_component_name, sagemaker_session
90+
),
91+
inference_component_name,
92+
)
93+
94+
95+
def _get_model_id_version_from_ic_endpoint_with_ic_name(
96+
inference_component_name: str, sagemaker_session: Session
97+
):
98+
"""Returns the model ID and version inferred from a sagemaker inference component.
99+
100+
Raises:
101+
ValueError: If the inference component does not have tags from which the model ID
102+
and version can be inferred.
103+
"""
104+
region: str = sagemaker_session.boto_region_name
105+
partition: str = aws_partition(region)
106+
account_id: str = sagemaker_session.account_id()
107+
108+
inference_component_arn = (
109+
f"arn:{partition}:sagemaker:{region}:{account_id}:"
110+
f"inference-component/{inference_component_name}"
111+
)
112+
113+
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
114+
inference_component_arn, sagemaker_session
115+
)
116+
117+
if not model_id:
118+
raise ValueError(
119+
"Cannot infer JumpStart model ID from inference component "
120+
f"'{inference_component_name}'. Please specify JumpStart `model_id` "
121+
"when retrieving default predictor for this inference component."
122+
)
123+
124+
return model_id, model_version
125+
126+
127+
def _get_model_id_version_from_non_ic_endpoint(
128+
endpoint_name: str,
129+
inference_component_name: Optional[str],
130+
sagemaker_session: Session,
131+
) -> Tuple[str, str]:
132+
"""Returns the model ID and version inferred from a model-based endpoint..
133+
134+
Raises:
135+
ValueError: If a non-None inference component name is supplied, or if the endpoint does
136+
not have tags from which the model ID and version can be inferred.
137+
"""
138+
139+
if inference_component_name:
140+
raise ValueError("Cannot specify inference component name for model-based endpoints.")
141+
142+
region: str = sagemaker_session.boto_region_name
143+
partition: str = aws_partition(region)
144+
account_id: str = sagemaker_session.account_id()
145+
146+
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
147+
148+
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
149+
endpoint_arn, sagemaker_session
150+
)
151+
152+
if not model_id:
153+
raise ValueError(
154+
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
155+
"Please specify JumpStart `model_id` when retrieving default "
156+
"predictor for this endpoint."
157+
)
158+
159+
return model_id, model_version

src/sagemaker/jumpstart/utils.py

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from sagemaker.session import Session
4343
from sagemaker.config import load_sagemaker_config
44-
from sagemaker.utils import aws_partition, resolve_value_from_config
44+
from sagemaker.utils import resolve_value_from_config
4545
from sagemaker.workflow import is_pipeline_variable
4646

4747

@@ -764,7 +764,7 @@ def is_valid_model_id(
764764
raise ValueError(f"Unsupported script: {script}")
765765

766766

767-
def _get_jumpstart_model_id_version_from_resource_arn(
767+
def get_jumpstart_model_id_version_from_resource_arn(
768768
resource_arn: str,
769769
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
770770
) -> Tuple[Optional[str], Optional[str]]:
@@ -778,54 +778,35 @@ def _get_jumpstart_model_id_version_from_resource_arn(
778778
model_id: Optional[str] = None
779779
model_version: Optional[str] = None
780780

781-
if tag_key_in_array(enums.JumpStartTag.MODEL_ID, list_tags_result):
781+
model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
782+
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
783+
784+
for model_id_key in model_id_keys:
782785
try:
783-
model_id = get_tag_value(enums.JumpStartTag.MODEL_ID, list_tags_result)
786+
model_id_from_tag = get_tag_value(model_id_key, list_tags_result)
784787
except KeyError:
785-
model_id = None
788+
continue
789+
if model_id_from_tag is not None:
790+
if model_id is not None and model_id_from_tag != model_id:
791+
constants.JUMPSTART_LOGGER.warning(
792+
"Found multiple model ID tags on the following resource: %s", resource_arn
793+
)
794+
model_id = None
795+
break
796+
model_id = model_id_from_tag
786797

787-
if tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, list_tags_result):
798+
for model_version_key in model_version_keys:
788799
try:
789-
model_version = get_tag_value(enums.JumpStartTag.MODEL_VERSION, list_tags_result)
800+
model_version_from_tag = get_tag_value(model_version_key, list_tags_result)
790801
except KeyError:
791-
model_version = None
802+
continue
803+
if model_version_from_tag is not None:
804+
if model_version is not None and model_version_from_tag != model_version:
805+
constants.JUMPSTART_LOGGER.warning(
806+
"Found multiple model version tags on the following resource: %s", resource_arn
807+
)
808+
model_version = None
809+
break
810+
model_version = model_version_from_tag
792811

793812
return model_id, model_version
794-
795-
796-
def get_jumpstart_model_id_version_from_training_job(
797-
training_job_name: str,
798-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
799-
) -> Tuple[Optional[str], Optional[str]]:
800-
"""Inspects tags of training job to return JumpStart model ID and version.
801-
802-
Returns None if information cannot be inferred.
803-
"""
804-
805-
region: str = sagemaker_session.boto_region_name
806-
partition: str = aws_partition(region)
807-
account_id: str = sagemaker_session.account_id()
808-
809-
training_job_arn = (
810-
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
811-
)
812-
813-
return _get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session)
814-
815-
816-
def get_jumpstart_model_id_version_from_endpoint(
817-
endpoint_name: str,
818-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
819-
) -> Tuple[Optional[str], Optional[str]]:
820-
"""Inspects tags of endpoint to return JumpStart model ID and version.
821-
822-
Returns None if information cannot be inferred.
823-
"""
824-
825-
region: str = sagemaker_session.boto_region_name
826-
partition: str = aws_partition(region)
827-
account_id: str = sagemaker_session.account_id()
828-
829-
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
830-
831-
return _get_jumpstart_model_id_version_from_resource_arn(endpoint_arn, sagemaker_session)

src/sagemaker/predictor.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1818

1919
from sagemaker.jumpstart.factory.model import get_default_predictor
20-
from sagemaker.jumpstart.utils import (
21-
get_jumpstart_model_id_version_from_endpoint,
22-
)
20+
from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint
21+
2322

2423
from sagemaker.session import Session
2524

@@ -35,6 +34,7 @@
3534

3635
def retrieve_default(
3736
endpoint_name: str,
37+
inference_component_name: Optional[str] = None,
3838
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3939
region: Optional[str] = None,
4040
model_id: Optional[str] = None,
@@ -46,6 +46,8 @@ def retrieve_default(
4646
4747
Args:
4848
endpoint_name (str): Endpoint name for which to create a predictor.
49+
inference_component_name (str): Name of the Amazon SageMaker inference component
50+
from which to optionally create a predictor. (Default: None).
4951
sagemaker_session (Session): The SageMaker Session to attach to the Predictor.
5052
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
5153
region (str): The AWS Region for which to retrieve the default predictor.
@@ -69,21 +71,32 @@ def retrieve_default(
6971
"""
7072

7173
if model_id is None:
72-
model_id, inferred_model_version = get_jumpstart_model_id_version_from_endpoint(
73-
endpoint_name=endpoint_name,
74-
sagemaker_session=sagemaker_session,
74+
(
75+
inferred_model_id,
76+
inferred_model_version,
77+
inferred_inference_component_name,
78+
) = get_model_id_version_from_endpoint(
79+
endpoint_name, inference_component_name, sagemaker_session
7580
)
76-
model_version = model_version or inferred_model_version
77-
if not model_id:
81+
82+
if not inferred_model_id:
7883
raise ValueError(
7984
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
8085
"Please specify JumpStart `model_id` when retrieving default "
8186
"predictor for this endpoint."
8287
)
8388

84-
model_version = model_version or "*"
89+
model_id = inferred_model_id
90+
model_version = model_version or inferred_model_version or "*"
91+
inference_component_name = inference_component_name or inferred_inference_component_name
92+
else:
93+
model_version = model_version or "*"
8594

86-
predictor = Predictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)
95+
predictor = Predictor(
96+
endpoint_name=endpoint_name,
97+
component_name=inference_component_name,
98+
sagemaker_session=sagemaker_session,
99+
)
87100

88101
return get_default_predictor(
89102
predictor=predictor,

0 commit comments

Comments
 (0)