Skip to content

feat: retrieve jumpstart estimator and predictor without specifying model id (infer from tags) #4304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
21177c4
feat: retrieve jumpstart estimator and predictor without specifying m…
evakravi Dec 5, 2023
8086f6e
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi Dec 5, 2023
69b2929
fix: pylint
evakravi Dec 5, 2023
3988710
chore: add support for ic-based endpoints
evakravi Dec 20, 2023
c4f7ffc
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi Dec 20, 2023
1dd1df4
chore: update docstrings
evakravi Dec 20, 2023
73c17a0
chore: add integ tests
evakravi Dec 21, 2023
a5de8ce
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi Dec 21, 2023
ab60cb4
chore: add support in conftest for ic endpoints
evakravi Dec 21, 2023
ec523be
fix: delete inference components
evakravi Dec 21, 2023
e73bd8c
chore: address tagging and ic determination comments
evakravi Dec 26, 2023
4d963d6
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi Dec 26, 2023
aa0d354
chore: address PR comments
evakravi Dec 27, 2023
caf4844
fix: docstring
evakravi Dec 28, 2023
2292586
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 2, 2024
995eb6d
chore: improve docs
evakravi Jan 12, 2024
c5285df
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi Jan 12, 2024
aa363c6
chore: address comments
evakravi Jan 12, 2024
43cc2e1
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 16, 2024
73dad91
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 16, 2024
e4cef3a
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 22, 2024
8d73539
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 22, 2024
6c3c965
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 23, 2024
a658b17
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 23, 2024
edd128f
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 23, 2024
133d3b7
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 24, 2024
93c9adf
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 25, 2024
5b2e0f9
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 29, 2024
8082dfa
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 30, 2024
ea8506d
fix: list_and_paginate_inference_component_names_associated_with_endp…
evakravi Jan 31, 2024
7101638
fix: boto3 session region
evakravi Jan 31, 2024
254f147
fix: boto_session
evakravi Jan 31, 2024
dc245d7
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Jan 31, 2024
13b7b2c
fix: sagemaker session to delete IC
evakravi Feb 1, 2024
62a9262
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,9 @@
"Unable to create default JumpStart SageMaker Session due to the following error: %s.",
str(e),
)

EXTRA_MODEL_ID_TAGS = ["sm-jumpstart-id", "sagemaker-studio:jumpstart-model-id"]
EXTRA_MODEL_VERSION_TAGS = [
"sm-jumpstart-model-version",
"sagemaker-studio:jumpstart-model-version",
]
17 changes: 15 additions & 2 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
from sagemaker.jumpstart.factory.model import get_default_predictor
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
from sagemaker.jumpstart.utils import (
is_valid_model_id,
resolve_model_sagemaker_config_field,
Expand Down Expand Up @@ -668,8 +669,8 @@ def fit(
def attach(
cls,
training_job_name: str,
model_id: str,
model_version: str = "*",
model_id: Optional[str] = None,
model_version: Optional[str] = None,
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_channel_name: str = "model",
) -> "JumpStartEstimator":
Expand Down Expand Up @@ -711,8 +712,20 @@ def attach(
Returns:
Instance of the calling ``JumpStartEstimator`` Class with the attached
training job.

Raises:
ValueError: if the model ID or version cannot be inferred from the training job.

"""

if model_id is None:

model_id, model_version = get_model_id_version_from_training_job(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)

model_version = model_version or "*"

return cls._attach(
training_job_name=training_job_name,
sagemaker_session=sagemaker_session,
Expand Down
210 changes: 210 additions & 0 deletions src/sagemaker/jumpstart/session_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores SageMaker Session utilities for JumpStart models."""

from __future__ import absolute_import

from typing import Optional, Tuple
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION

from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
from sagemaker.session import Session
from sagemaker.utils import aws_partition


def get_model_id_version_from_endpoint(
endpoint_name: str,
inference_component_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str, Optional[str]]:
"""Given an endpoint and optionally inference component names, return the model ID and version.

Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
and version. A third string element is included in the tuple for any inferred inference
component name, or 'None' if it's a model-based endpoint.

JumpStart adds tags automatically to endpoints, models, endpoint configs, and inference
components launched in SageMaker Studio and programmatically with the SageMaker Python SDK.

Raises:
ValueError: If model ID and version cannot be inferred from the endpoint.
"""
if inference_component_name or sagemaker_session.is_inference_component_based_endpoint(
endpoint_name
):
if inference_component_name:
(
model_id,
model_version,
) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
inference_component_name, sagemaker_session
)

else:
(
model_id,
model_version,
inference_component_name,
) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
endpoint_name, sagemaker_session
)

else:
model_id, model_version = _get_model_id_version_from_model_based_endpoint(
endpoint_name, inference_component_name, sagemaker_session
)
return model_id, model_version, inference_component_name


def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
endpoint_name: str, sagemaker_session: Session
) -> Tuple[str, str, str]:
"""Given an endpoint name, derives the model ID, version, and inferred inference component name.

This function assumes the endpoint corresponds to an inference-component-based endpoint.
An endpoint is inference-component-based if and only if the associated endpoint config
has a role associated with it and no production variants with a ``ModelName`` field.

Raises:
ValueError: If there is not a single inference component associated with the endpoint.
"""
inference_component_names = (
sagemaker_session.list_and_paginate_inference_component_names_associated_with_endpoint(
endpoint_name=endpoint_name
)
)

if len(inference_component_names) == 0:
raise ValueError(
f"No inference component found for the following endpoint: {endpoint_name}. "
"Use ``SageMaker.CreateInferenceComponent`` to add inference components to "
"your endpoint."
)
if len(inference_component_names) > 1:
raise ValueError(
f"Multiple inference components found for the following endpoint: {endpoint_name}. "
"Provide an 'inference_component_name' to retrieve the model ID and version "
"associated with a particular inference component."
)
inference_component_name = inference_component_names[0]
return (
*_get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
inference_component_name, sagemaker_session
),
inference_component_name,
)


def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
inference_component_name: str, sagemaker_session: Session
):
"""Returns the model ID and version inferred from a SageMaker inference component.

Raises:
ValueError: If the inference component does not have tags from which the model ID
and version can be inferred.
"""
region: str = sagemaker_session.boto_region_name
partition: str = aws_partition(region)
account_id: str = sagemaker_session.account_id()

inference_component_arn = (
f"arn:{partition}:sagemaker:{region}:{account_id}:"
f"inference-component/{inference_component_name}"
)

model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
inference_component_arn, sagemaker_session
)

if not model_id:
raise ValueError(
"Cannot infer JumpStart model ID from inference component "
f"'{inference_component_name}'. Please specify JumpStart `model_id` "
"when retrieving default predictor for this inference component."
)

return model_id, model_version


def _get_model_id_version_from_model_based_endpoint(
endpoint_name: str,
inference_component_name: Optional[str],
sagemaker_session: Session,
) -> Tuple[str, str]:
"""Returns the model ID and version inferred from a model-based endpoint.

Raises:
ValueError: If an inference component name is supplied, or if the endpoint does
not have tags from which the model ID and version can be inferred.
"""

if inference_component_name:
raise ValueError("Cannot specify inference component name for model-based endpoints.")

region: str = sagemaker_session.boto_region_name
partition: str = aws_partition(region)
account_id: str = sagemaker_session.account_id()

# SageMaker Tagging requires endpoint names to be lower cased
endpoint_name = endpoint_name.lower()

endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"

model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
endpoint_arn, sagemaker_session
)

if not model_id:
raise ValueError(
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
"Please specify JumpStart `model_id` when retrieving default "
"predictor for this endpoint."
)

return model_id, model_version


def get_model_id_version_from_training_job(
training_job_name: str,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[str, str]:
"""Returns the model ID and version inferred from a training job.

Raises:
ValueError: If the training job does not have tags from which the model ID
and version can be inferred. JumpStart adds tags automatically to training jobs
launched in SageMaker Studio and programmatically with the SageMaker Python SDK.
"""
region: str = sagemaker_session.boto_region_name
partition: str = aws_partition(region)
account_id: str = sagemaker_session.account_id()

training_job_arn = (
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
)

model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
training_job_arn, sagemaker_session
)

model_version = inferred_model_version or None

if not model_id:
raise ValueError(
f"Cannot infer JumpStart model ID from training job '{training_job_name}'. "
"Please specify JumpStart `model_id` when retrieving Estimator "
"for this training job."
)

return model_id, model_version
50 changes: 49 additions & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import
import logging
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
import boto3
from packaging.version import Version
Expand Down Expand Up @@ -762,3 +762,51 @@ def is_valid_model_id(
if script == enums.JumpStartScriptScope.TRAINING:
return model_id in model_id_set
raise ValueError(f"Unsupported script: {script}")


def get_jumpstart_model_id_version_from_resource_arn(
resource_arn: str,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Tuple[Optional[str], Optional[str]]:
"""Returns the JumpStart model ID and version if in resource tags.

Returns 'None' if model ID or version cannot be inferred from tags.
"""

list_tags_result = sagemaker_session.list_tags(resource_arn)

model_id: Optional[str] = None
model_version: Optional[str] = None

model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]

for model_id_key in model_id_keys:
try:
model_id_from_tag = get_tag_value(model_id_key, list_tags_result)
except KeyError:
continue
if model_id_from_tag is not None:
if model_id is not None and model_id_from_tag != model_id:
constants.JUMPSTART_LOGGER.warning(
"Found multiple model ID tags on the following resource: %s", resource_arn
)
model_id = None
break
model_id = model_id_from_tag

for model_version_key in model_version_keys:
try:
model_version_from_tag = get_tag_value(model_version_key, list_tags_result)
except KeyError:
continue
if model_version_from_tag is not None:
if model_version is not None and model_version_from_tag != model_version:
constants.JUMPSTART_LOGGER.warning(
"Found multiple model version tags on the following resource: %s", resource_arn
)
model_version = None
break
model_version = model_version_from_tag

return model_id, model_version
41 changes: 33 additions & 8 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION

from sagemaker.jumpstart.factory.model import get_default_predictor
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint


from sagemaker.session import Session

Expand All @@ -33,6 +34,7 @@

def retrieve_default(
endpoint_name: str,
inference_component_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
region: Optional[str] = None,
model_id: Optional[str] = None,
Expand All @@ -44,7 +46,9 @@ def retrieve_default(

Args:
endpoint_name (str): Endpoint name for which to create a predictor.
sagemaker_session (Session): The SageMaker Session to attach to the Predictor.
inference_component_name (str): Name of the Amazon SageMaker inference component
from which to optionally create a predictor. (Default: None).
sagemaker_session (Session): The SageMaker Session to attach to the predictor.
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
region (str): The AWS Region for which to retrieve the default predictor.
(Default: None).
Expand All @@ -63,16 +67,37 @@ def retrieve_default(
Predictor: The default predictor to use for the model.

Raises:
ValueError: If the combination of arguments specified is not supported.
ValueError: If the combination of arguments specified is not supported, or if a model ID or
version cannot be inferred from the endpoint.
"""

if not is_jumpstart_model_input(model_id, model_version):
raise ValueError(
"Must specify JumpStart `model_id` and `model_version` "
"when retrieving default predictor."
if model_id is None:
(
inferred_model_id,
inferred_model_version,
inferred_inference_component_name,
) = get_model_id_version_from_endpoint(
endpoint_name, inference_component_name, sagemaker_session
)

predictor = Predictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)
if not inferred_model_id:
raise ValueError(
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. "
"Please specify JumpStart `model_id` when retrieving default "
"predictor for this endpoint."
)

model_id = inferred_model_id
model_version = model_version or inferred_model_version or "*"
inference_component_name = inference_component_name or inferred_inference_component_name
else:
model_version = model_version or "*"

predictor = Predictor(
endpoint_name=endpoint_name,
component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)

return get_default_predictor(
predictor=predictor,
Expand Down
Loading