Skip to content

Commit 647acba

Browse files
authored
chore: cleanup jumpstart factory (#4840)
* chore: cleanup jumpstart factory * fix: typing * chore: address pr comments, fix formatting * fix: failing config tests
1 parent de5de9b commit 647acba

20 files changed

+322
-422
lines changed

src/sagemaker/environment_variables.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23-
from sagemaker.jumpstart.enums import JumpStartScriptScope
23+
from sagemaker.jumpstart.enums import JumpStartModelType, JumpStartScriptScope
2424
from sagemaker.session import Session
2525

2626
logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ def retrieve_default(
3838
instance_type: Optional[str] = None,
3939
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
4040
config_name: Optional[str] = None,
41+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4142
) -> Dict[str, str]:
4243
"""Retrieves the default container environment variables for the model matching the arguments.
4344
@@ -70,6 +71,8 @@ def retrieve_default(
7071
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
7172
variables.
7273
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
74+
model_type (JumpStartModelType): The type of the model, can be open weights model
75+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7376
Returns:
7477
dict: The variables to use for the model.
7578
@@ -94,4 +97,5 @@ def retrieve_default(
9497
instance_type=instance_type,
9598
script=script,
9699
config_name=config_name,
100+
model_type=model_type,
97101
)

src/sagemaker/hyperparameters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23-
from sagemaker.jumpstart.enums import HyperparameterValidationMode
23+
from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType
2424
from sagemaker.jumpstart.validators import validate_hyperparameters
2525
from sagemaker.session import Session
2626

@@ -38,6 +38,7 @@ def retrieve_default(
3838
tolerate_deprecated_model: bool = False,
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4040
config_name: Optional[str] = None,
41+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4142
) -> Dict[str, str]:
4243
"""Retrieves the default training hyperparameters for the model matching the given arguments.
4344
@@ -71,6 +72,8 @@ def retrieve_default(
7172
specified, one is created using the default AWS configuration
7273
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
7374
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
75+
model_type (JumpStartModelType): The type of the model, can be open weights model
76+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7477
Returns:
7578
dict: The hyperparameters to use for the model.
7679
@@ -93,6 +96,7 @@ def retrieve_default(
9396
tolerate_deprecated_model=tolerate_deprecated_model,
9497
sagemaker_session=sagemaker_session,
9598
config_name=config_name,
99+
model_type=model_type,
96100
)
97101

98102

src/sagemaker/image_uris.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from sagemaker import utils
2424
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
25+
from sagemaker.jumpstart.enums import JumpStartModelType
2526
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2627
from sagemaker.spark import defaults
2728
from sagemaker.jumpstart import artifacts
@@ -72,6 +73,7 @@ def retrieve(
7273
serverless_inference_config=None,
7374
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7475
config_name=None,
76+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
7577
) -> str:
7678
"""Retrieves the ECR URI for the Docker image matching the given arguments.
7779
@@ -128,6 +130,8 @@ def retrieve(
128130
specified, one is created using the default AWS configuration
129131
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
130132
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
133+
model_type (JumpStartModelType): The type of the model, can be open weights model
134+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
131135
132136
Returns:
133137
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -169,6 +173,7 @@ def retrieve(
169173
tolerate_deprecated_model,
170174
sagemaker_session=sagemaker_session,
171175
config_name=config_name,
176+
model_type=model_type,
172177
)
173178

174179
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
2020
)
2121
from sagemaker.jumpstart.enums import (
22+
JumpStartModelType,
2223
JumpStartScriptScope,
2324
)
2425
from sagemaker.jumpstart.utils import (
@@ -41,6 +42,7 @@ def _retrieve_default_environment_variables(
4142
instance_type: Optional[str] = None,
4243
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
4344
config_name: Optional[str] = None,
45+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4446
) -> Dict[str, str]:
4547
"""Retrieves the inference environment variables for the model matching the given arguments.
4648
@@ -73,6 +75,8 @@ def _retrieve_default_environment_variables(
7375
script (JumpStartScriptScope): The JumpStart script for which to retrieve
7476
environment variables.
7577
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
78+
model_type (JumpStartModelType): The type of the model, can be open weights model
79+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7680
Returns:
7781
dict: the inference environment variables to use for the model.
7882
"""
@@ -91,6 +95,7 @@ def _retrieve_default_environment_variables(
9195
tolerate_deprecated_model=tolerate_deprecated_model,
9296
sagemaker_session=sagemaker_session,
9397
config_name=config_name,
98+
model_type=model_type,
9499
)
95100

96101
default_environment_variables: Dict[str, str] = {}
@@ -130,6 +135,7 @@ def _retrieve_default_environment_variables(
130135
sagemaker_session=sagemaker_session,
131136
instance_type=instance_type,
132137
config_name=config_name,
138+
model_type=model_type,
133139
)
134140
)
135141

@@ -178,6 +184,7 @@ def _retrieve_gated_model_uri_env_var_value(
178184
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
179185
instance_type: Optional[str] = None,
180186
config_name: Optional[str] = None,
187+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
181188
) -> Optional[str]:
182189
"""Retrieves the gated model env var URI matching the given arguments.
183190
@@ -204,7 +211,8 @@ def _retrieve_gated_model_uri_env_var_value(
204211
instance_type (str): An instance type to optionally supply in order to get
205212
environment variables specific for the instance type.
206213
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
207-
214+
model_type (JumpStartModelType): The type of the model, can be open weights model
215+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
208216
Returns:
209217
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
210218
have gated training artifacts.
@@ -227,6 +235,7 @@ def _retrieve_gated_model_uri_env_var_value(
227235
tolerate_deprecated_model=tolerate_deprecated_model,
228236
sagemaker_session=sagemaker_session,
229237
config_name=config_name,
238+
model_type=model_type,
230239
)
231240

232241
s3_key: Optional[str] = (

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1818
)
1919
from sagemaker.jumpstart.enums import (
20+
JumpStartModelType,
2021
JumpStartScriptScope,
2122
VariableScope,
2223
)
@@ -38,6 +39,7 @@ def _retrieve_default_hyperparameters(
3839
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3940
instance_type: Optional[str] = None,
4041
config_name: Optional[str] = None,
42+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4143
):
4244
"""Retrieves the training hyperparameters for the model matching the given arguments.
4345
@@ -71,6 +73,8 @@ def _retrieve_default_hyperparameters(
7173
instance_type (str): An instance type to optionally supply in order to get hyperparameters
7274
specific for the instance type.
7375
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
76+
model_type (JumpStartModelType): The type of the model, can be open weights model
77+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7478
Returns:
7579
dict: the hyperparameters to use for the model.
7680
"""
@@ -89,6 +93,7 @@ def _retrieve_default_hyperparameters(
8993
tolerate_deprecated_model=tolerate_deprecated_model,
9094
sagemaker_session=sagemaker_session,
9195
config_name=config_name,
96+
model_type=model_type,
9297
)
9398

9499
default_hyperparameters: Dict[str, str] = {}

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2020
)
2121
from sagemaker.jumpstart.enums import (
22+
JumpStartModelType,
2223
JumpStartScriptScope,
2324
ModelFramework,
2425
)
@@ -48,6 +49,7 @@ def _retrieve_image_uri(
4849
tolerate_deprecated_model: bool = False,
4950
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
5051
config_name: Optional[str] = None,
52+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
5153
):
5254
"""Retrieves the container image URI for JumpStart models.
5355
@@ -100,6 +102,8 @@ def _retrieve_image_uri(
100102
specified, one is created using the default AWS configuration
101103
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
102104
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
105+
model_type (JumpStartModelType): The type of the model, can be open weights model
106+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
103107
Returns:
104108
str: the ECR URI for the corresponding SageMaker Docker image.
105109
@@ -123,6 +127,7 @@ def _retrieve_image_uri(
123127
tolerate_deprecated_model=tolerate_deprecated_model,
124128
sagemaker_session=sagemaker_session,
125129
config_name=config_name,
130+
model_type=model_type,
126131
)
127132

128133
if image_scope == JumpStartScriptScope.INFERENCE:

src/sagemaker/jumpstart/artifacts/incremental_training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1818
)
1919
from sagemaker.jumpstart.enums import (
20+
JumpStartModelType,
2021
JumpStartScriptScope,
2122
)
2223
from sagemaker.jumpstart.utils import (
@@ -35,6 +36,7 @@ def _model_supports_incremental_training(
3536
tolerate_deprecated_model: bool = False,
3637
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3738
config_name: Optional[str] = None,
39+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3840
) -> bool:
3941
"""Returns True if the model supports incremental training.
4042
@@ -59,6 +61,8 @@ def _model_supports_incremental_training(
5961
specified, one is created using the default AWS configuration
6062
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6163
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
64+
model_type (JumpStartModelType): The type of the model, can be open weights model
65+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
6266
Returns:
6367
bool: the support status for incremental training.
6468
"""
@@ -77,6 +81,7 @@ def _model_supports_incremental_training(
7781
tolerate_deprecated_model=tolerate_deprecated_model,
7882
sagemaker_session=sagemaker_session,
7983
config_name=config_name,
84+
model_type=model_type,
8085
)
8186

8287
return model_specs.supports_incremental_training()

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _retrieve_estimator_init_kwargs(
167167
tolerate_deprecated_model: bool = False,
168168
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
169169
config_name: Optional[str] = None,
170+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
170171
) -> dict:
171172
"""Retrieves kwargs for `Estimator`.
172173
@@ -193,6 +194,8 @@ def _retrieve_estimator_init_kwargs(
193194
specified, one is created using the default AWS configuration
194195
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
195196
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
197+
model_type (JumpStartModelType): The type of the model, can be open weights model
198+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
196199
Returns:
197200
dict: the kwargs to use for the use case.
198201
"""
@@ -211,6 +214,7 @@ def _retrieve_estimator_init_kwargs(
211214
tolerate_deprecated_model=tolerate_deprecated_model,
212215
sagemaker_session=sagemaker_session,
213216
config_name=config_name,
217+
model_type=model_type,
214218
)
215219

216220
kwargs = deepcopy(model_specs.estimator_kwargs)
@@ -233,6 +237,7 @@ def _retrieve_estimator_fit_kwargs(
233237
tolerate_deprecated_model: bool = False,
234238
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
235239
config_name: Optional[str] = None,
240+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
236241
) -> dict:
237242
"""Retrieves kwargs for `Estimator.fit`.
238243
@@ -257,6 +262,8 @@ def _retrieve_estimator_fit_kwargs(
257262
specified, one is created using the default AWS configuration
258263
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
259264
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
265+
model_type (JumpStartModelType): The type of the model, can be open weights model
266+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
260267
261268
Returns:
262269
dict: the kwargs to use for the use case.
@@ -276,6 +283,7 @@ def _retrieve_estimator_fit_kwargs(
276283
tolerate_deprecated_model=tolerate_deprecated_model,
277284
sagemaker_session=sagemaker_session,
278285
config_name=config_name,
286+
model_type=model_type,
279287
)
280288

281289
return model_specs.fit_kwargs

src/sagemaker/jumpstart/artifacts/metric_definitions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1919
)
2020
from sagemaker.jumpstart.enums import (
21+
JumpStartModelType,
2122
JumpStartScriptScope,
2223
)
2324
from sagemaker.jumpstart.utils import (
@@ -37,6 +38,7 @@ def _retrieve_default_training_metric_definitions(
3738
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3839
instance_type: Optional[str] = None,
3940
config_name: Optional[str] = None,
41+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4042
) -> Optional[List[Dict[str, str]]]:
4143
"""Retrieves the default training metric definitions for the model.
4244
@@ -63,6 +65,8 @@ def _retrieve_default_training_metric_definitions(
6365
instance_type (str): An instance type to optionally supply in order to get
6466
metric definitions specific for the instance type.
6567
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
68+
model_type (JumpStartModelType): The type of the model, can be open weights model
69+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
6670
Returns:
6771
list: the default training metric definitions to use for the model or None.
6872
"""
@@ -81,6 +85,7 @@ def _retrieve_default_training_metric_definitions(
8185
tolerate_deprecated_model=tolerate_deprecated_model,
8286
sagemaker_session=sagemaker_session,
8387
config_name=config_name,
88+
model_type=model_type,
8489
)
8590

8691
default_metric_definitions = (

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
130130
tolerate_deprecated_model: bool = False,
131131
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
132132
config_name: Optional[str] = None,
133+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
133134
) -> Optional[str]:
134135
"""Retrieves s3 artifact uri associated with model package.
135136
@@ -156,6 +157,8 @@ def _retrieve_model_package_model_artifact_s3_uri(
156157
specified, one is created using the default AWS configuration
157158
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
158159
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
160+
model_type (JumpStartModelType): The type of the model, can be open weights model
161+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
159162
Returns:
160163
str: the model package artifact uri to use for the model or None.
161164
@@ -179,6 +182,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
179182
tolerate_deprecated_model=tolerate_deprecated_model,
180183
sagemaker_session=sagemaker_session,
181184
config_name=config_name,
185+
model_type=model_type,
182186
)
183187

184188
if model_specs.training_model_package_artifact_uris is None:

0 commit comments

Comments
 (0)