Skip to content

feat: jumpstart instance type variants #4068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2efa1e8
feat: jumpstart instance type variants
evakravi Aug 14, 2023
8d527f8
fix: typo
evakravi Aug 14, 2023
c21d2e1
fix: failing tests
evakravi Aug 14, 2023
e9e99c6
fix: docstring
evakravi Aug 14, 2023
fd41f0b
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Aug 15, 2023
9f912e4
chore: support for local and local_gpu
evakravi Aug 17, 2023
099d6f1
fix: variants -> regional_variants, properties -> regional_properties
evakravi Aug 25, 2023
3e0d4f6
chore: fallback to legacy logic if no match found
evakravi Aug 30, 2023
7666ee0
chore: add support for no ecr specs
evakravi Sep 11, 2023
78d5911
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 12, 2023
c319e65
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 12, 2023
001edaa
fix: flake8 line length
evakravi Sep 12, 2023
d9aa56f
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 14, 2023
0f136e7
chore: address PR comments, add support for bad metadata fallback
evakravi Sep 14, 2023
f84d5ae
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 15, 2023
76d5f4c
chore: remove unnecessary log, improve error message
evakravi Sep 15, 2023
71ee0ad
chore: slight adjustment to image uri variant logic
evakravi Sep 18, 2023
e1d8d30
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 19, 2023
6b6ccd1
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 19, 2023
4299b8f
chore: improve error message
evakravi Sep 19, 2023
2f22669
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 19, 2023
3f42c15
fix: instance type found but image uri for family
evakravi Sep 19, 2023
6ca6627
feat: instance type variants for environment variables
evakravi Sep 20, 2023
d87ce71
chore: address PR comments, fix formatting
evakravi Sep 20, 2023
814d0cd
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 20, 2023
fb16dbb
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 21, 2023
43947f5
Merge branch 'master' into feat/jumpstart-instance-type-variants
evakravi Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.session import Session

logger = logging.getLogger(__name__)
Expand All @@ -33,6 +34,8 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
include_aws_sdk_env_vars: bool = True,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

Expand All @@ -58,6 +61,10 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get environment
variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
variables.
Returns:
dict: The variables to use for the model.

Expand All @@ -78,4 +85,6 @@ def retrieve_default(
tolerate_deprecated_model,
include_aws_sdk_env_vars,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
script=script,
)
28 changes: 7 additions & 21 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,6 @@ def retrieve(
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)


def _get_instance_type_family(instance_type):
"""Return the family of the instance type.

Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
or there is no match, return an empty string.
"""
instance_type_family = ""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match is not None:
instance_type_family = match[1]
return instance_type_family


def _get_image_tag(
container_version,
distribution,
Expand All @@ -297,7 +283,7 @@ def _get_image_tag(
version,
):
"""Return image tag based on framework, container, and compute configuration(s)."""
instance_type_family = _get_instance_type_family(instance_type)
instance_type_family = utils.get_instance_type_family(instance_type)
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
_validate_arg(
Expand Down Expand Up @@ -385,7 +371,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non

def _validate_instance_deprecation(framework, instance_type, version):
"""Check if instance type is deprecated for a certain framework with a certain version"""
if _get_instance_type_family(instance_type) == "p2":
if utils.get_instance_type_family(instance_type) == "p2":
if (framework == "pytorch" and Version(version) >= Version("1.13")) or (
framework == "tensorflow" and Version(version) >= Version("2.12")
):
Expand All @@ -409,7 +395,7 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
# Validate for Graviton allowed frameowrks
if (
instance_type is not None
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
):
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
Expand All @@ -426,7 +412,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
"""Return final image scope based on provided framework and instance type."""
if (
framework in GRAVITON_ALLOWED_FRAMEWORKS
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
):
return INFERENCE_GRAVITON
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
Expand All @@ -441,7 +427,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
def _get_inference_tool(inference_tool, instance_type):
"""Extract the inference tool name from instance type."""
if not inference_tool:
instance_type_family = _get_instance_type_family(instance_type)
instance_type_family = utils.get_instance_type_family(instance_type)
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
return "neuron"
return inference_tool
Expand Down Expand Up @@ -529,7 +515,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
processor = "neuron"
else:
# looks for either "ml.<family>.<size>" or "ml_<family>"
family = _get_instance_type_family(instance_type)
family = utils.get_instance_type_family(instance_type)
if family:
# For some frameworks, we have optimized images for specific families, e.g c5 or p3.
# In those cases, we use the family name in the image tag. In other cases, we use
Expand Down Expand Up @@ -559,7 +545,7 @@ def _should_auto_select_container_version(instance_type, distribution):
p4d = False
if instance_type:
# looks for either "ml.<family>.<size>" or "ml_<family>"
family = _get_instance_type_family(instance_type)
family = utils.get_instance_type_family(instance_type)
if family:
p4d = family == "p4d"

Expand Down
36 changes: 31 additions & 5 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def _retrieve_default_environment_variables(
tolerate_deprecated_model: bool = False,
include_aws_sdk_env_vars: bool = True,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
) -> Dict[str, str]:
"""Retrieves the inference environment variables for the model matching the given arguments.

Expand All @@ -59,6 +61,10 @@ def _retrieve_default_environment_variables(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve
environment variables.
Returns:
dict: the inference environment variables to use for the model.
"""
Expand All @@ -69,17 +75,37 @@ def _retrieve_default_environment_variables(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=JumpStartScriptScope.INFERENCE,
scope=script,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

default_environment_variables: Dict[str, str] = {}
for environment_variable in model_specs.inference_environment_variables:
if include_aws_sdk_env_vars or environment_variable.required_for_model_class:
default_environment_variables[environment_variable.name] = str(
environment_variable.default
if script == JumpStartScriptScope.INFERENCE:
for environment_variable in model_specs.inference_environment_variables:
if include_aws_sdk_env_vars or environment_variable.required_for_model_class:
default_environment_variables[environment_variable.name] = str(
environment_variable.default
)

if instance_type:
if script == JumpStartScriptScope.INFERENCE and getattr(
model_specs, "hosting_instance_type_variants", None
):
default_environment_variables.update(
model_specs.hosting_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
instance_type
)
)
elif script == JumpStartScriptScope.TRAINING and getattr(
model_specs, "training_instance_type_variants", None
):
default_environment_variables.update(
model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
instance_type
)
)

return default_environment_variables
27 changes: 26 additions & 1 deletion src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,35 @@ def _retrieve_image_uri(
)

if image_scope == JumpStartScriptScope.INFERENCE:
hosting_instance_type_variants = model_specs.hosting_instance_type_variants
if hosting_instance_type_variants:
image_uri = hosting_instance_type_variants.get_image_uri(
instance_type=instance_type, region=region
)
if image_uri is not None:
return image_uri
ecr_specs = model_specs.hosting_ecr_specs
if ecr_specs is None:
raise ValueError(
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
f"with {instance_type} instance type in {region}. "
"Please try another instance type or region."
)
elif image_scope == JumpStartScriptScope.TRAINING:
training_instance_type_variants = model_specs.training_instance_type_variants
if training_instance_type_variants:
image_uri = training_instance_type_variants.get_image_uri(
instance_type=instance_type, region=region
)
if image_uri is not None:
return image_uri
ecr_specs = model_specs.training_ecr_specs

if ecr_specs is None:
raise ValueError(
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
f"with {instance_type} instance type in {region}. "
"Please try another instance type or region."
)
if framework is not None and framework != ecr_specs.framework:
raise ValueError(
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
script=JumpStartScriptScope.INFERENCE,
instance_type=kwargs.instance_type,
)

for key, value in extra_env_vars.items():
Expand Down
Loading