Skip to content

chore: cleanup jumpstart factory #4840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.enums import JumpStartModelType, JumpStartScriptScope
from sagemaker.session import Session

logger = logging.getLogger(__name__)
Expand All @@ -38,6 +38,7 @@ def retrieve_default(
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

Expand Down Expand Up @@ -70,6 +71,8 @@ def retrieve_default(
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: The variables to use for the model.

Expand All @@ -94,4 +97,5 @@ def retrieve_default(
instance_type=instance_type,
script=script,
config_name=config_name,
model_type=model_type,
)
6 changes: 5 additions & 1 deletion src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import HyperparameterValidationMode
from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType
from sagemaker.jumpstart.validators import validate_hyperparameters
from sagemaker.session import Session

Expand All @@ -38,6 +38,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.

Expand Down Expand Up @@ -71,6 +72,8 @@ def retrieve_default(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: The hyperparameters to use for the model.

Expand All @@ -93,6 +96,7 @@ def retrieve_default(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)


Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from sagemaker import utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
from sagemaker.jumpstart import artifacts
Expand Down Expand Up @@ -72,6 +73,7 @@ def retrieve(
serverless_inference_config=None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name=None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -128,6 +130,8 @@ def retrieve(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).

Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
Expand Down Expand Up @@ -169,6 +173,7 @@ def retrieve(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
Expand All @@ -41,6 +42,7 @@ def _retrieve_default_environment_variables(
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Dict[str, str]:
"""Retrieves the inference environment variables for the model matching the given arguments.

Expand Down Expand Up @@ -73,6 +75,8 @@ def _retrieve_default_environment_variables(
script (JumpStartScriptScope): The JumpStart script for which to retrieve
environment variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: the inference environment variables to use for the model.
"""
Expand All @@ -91,6 +95,7 @@ def _retrieve_default_environment_variables(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

default_environment_variables: Dict[str, str] = {}
Expand Down Expand Up @@ -130,6 +135,7 @@ def _retrieve_default_environment_variables(
sagemaker_session=sagemaker_session,
instance_type=instance_type,
config_name=config_name,
model_type=model_type,
)
)

Expand Down Expand Up @@ -178,6 +184,7 @@ def _retrieve_gated_model_uri_env_var_value(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Optional[str]:
"""Retrieves the gated model env var URI matching the given arguments.

Expand All @@ -204,7 +211,8 @@ def _retrieve_gated_model_uri_env_var_value(
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).

model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
have gated training artifacts.
Expand All @@ -227,6 +235,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

s3_key: Optional[str] = (
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
VariableScope,
)
Expand All @@ -38,6 +39,7 @@ def _retrieve_default_hyperparameters(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
):
"""Retrieves the training hyperparameters for the model matching the given arguments.

Expand Down Expand Up @@ -71,6 +73,8 @@ def _retrieve_default_hyperparameters(
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: the hyperparameters to use for the model.
"""
Expand All @@ -89,6 +93,7 @@ def _retrieve_default_hyperparameters(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

default_hyperparameters: Dict[str, str] = {}
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
ModelFramework,
)
Expand Down Expand Up @@ -48,6 +49,7 @@ def _retrieve_image_uri(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
):
"""Retrieves the container image URI for JumpStart models.

Expand Down Expand Up @@ -100,6 +102,8 @@ def _retrieve_image_uri(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Expand All @@ -123,6 +127,7 @@ def _retrieve_image_uri(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

if image_scope == JumpStartScriptScope.INFERENCE:
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
Expand All @@ -35,6 +36,7 @@ def _model_supports_incremental_training(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> bool:
"""Returns True if the model supports incremental training.

Expand All @@ -59,6 +61,8 @@ def _model_supports_incremental_training(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
bool: the support status for incremental training.
"""
Expand All @@ -77,6 +81,7 @@ def _model_supports_incremental_training(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

return model_specs.supports_incremental_training()
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/artifacts/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _retrieve_estimator_init_kwargs(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> dict:
"""Retrieves kwargs for `Estimator`.

Expand All @@ -193,6 +194,8 @@ def _retrieve_estimator_init_kwargs(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
dict: the kwargs to use for the use case.
"""
Expand All @@ -211,6 +214,7 @@ def _retrieve_estimator_init_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

kwargs = deepcopy(model_specs.estimator_kwargs)
Expand All @@ -233,6 +237,7 @@ def _retrieve_estimator_fit_kwargs(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> dict:
"""Retrieves kwargs for `Estimator.fit`.

Expand All @@ -257,6 +262,8 @@ def _retrieve_estimator_fit_kwargs(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).

Returns:
dict: the kwargs to use for the use case.
Expand All @@ -276,6 +283,7 @@ def _retrieve_estimator_fit_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

return model_specs.fit_kwargs
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import (
JumpStartModelType,
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
Expand All @@ -37,6 +38,7 @@ def _retrieve_default_training_metric_definitions(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Optional[List[Dict[str, str]]]:
"""Retrieves the default training metric definitions for the model.

Expand All @@ -63,6 +65,8 @@ def _retrieve_default_training_metric_definitions(
instance_type (str): An instance type to optionally supply in order to get
metric definitions specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
list: the default training metric definitions to use for the model or None.
"""
Expand All @@ -81,6 +85,7 @@ def _retrieve_default_training_metric_definitions(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

default_metric_definitions = (
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Optional[str]:
"""Retrieves s3 artifact uri associated with model package.

Expand All @@ -156,6 +157,8 @@ def _retrieve_model_package_model_artifact_s3_uri(
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
model_type (JumpStartModelType): The type of the model, can be open weights model
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
str: the model package artifact uri to use for the model or None.

Expand All @@ -179,6 +182,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
model_type=model_type,
)

if model_specs.training_model_package_artifact_uris is None:
Expand Down
Loading