Skip to content

feat: jumpstart model artifact instance type variants #4172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,68 @@
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
from sagemaker.jumpstart.types import JumpStartModelSpecs


def _retrieve_hosting_prepacked_artifact_key(
model_specs: JumpStartModelSpecs, instance_type: str
) -> str:
"""Returns instance specific hosting prepacked artifact key or default one as fallback."""
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
instance_type=instance_type
)
if instance_type
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
else None
)

default_prepacked_hosting_artifact_key: Optional[str] = getattr(
model_specs, "hosting_prepacked_artifact_key"
)

return (
instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
)


def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
"""Returns instance specific hosting artifact key or default one as fallback."""
instance_specific_hosting_artifact_key: Optional[str] = (
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
instance_type=instance_type
)
if instance_type
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
else None
)

default_hosting_artifact_key: str = model_specs.hosting_artifact_key

return instance_specific_hosting_artifact_key or default_hosting_artifact_key


def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
"""Returns instance specific training artifact key or default one as fallback."""
instance_specific_training_artifact_key: Optional[str] = (
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
instance_type=instance_type
)
if instance_type
and getattr(model_specs, "training_instance_type_variants", None) is not None
else None
)

default_training_artifact_key: str = model_specs.training_artifact_key

return instance_specific_training_artifact_key or default_training_artifact_key


def _retrieve_model_uri(
model_id: str,
model_version: str,
model_scope: Optional[str] = None,
instance_type: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -50,6 +106,7 @@ def _retrieve_model_uri(
artifact S3 URI.
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
region (str): Region for which to retrieve model S3 URI. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
Expand Down Expand Up @@ -84,14 +141,21 @@ def _retrieve_model_uri(
sagemaker_session=sagemaker_session,
)

model_artifact_key: str

if model_scope == JumpStartScriptScope.INFERENCE:

is_prepacked = not model_specs.use_inference_script_uri()

model_artifact_key = (
getattr(model_specs, "hosting_prepacked_artifact_key", None)
or model_specs.hosting_artifact_key
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
if is_prepacked
else _retrieve_hosting_artifact_key(model_specs, instance_type)
)

elif model_scope == JumpStartScriptScope.TRAINING:
model_artifact_key = model_specs.training_artifact_key

model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)

bucket = os.environ.get(
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
)

if (
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
)

if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
Expand Down
56 changes: 56 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,62 @@ def to_json(self) -> Dict[str, Any]:
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return json_obj

def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
"""Returns instance specific model artifact key.

Returns None if a model, instance type tuple does not have specific
artifact key.
"""

return self._get_instance_specific_property(
instance_type=instance_type, property_name="prepacked_artifact_key"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should the property name include s3?

)

def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str]:
"""Returns instance specific model artifact key.

Returns None if a model, instance type tuple does not have specific
artifact key.
"""

return self._get_instance_specific_property(
instance_type=instance_type, property_name="artifact_key"
)

def _get_instance_specific_property(
self, instance_type: str, property_name: str
) -> Optional[str]:
"""Returns instance specific property.

If a value exists for both the instance family and instance type,
the instance type value is chosen.

Returns None if a (model, instance type, property name) tuple does not have
specific prepacked artifact key.
"""

if self.variants is None:
return None

instance_specific_property: Optional[str] = (
self.variants.get(instance_type, {}).get("properties", {}).get(property_name, None)
)

if instance_specific_property:
return instance_specific_property

instance_type_family = get_instance_type_family(instance_type)

instance_family_property: Optional[str] = (
self.variants.get(instance_type_family, {})
.get("properties", {})
.get(property_name, None)
if instance_type_family not in {"", None}
else None
)

return instance_family_property

def get_instance_specific_hyperparameters(
self, instance_type: str
) -> List[JumpStartHyperparameter]:
Expand Down
15 changes: 9 additions & 6 deletions src/sagemaker/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def retrieve(
model_id: Optional[str] = None,
model_version: Optional[str] = None,
model_scope: Optional[str] = None,
instance_type: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -44,6 +45,7 @@ def retrieve(
the model artifact S3 URI.
model_scope (str): The model type.
Valid values: "training" and "inference".
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
specifications should be tolerated without raising an exception. If ``False``, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down Expand Up @@ -71,11 +73,12 @@ def retrieve(
)

return artifacts._retrieve_model_uri(
model_id,
model_version, # type: ignore
model_scope,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version, # type: ignore
model_scope=model_scope,
instance_type=instance_type,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
Loading