-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: retrieve jumpstart estimator and predictor without specifying model id (infer from tags) #4304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
knikure
merged 35 commits into
aws:master
from
evakravi:feat/retrieve-js-estimator-predictor-without-model-id
Feb 1, 2024
Merged
feat: retrieve jumpstart estimator and predictor without specifying model id (infer from tags) #4304
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
21177c4
feat: retrieve jumpstart estimator and predictor without specifying m…
evakravi 8086f6e
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi 69b2929
fix: pylint
evakravi 3988710
chore: add support for ic-based endpoints
evakravi c4f7ffc
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi 1dd1df4
chore: update docstrings
evakravi 73c17a0
chore: add integ tests
evakravi a5de8ce
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi ab60cb4
chore: add support in conftest for ic endpoints
evakravi ec523be
fix: delete inference components
evakravi e73bd8c
chore: address tagging and ic determination comments
evakravi 4d963d6
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi aa0d354
chore: address PR comments
evakravi caf4844
fix: docstring
evakravi 2292586
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 995eb6d
chore: improve docs
evakravi c5285df
Merge remote-tracking branch 'origin' into feat/retrieve-js-estimator…
evakravi aa363c6
chore: address comments
evakravi 43cc2e1
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 73dad91
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi e4cef3a
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 8d73539
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 6c3c965
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi a658b17
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi edd128f
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 133d3b7
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 93c9adf
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 5b2e0f9
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 8082dfa
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi ea8506d
fix: list_and_paginate_inference_component_names_associated_with_endp…
evakravi 7101638
fix: boto3 session region
evakravi 254f147
fix: boto_session
evakravi dc245d7
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi 13b7b2c
fix: sagemaker session to delete IC
evakravi 62a9262
Merge branch 'master' into feat/retrieve-js-estimator-predictor-witho…
evakravi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module stores SageMaker Session utilities for JumpStart models.""" | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from __future__ import absolute_import | ||
|
||
from typing import Optional, Tuple | ||
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION | ||
|
||
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn | ||
from sagemaker.session import Session | ||
from sagemaker.utils import aws_partition | ||
|
||
|
||
def get_model_id_version_from_endpoint( | ||
endpoint_name: str, | ||
inference_component_name: Optional[str] = None, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> Tuple[str, str, Optional[str]]: | ||
"""Given an endpoint and optionally inference component names, return the model ID and version. | ||
|
||
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID | ||
and version. A third string element is included in the tuple for any inferred inference | ||
component name, or 'None' if it's a model-based endpoint. | ||
|
||
JumpStart adds tags automatically to endpoints, models, endpoint configs, and inference | ||
components launched in SageMaker Studio and programmatically with the SageMaker Python SDK. | ||
|
||
Raises: | ||
ValueError: If model ID and version cannot be inferred from the endpoint. | ||
""" | ||
if inference_component_name or sagemaker_session.is_inference_component_based_endpoint( | ||
endpoint_name | ||
): | ||
if inference_component_name: | ||
( | ||
model_id, | ||
model_version, | ||
) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 | ||
inference_component_name, sagemaker_session | ||
) | ||
|
||
else: | ||
( | ||
model_id, | ||
model_version, | ||
inference_component_name, | ||
) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 | ||
endpoint_name, sagemaker_session | ||
) | ||
|
||
else: | ||
model_id, model_version = _get_model_id_version_from_model_based_endpoint( | ||
endpoint_name, inference_component_name, sagemaker_session | ||
) | ||
return model_id, model_version, inference_component_name | ||
|
||
|
||
def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( | ||
endpoint_name: str, sagemaker_session: Session | ||
) -> Tuple[str, str, str]: | ||
"""Given an endpoint name, derives the model ID, version, and inferred inference component name. | ||
|
||
This function assumes the endpoint corresponds to an inference-component-based endpoint. | ||
An endpoint is inference-component-based if and only if the associated endpoint config | ||
has a role associated with it and no production variants with a ``ModelName`` field. | ||
|
||
Raises: | ||
ValueError: If there is not a single inference component associated with the endpoint. | ||
""" | ||
inference_component_names = ( | ||
sagemaker_session.list_and_paginate_inference_component_names_associated_with_endpoint( | ||
endpoint_name=endpoint_name | ||
) | ||
) | ||
|
||
if len(inference_component_names) == 0: | ||
raise ValueError( | ||
f"No inference component found for the following endpoint: {endpoint_name}. " | ||
"Use ``SageMaker.CreateInferenceComponent`` to add inference components to " | ||
"your endpoint." | ||
) | ||
if len(inference_component_names) > 1: | ||
raise ValueError( | ||
f"Multiple inference components found for the following endpoint: {endpoint_name}. " | ||
"Provide an 'inference_component_name' to retrieve the model ID and version " | ||
"associated with a particular inference component." | ||
) | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inference_component_name = inference_component_names[0] | ||
return ( | ||
*_get_model_id_version_from_inference_component_endpoint_with_inference_component_name( | ||
inference_component_name, sagemaker_session | ||
), | ||
inference_component_name, | ||
) | ||
|
||
|
||
def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( | ||
inference_component_name: str, sagemaker_session: Session | ||
): | ||
"""Returns the model ID and version inferred from a SageMaker inference component. | ||
|
||
Raises: | ||
ValueError: If the inference component does not have tags from which the model ID | ||
and version can be inferred. | ||
""" | ||
region: str = sagemaker_session.boto_region_name | ||
partition: str = aws_partition(region) | ||
account_id: str = sagemaker_session.account_id() | ||
|
||
inference_component_arn = ( | ||
f"arn:{partition}:sagemaker:{region}:{account_id}:" | ||
f"inference-component/{inference_component_name}" | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( | ||
inference_component_arn, sagemaker_session | ||
) | ||
|
||
if not model_id: | ||
raise ValueError( | ||
"Cannot infer JumpStart model ID from inference component " | ||
f"'{inference_component_name}'. Please specify JumpStart `model_id` " | ||
"when retrieving default predictor for this inference component." | ||
) | ||
|
||
return model_id, model_version | ||
|
||
|
||
def _get_model_id_version_from_model_based_endpoint( | ||
endpoint_name: str, | ||
inference_component_name: Optional[str], | ||
sagemaker_session: Session, | ||
) -> Tuple[str, str]: | ||
"""Returns the model ID and version inferred from a model-based endpoint. | ||
|
||
Raises: | ||
ValueError: If an inference component name is supplied, or if the endpoint does | ||
not have tags from which the model ID and version can be inferred. | ||
""" | ||
|
||
if inference_component_name: | ||
raise ValueError("Cannot specify inference component name for model-based endpoints.") | ||
|
||
region: str = sagemaker_session.boto_region_name | ||
partition: str = aws_partition(region) | ||
account_id: str = sagemaker_session.account_id() | ||
|
||
# SageMaker Tagging requires endpoint names to be lower cased | ||
endpoint_name = endpoint_name.lower() | ||
|
||
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( | ||
endpoint_arn, sagemaker_session | ||
) | ||
|
||
if not model_id: | ||
raise ValueError( | ||
f"Cannot infer JumpStart model ID from endpoint '{endpoint_name}'. " | ||
"Please specify JumpStart `model_id` when retrieving default " | ||
"predictor for this endpoint." | ||
) | ||
|
||
return model_id, model_version | ||
|
||
|
||
def get_model_id_version_from_training_job( | ||
training_job_name: str, | ||
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> Tuple[str, str]: | ||
"""Returns the model ID and version inferred from a training job. | ||
|
||
Raises: | ||
ValueError: If the training job does not have tags from which the model ID | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and version can be inferred. JumpStart adds tags automatically to training jobs | ||
launched in SageMaker Studio and programmatically with the SageMaker Python SDK. | ||
""" | ||
region: str = sagemaker_session.boto_region_name | ||
partition: str = aws_partition(region) | ||
account_id: str = sagemaker_session.account_id() | ||
|
||
training_job_arn = ( | ||
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" | ||
) | ||
|
||
model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn( | ||
training_job_arn, sagemaker_session | ||
) | ||
|
||
model_version = inferred_model_version or None | ||
|
||
if not model_id: | ||
raise ValueError( | ||
f"Cannot infer JumpStart model ID from training job '{training_job_name}'. " | ||
"Please specify JumpStart `model_id` when retrieving Estimator " | ||
"for this training job." | ||
) | ||
|
||
return model_id, model_version |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.