Skip to content

feature: JumpStart Integration #2870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b09793a
feature: client cache for jumpstart models
evakravi Dec 7, 2021
1da6b98
feature: allow conditional parellel builds (#2727)
mufaddal-rohawala Nov 4, 2021
1489363
fix endpoint bug (#2772)
BasilBeirouti Dec 6, 2021
ba55962
fix: local mode - support relative file structure (#2768)
mufaddal-rohawala Dec 7, 2021
221411c
prepare release v2.72.0
Dec 13, 2021
395bc24
update development version to v2.72.1.dev0
Dec 13, 2021
b689689
fix: Set ProcessingStep upload locations deterministically to avoid c…
staubhp Dec 8, 2021
5cd25f3
fix: Prevent repack_model script from referencing nonexistent directo…
staubhp Dec 9, 2021
2e8710d
fix: S3Input - add support for instance attributes (#2754)
mufaddal-rohawala Dec 15, 2021
05478a0
fix: typos and broken link (#2765)
mohamed-ali Dec 16, 2021
7da8254
prepare release v2.72.1
Dec 20, 2021
51404ba
update development version to v2.72.2.dev0
Dec 20, 2021
5e83d5c
fix: Model Registration with BYO scripts (#2797)
sreedes Dec 17, 2021
7def439
fix: Add ContentType in test_auto_ml_describe
navinns Dec 27, 2021
cd870ec
fix: Re-deploy static integ test endpoint if it is not found
Dec 27, 2021
116ce8a
documentation :SageMaker model parallel library 1.6.0 API doc (#2814)
mchoi8739 Dec 30, 2021
9ebfaf1
fix: fix kmeans test deletion sequence, increment lineage statics (#2…
mufaddal-rohawala Dec 31, 2021
071228a
fix: Increment static lineage pipeline (#2817)
mufaddal-rohawala Jan 3, 2022
d36d9aa
fix: Update CHANGELOG.md (#2832)
ahsan-z-khan Jan 6, 2022
02e0b15
prepare release v2.72.2
Jan 6, 2022
ee247c0
update development version to v2.72.3.dev0
Jan 6, 2022
20740cb
change: update master from dev (#2836)
ahsan-z-khan Jan 10, 2022
3d491ac
prepare release v2.72.3
Jan 10, 2022
068ad64
update development version to v2.72.4.dev0
Jan 10, 2022
3859a94
fix: fixes unnecessary session call while generating pipeline definit…
xchen909 Jan 10, 2022
77bf04f
feature: Add models_v2 under lineage context (#2800)
yzhu0 Jan 10, 2022
5da5364
feature: enable python 3.9 (#2802)
mufaddal-rohawala Jan 10, 2022
b105b19
change: Update CHANGELOG.md (#2842)
shreyapandit Jan 11, 2022
ac57772
fix: update pricing link (#2805)
ahsan-z-khan Jan 11, 2022
b691d3d
feature: Adding Jumpstart retrieval functions (#2789)
evakravi Jan 12, 2022
00f23e6
feature: jumpstart hyperparameters and environment variables (#2850)
evakravi Jan 14, 2022
167b723
feature: script mode for model class (#2841)
evakravi Jan 19, 2022
d9d8c68
feat: Script mode support for Estimator class (#2834)
evakravi Jan 21, 2022
c03efb2
feature: jumpstart vulnerability and deprecated check (#2855)
evakravi Jan 23, 2022
63b0372
Feat: tagging jumpstart models (#2860)
evakravi Jan 24, 2022
c9aa29b
feat: Integ tests for jumpstart model and estimator (#2865)
evakravi Jan 25, 2022
423f389
feat: Hyperparameter validation (#2856)
evakravi Jan 25, 2022
b8d0bad
Syncing master-jumpstart with dev (#2887)
evakravi Feb 3, 2022
d25ed94
Merge remote-tracking branch 'origin/master-jumpstart' into dev
evakravi Feb 3, 2022
3582132
fix: git diff
evakravi Feb 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ venv/
.docker/
env/
.vscode/
**/tmp
.python-version
10 changes: 5 additions & 5 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self.validate_train_spec()
self.hyperparameter_definitions = self._parse_hyperparameters()

self.hyperparam_dict = {}
self._hyperparameters = {}
if hyperparameters:
self.set_hyperparameters(**hyperparameters)

Expand Down Expand Up @@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs):
"""Placeholder docstring"""
for k, v in kwargs.items():
value = self._validate_and_cast_hyperparameter(k, v)
self.hyperparam_dict[k] = value
self._hyperparameters[k] = value

self._validate_and_set_default_hyperparameters()

Expand All @@ -225,7 +225,7 @@ def hyperparameters(self):
The fit() method, that does the model training, calls this method to
find the hyperparameters you specified.
"""
return self.hyperparam_dict
return self._hyperparameters

def training_image_uri(self):
"""Returns the docker image to use for training.
Expand Down Expand Up @@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self):
# Check if all the required hyperparameters are set. If there is a default value
# for one, set it.
for name, definition in self.hyperparameter_definitions.items():
if name not in self.hyperparam_dict:
if name not in self._hyperparameters:
spec = definition["spec"]
if "DefaultValue" in spec:
self.hyperparam_dict[name] = spec["DefaultValue"]
self._hyperparameters[name] = spec["DefaultValue"]
elif "IsRequired" in spec and spec["IsRequired"]:
raise ValueError("Required hyperparameter: %s is not set" % name)

Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging

from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
Expand Down Expand Up @@ -158,7 +158,9 @@ def hyperparameters(self):

# remove unset keys.
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
return hyperparameters

def create_model(
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
51 changes: 51 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve environment variables for hosting containers."""

from __future__ import absolute_import

import logging
from typing import Dict

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts

logger = logging.getLogger(__name__)


def retrieve_default(
region=None,
model_id=None,
model_version=None,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

Args:
region (str): Optional. Region for which to retrieve default environment variables.
(Default: None).
model_id (str): Optional. Model ID of the model for which to
retrieve the default environment variables. (Default: None).
model_version (str): Optional. Version of the model for which to retrieve the
default environment variables. (Default: None).
Returns:
dict: the variables to use for the model.

Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError(
"Must specify `model_id` and `model_version` when retrieving environment variables."
)

return artifacts._retrieve_default_environment_variables(model_id, model_version, region)
Loading