Skip to content

Commit 7203862

Browse files
evakravimufaddal-rohawalaBasil BeiroutiPayton Staubahsan-z-khan
authored
feature: JumpStart Integration (#2870)
Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Basil Beirouti <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Basil Beirouti <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Mohamed Ali Jamaoui <[email protected]> Co-authored-by: ci <ci> Co-authored-by: Jeniya Tabassum <[email protected]> Co-authored-by: sreedes <[email protected]> Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Miyoung <[email protected]> Co-authored-by: Ameen Khan <[email protected]> Co-authored-by: Zhankui Lu <[email protected]> Co-authored-by: Xiaoguang Chen <[email protected]> Co-authored-by: Jonathan Guinegagne <[email protected]> Co-authored-by: Zhankui Lu <[email protected]> Co-authored-by: Yifei Zhu <[email protected]> Co-authored-by: Qingzi-Lan <[email protected]> Co-authored-by: Navin Soni <[email protected]> Co-authored-by: marckarp <[email protected]> Co-authored-by: chenxy <[email protected]> Co-authored-by: Xinghan Chen <[email protected]> Co-authored-by: Tulio Casagrande <[email protected]> Co-authored-by: jerrypeng7773 <[email protected]> Co-authored-by: marckarp <[email protected]> Co-authored-by: jayatalr <[email protected]> Co-authored-by: bhaoz <[email protected]> Co-authored-by: Ethan Cheng <[email protected]> Co-authored-by: Xiaoguang Chen <[email protected]> Co-authored-by: keerthanvasist <[email protected]> Co-authored-by: Shreya Pandit <[email protected]>
1 parent 4886405 commit 7203862

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+10531
-292
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ venv/
2828
.docker/
2929
env/
3030
.vscode/
31+
**/tmp
3132
.python-version

src/sagemaker/algorithm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
self.validate_train_spec()
175175
self.hyperparameter_definitions = self._parse_hyperparameters()
176176

177-
self.hyperparam_dict = {}
177+
self._hyperparameters = {}
178178
if hyperparameters:
179179
self.set_hyperparameters(**hyperparameters)
180180

@@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs):
215215
"""Placeholder docstring"""
216216
for k, v in kwargs.items():
217217
value = self._validate_and_cast_hyperparameter(k, v)
218-
self.hyperparam_dict[k] = value
218+
self._hyperparameters[k] = value
219219

220220
self._validate_and_set_default_hyperparameters()
221221

@@ -225,7 +225,7 @@ def hyperparameters(self):
225225
The fit() method, that does the model training, calls this method to
226226
find the hyperparameters you specified.
227227
"""
228-
return self.hyperparam_dict
228+
return self._hyperparameters
229229

230230
def training_image_uri(self):
231231
"""Returns the docker image to use for training.
@@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self):
464464
# Check if all the required hyperparameters are set. If there is a default value
465465
# for one, set it.
466466
for name, definition in self.hyperparameter_definitions.items():
467-
if name not in self.hyperparam_dict:
467+
if name not in self._hyperparameters:
468468
spec = definition["spec"]
469469
if "DefaultValue" in spec:
470-
self.hyperparam_dict[name] = spec["DefaultValue"]
470+
self._hyperparameters[name] = spec["DefaultValue"]
471471
elif "IsRequired" in spec and spec["IsRequired"]:
472472
raise ValueError("Required hyperparameter: %s is not set" % name)
473473

src/sagemaker/chainer/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import logging
1717

18-
from sagemaker.estimator import Framework
18+
from sagemaker.estimator import Framework, EstimatorBase
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
@@ -158,7 +158,9 @@ def hyperparameters(self):
158158

159159
# remove unset keys.
160160
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
161-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
161+
hyperparameters.update(
162+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
163+
)
162164
return hyperparameters
163165

164166
def create_model(

src/sagemaker/chainer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
168168
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
169169
self._upload_code(deploy_key_prefix)
170170
deploy_env = dict(self.env)
171-
deploy_env.update(self._framework_env_vars())
171+
deploy_env.update(self._script_mode_env_vars())
172172

173173
if self.model_server_workers:
174174
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve environment variables for hosting containers."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import Dict
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region=None,
28+
model_id=None,
29+
model_version=None,
30+
) -> Dict[str, str]:
31+
"""Retrieves the default container environment variables for the model matching the arguments.
32+
33+
Args:
34+
region (str): Optional. Region for which to retrieve default environment variables.
35+
(Default: None).
36+
model_id (str): Optional. Model ID of the model for which to
37+
retrieve the default environment variables. (Default: None).
38+
model_version (str): Optional. Version of the model for which to retrieve the
39+
default environment variables. (Default: None).
40+
Returns:
41+
dict: the variables to use for the model.
42+
43+
Raises:
44+
ValueError: If the combination of arguments specified is not supported.
45+
"""
46+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
47+
raise ValueError(
48+
"Must specify `model_id` and `model_version` when retrieving environment variables."
49+
)
50+
51+
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)

0 commit comments

Comments
 (0)