Skip to content

Commit 2418598

Browse files
committed
feat: jsch jumpstart estimator support (aws#4439)
1 parent 700c16d commit 2418598

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1167
-98
lines changed

src/sagemaker/environment_variables.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve_default(
3030
region: Optional[str] = None,
3131
model_id: Optional[str] = None,
3232
model_version: Optional[str] = None,
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
include_aws_sdk_env_vars: bool = True,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default environment variables. (Default: None).
4748
model_version (str): Optional. The version of the model for which to retrieve the
4849
default environment variables. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5053
specifications should be tolerated (exception not raised). If False, raises an
5154
exception if the script used by this version of the model has dependencies with known
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_environment_variables(
8184
model_id,
8285
model_version,
86+
hub_arn,
8387
region,
8488
tolerate_vulnerable_model,
8589
tolerate_deprecated_model,

src/sagemaker/hyperparameters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def retrieve_default(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
34+
hub_arn: Optional[str] = None,
3435
instance_type: Optional[str] = None,
3536
include_container_hyperparameters: bool = False,
3637
tolerate_vulnerable_model: bool = False,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default hyperparameters. (Default: None).
4748
model_version (str): The version of the model for which to retrieve the
4849
default hyperparameters. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
instance_type (str): An instance type to optionally supply in order to get hyperparameters
5053
specific for the instance type.
5154
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_hyperparameters(
8184
model_id=model_id,
8285
model_version=model_version,
86+
hub_arn=hub_arn,
8387
instance_type=instance_type,
8488
region=region,
8589
include_container_hyperparameters=include_container_hyperparameters,
@@ -93,6 +97,7 @@ def validate(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
9599
model_version: Optional[str] = None,
100+
hub_arn: Optional[str] = None,
96101
hyperparameters: Optional[dict] = None,
97102
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
98103
tolerate_vulnerable_model: bool = False,
@@ -107,6 +112,8 @@ def validate(
107112
(Default: None).
108113
model_version (str): The version of the model for which to validate hyperparameters.
109114
(Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (default: None).
110117
hyperparameters (dict): Hyperparameters to validate.
111118
(Default: None).
112119
validation_mode (HyperparameterValidationMode): Method of validation to use with
@@ -148,6 +155,7 @@ def validate(
148155
return validate_hyperparameters(
149156
model_id=model_id,
150157
model_version=model_version,
158+
hub_arn=hub_arn,
151159
hyperparameters=hyperparameters,
152160
validation_mode=validation_mode,
153161
region=region,

src/sagemaker/image_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def retrieve(
6262
training_compiler_config=None,
6363
model_id=None,
6464
model_version=None,
65+
hub_arn=None,
6566
tolerate_vulnerable_model=False,
6667
tolerate_deprecated_model=False,
6768
sdk_version=None,
@@ -102,6 +103,8 @@ def retrieve(
102103
(default: None).
103104
model_version (str): The version of the JumpStart model for which to retrieve the
104105
image URI (default: None).
106+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
107+
model details from. (default: None).
105108
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
106109
should be tolerated without an exception raised. If ``False``, raises an exception if
107110
the script used by this version of the model has dependencies with known security
@@ -147,6 +150,7 @@ def retrieve(
147150
model_id,
148151
model_version,
149152
image_scope,
153+
hub_arn,
150154
framework,
151155
region,
152156
version,

src/sagemaker/instance_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve_default(
3030
region: Optional[str] = None,
3131
model_id: Optional[str] = None,
3232
model_version: Optional[str] = None,
33+
hub_arn: Optional[str] = None,
3334
scope: Optional[str] = None,
3435
tolerate_vulnerable_model: bool = False,
3536
tolerate_deprecated_model: bool = False,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default instance type. (Default: None).
4748
model_version (str): The version of the model for which to retrieve the
4849
default instance type. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
scope (str): The model type, i.e. what it is used for.
5053
Valid values: "training" and "inference".
5154
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -82,6 +85,7 @@ def retrieve_default(
8285
model_id,
8386
model_version,
8487
scope,
88+
hub_arn,
8589
region,
8690
tolerate_vulnerable_model,
8791
tolerate_deprecated_model,
@@ -95,6 +99,7 @@ def retrieve(
9599
region: Optional[str] = None,
96100
model_id: Optional[str] = None,
97101
model_version: Optional[str] = None,
102+
hub_arn: Optional[str] = None,
98103
scope: Optional[str] = None,
99104
tolerate_vulnerable_model: bool = False,
100105
tolerate_deprecated_model: bool = False,
@@ -110,6 +115,8 @@ def retrieve(
110115
retrieve the supported instance types. (Default: None).
111116
model_version (str): The version of the model for which to retrieve the
112117
supported instance types. (Default: None).
118+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
119+
model details from. (default: None).
113120
tolerate_vulnerable_model (bool): True if vulnerable versions of model
114121
specifications should be tolerated (exception not raised). If False, raises an
115122
exception if the script used by this version of the model has dependencies with known
@@ -145,6 +152,7 @@ def retrieve(
145152
model_id,
146153
model_version,
147154
scope,
155+
hub_arn,
148156
region,
149157
tolerate_vulnerable_model,
150158
tolerate_deprecated_model,

src/sagemaker/jumpstart/accessors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
2121
from sagemaker.jumpstart.enums import JumpStartModelType
2222
from sagemaker.jumpstart import cache
23+
from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs
2324
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2425

2526

@@ -255,6 +256,7 @@ def get_model_specs(
255256
version: str,
256257
s3_client: Optional[boto3.client] = None,
257258
model_type=JumpStartModelType.OPEN_WEIGHTS,
259+
hub_arn: Optional[str] = None,
258260
) -> JumpStartModelSpecs:
259261
"""Returns model specs from JumpStart models cache.
260262
@@ -274,6 +276,13 @@ def get_model_specs(
274276
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
275277
)
276278
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
279+
280+
if hub_arn:
281+
hub_model_arn = construct_hub_model_arn_from_inputs(
282+
hub_arn=hub_arn, model_name=model_id, version=version
283+
)
284+
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)
285+
277286
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
278287
model_id=model_id, version_str=version, model_type=model_type
279288
)

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
def _retrieve_default_environment_variables(
3333
model_id: str,
3434
model_version: str,
35+
hub_arn: Optional[str] = None,
3536
region: Optional[str] = None,
3637
tolerate_vulnerable_model: bool = False,
3738
tolerate_deprecated_model: bool = False,
@@ -47,6 +48,8 @@ def _retrieve_default_environment_variables(
4748
retrieve the default environment variables.
4849
model_version (str): Version of the JumpStart model for which to retrieve the
4950
default environment variables.
51+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
52+
model details from. (default: None).
5053
region (Optional[str]): Region for which to retrieve default environment variables.
5154
(Default: None).
5255
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -78,6 +81,7 @@ def _retrieve_default_environment_variables(
7881
model_specs = verify_model_region_and_return_specs(
7982
model_id=model_id,
8083
version=model_version,
84+
hub_arn=hub_arn,
8185
scope=script,
8286
region=region,
8387
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -116,6 +120,7 @@ def _retrieve_default_environment_variables(
116120
] = lambda instance_type: _retrieve_gated_model_uri_env_var_value(
117121
model_id=model_id,
118122
model_version=model_version,
123+
hub_arn=hub_arn,
119124
region=region,
120125
tolerate_vulnerable_model=tolerate_vulnerable_model,
121126
tolerate_deprecated_model=tolerate_deprecated_model,
@@ -161,6 +166,7 @@ def _retrieve_default_environment_variables(
161166
def _retrieve_gated_model_uri_env_var_value(
162167
model_id: str,
163168
model_version: str,
169+
hub_arn: Optional[str] = None,
164170
region: Optional[str] = None,
165171
tolerate_vulnerable_model: bool = False,
166172
tolerate_deprecated_model: bool = False,
@@ -174,6 +180,8 @@ def _retrieve_gated_model_uri_env_var_value(
174180
retrieve the gated model env var URI.
175181
model_version (str): Version of the JumpStart model for which to retrieve the
176182
gated model env var URI.
183+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
184+
model details from. (default: None).
177185
region (Optional[str]): Region for which to retrieve the gated model env var URI.
178186
(Default: None).
179187
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -204,6 +212,7 @@ def _retrieve_gated_model_uri_env_var_value(
204212
model_specs = verify_model_region_and_return_specs(
205213
model_id=model_id,
206214
version=model_version,
215+
hub_arn=hub_arn,
207216
scope=JumpStartScriptScope.TRAINING,
208217
region=region,
209218
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
def _retrieve_default_hyperparameters(
3131
model_id: str,
3232
model_version: str,
33+
hub_arn: Optional[str] = None,
3334
region: Optional[str] = None,
3435
include_container_hyperparameters: bool = False,
3536
tolerate_vulnerable_model: bool = False,
@@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters(
4445
retrieve the default hyperparameters.
4546
model_version (str): Version of the JumpStart model for which to retrieve the
4647
default hyperparameters.
48+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
49+
model details from. (default: None).
4750
region (str): Region for which to retrieve default hyperparameters.
4851
(Default: None).
4952
include_container_hyperparameters (bool): True if container hyperparameters
@@ -76,6 +79,7 @@ def _retrieve_default_hyperparameters(
7679
model_specs = verify_model_region_and_return_specs(
7780
model_id=model_id,
7881
version=model_version,
82+
hub_arn=hub_arn,
7983
scope=JumpStartScriptScope.TRAINING,
8084
region=region,
8185
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _retrieve_image_uri(
3333
model_id: str,
3434
model_version: str,
3535
image_scope: str,
36+
hub_arn: Optional[str] = None,
3637
framework: Optional[str] = None,
3738
region: Optional[str] = None,
3839
version: Optional[str] = None,
@@ -57,6 +58,8 @@ def _retrieve_image_uri(
5758
model_id (str): JumpStart model ID for which to retrieve image URI.
5859
model_version (str): Version of the JumpStart model for which to retrieve
5960
the image URI.
61+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
62+
model details from. (default: None).
6063
image_scope (str): The image type, i.e. what it is used for.
6164
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
6265
``image_scope`` is ignored.
@@ -110,6 +113,7 @@ def _retrieve_image_uri(
110113
model_specs = verify_model_region_and_return_specs(
111114
model_id=model_id,
112115
version=model_version,
116+
hub_arn=hub_arn,
113117
scope=image_scope,
114118
region=region,
115119
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/incremental_training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _model_supports_incremental_training(
3030
model_id: str,
3131
model_version: str,
3232
region: Optional[str],
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -43,6 +44,8 @@ def _model_supports_incremental_training(
4344
support status for incremental training.
4445
region (Optional[str]): Region for which to retrieve the
4546
support status for incremental training.
47+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
48+
model details from. (default: None).
4649
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4750
specifications should be tolerated (exception not raised). If False, raises an
4851
exception if the script used by this version of the model has dependencies with known
@@ -64,6 +67,7 @@ def _model_supports_incremental_training(
6467
model_specs = verify_model_region_and_return_specs(
6568
model_id=model_id,
6669
version=model_version,
70+
hub_arn=hub_arn,
6771
scope=JumpStartScriptScope.TRAINING,
6872
region=region,
6973
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _retrieve_default_instance_type(
3434
model_id: str,
3535
model_version: str,
3636
scope: str,
37+
hub_arn: Optional[str] = None,
3738
region: Optional[str] = None,
3839
tolerate_vulnerable_model: bool = False,
3940
tolerate_deprecated_model: bool = False,
@@ -50,6 +51,8 @@ def _retrieve_default_instance_type(
5051
default instance type.
5152
scope (str): The script type, i.e. what it is used for.
5253
Valid values: "training" and "inference".
54+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
55+
model details from. (default: None).
5356
region (Optional[str]): Region for which to retrieve default instance type.
5457
(Default: None).
5558
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -82,6 +85,7 @@ def _retrieve_default_instance_type(
8285
model_specs = verify_model_region_and_return_specs(
8386
model_id=model_id,
8487
version=model_version,
88+
hub_arn=hub_arn,
8589
scope=scope,
8690
region=region,
8791
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -122,6 +126,7 @@ def _retrieve_instance_types(
122126
model_id: str,
123127
model_version: str,
124128
scope: str,
129+
hub_arn: Optional[str] = None,
125130
region: Optional[str] = None,
126131
tolerate_vulnerable_model: bool = False,
127132
tolerate_deprecated_model: bool = False,
@@ -137,6 +142,8 @@ def _retrieve_instance_types(
137142
supported instance types.
138143
scope (str): The script type, i.e. what it is used for.
139144
Valid values: "training" and "inference".
145+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
146+
model details from. (default: None).
140147
region (Optional[str]): Region for which to retrieve supported instance types.
141148
(Default: None).
142149
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -169,6 +176,7 @@ def _retrieve_instance_types(
169176
model_specs = verify_model_region_and_return_specs(
170177
model_id=model_id,
171178
version=model_version,
179+
hub_arn=hub_arn,
172180
scope=scope,
173181
region=region,
174182
tolerate_vulnerable_model=tolerate_vulnerable_model,

0 commit comments

Comments
 (0)