Skip to content

Commit a245641

Browse files
authored
Merge branch 'dev' into fix-cond-step
2 parents 7153d6d + 8215dd2 commit a245641

File tree

9 files changed

+149
-46
lines changed

9 files changed

+149
-46
lines changed

doc/overview.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
746746
.. code:: python
747747
748748
from sagemaker.model import Model
749+
from sagemaker.predictor import Predictor
749750
from sagemaker.session import Session
750751
751752
# Create the SageMaker model instance
@@ -755,6 +756,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
755756
   source_dir=script_uri,
756757
   entry_point="inference.py",
757758
   role=Session().get_caller_identity_arn(),
759+
   predictor_cls=Predictor,
758760
)
759761
760762
Save the output from deploying the model to a variable named
@@ -766,12 +768,9 @@ Deployment may take about 5 minutes.
766768

767769
.. code:: python
768770
769-
from sagemaker.predictor import Predictor
770-
771771
predictor = model.deploy(
772772
   initial_instance_count=instance_count,
773773
   instance_type=instance_type,
774-
   predictor_cls=Predictor
775774
)
776775
777776
Because ``catboost`` and ``lightgbm`` rely on the PyTorch Deep Learning Containers

src/sagemaker/jumpstart/cache.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module defines the JumpStartModelsCache class."""
1414
from __future__ import absolute_import
1515
import datetime
16+
from difflib import get_close_matches
1617
from typing import List, Optional
1718
import json
1819
import boto3
@@ -204,14 +205,34 @@ def _get_manifest_key_from_model_id_semantic_version(
204205
sm_version_to_use = sm_version_to_use_list[0]
205206

206207
error_msg = (
207-
f"Unable to find model manifest for {model_id} with version {version} "
208-
f"compatible with your SageMaker version ({sm_version}). "
208+
f"Unable to find model manifest for '{model_id}' with version '{version}' "
209+
f"compatible with your SageMaker version ('{sm_version}'). "
209210
f"Consider upgrading your SageMaker library to at least version "
210-
f"{sm_version_to_use} so you can use version "
211-
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
211+
f"'{sm_version_to_use}' so you can use version "
212+
f"'{model_version_to_use_incompatible_with_sagemaker}' of '{model_id}'."
212213
)
213214
raise KeyError(error_msg)
214-
error_msg = f"Unable to find model manifest for {model_id} with version {version}."
215+
216+
error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
217+
error_msg += (
218+
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html"
219+
" for updated list of models. "
220+
)
221+
222+
other_model_id_version = self._select_version(
223+
"*", versions_incompatible_with_sagemaker
224+
) # all versions here are incompatible with sagemaker
225+
if other_model_id_version is not None:
226+
error_msg += (
227+
f"Consider using model ID '{model_id}' with version "
228+
f"'{other_model_id_version}'."
229+
)
230+
231+
else:
232+
possible_model_ids = [header.model_id for header in manifest.values()]
233+
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
234+
error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
235+
215236
raise KeyError(error_msg)
216237

217238
def _get_file_from_s3(

src/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,5 @@
122122
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
123123

124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125+
126+
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"

src/sagemaker/jumpstart/types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ def from_json(self, json_obj: Dict[str, str]) -> None:
135135
class JumpStartECRSpecs(JumpStartDataHolderType):
136136
"""Data class for JumpStart ECR specs."""
137137

138-
__slots__ = {
138+
__slots__ = [
139139
"framework",
140140
"framework_version",
141141
"py_version",
142142
"huggingface_transformers_version",
143-
}
143+
]
144144

145145
def __init__(self, spec: Dict[str, Any]):
146146
"""Initializes a JumpStartECRSpecs object from its json representation.
@@ -173,7 +173,7 @@ def to_json(self) -> Dict[str, Any]:
173173
class JumpStartHyperparameter(JumpStartDataHolderType):
174174
"""Data class for JumpStart hyperparameter definition in the training container."""
175175

176-
__slots__ = {
176+
__slots__ = [
177177
"name",
178178
"type",
179179
"options",
@@ -183,7 +183,7 @@ class JumpStartHyperparameter(JumpStartDataHolderType):
183183
"max",
184184
"exclusive_min",
185185
"exclusive_max",
186-
}
186+
]
187187

188188
def __init__(self, spec: Dict[str, Any]):
189189
"""Initializes a JumpStartHyperparameter object from its json representation.
@@ -234,12 +234,12 @@ def to_json(self) -> Dict[str, Any]:
234234
class JumpStartEnvironmentVariable(JumpStartDataHolderType):
235235
"""Data class for JumpStart environment variable definitions in the hosting container."""
236236

237-
__slots__ = {
237+
__slots__ = [
238238
"name",
239239
"type",
240240
"default",
241241
"scope",
242-
}
242+
]
243243

244244
def __init__(self, spec: Dict[str, Any]):
245245
"""Initializes a JumpStartEnvironmentVariable object from its json representation.

src/sagemaker/jumpstart/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from typing import Dict, List, Optional
1718
from urllib.parse import urlparse
1819
from packaging.version import Version
@@ -60,6 +61,14 @@ def get_jumpstart_content_bucket(region: str) -> str:
6061
Raises:
6162
RuntimeError: If JumpStart is not launched in ``region``.
6263
"""
64+
65+
if (
66+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
67+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
68+
):
69+
bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
70+
LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
71+
return bucket_override
6372
try:
6473
return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket
6574
except KeyError:

src/sagemaker/jumpstart/validators.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _validate_hyperparameter(
4949

5050
if len(hyperparameter_spec) > 1:
5151
raise JumpStartHyperparametersError(
52-
f"Unable to perform validation -- found multiple hyperparameter "
52+
"Unable to perform validation -- found multiple hyperparameter "
5353
f"'{hyperparameter_name}' in model specs."
5454
)
5555

@@ -76,35 +76,35 @@ def _validate_hyperparameter(
7676
if hyperparameter_value not in hyperparameter_spec.options:
7777
raise JumpStartHyperparametersError(
7878
f"Hyperparameter '{hyperparameter_name}' must have one of the following "
79-
f"values: {', '.join(hyperparameter_spec.options)}"
79+
f"values: {', '.join(hyperparameter_spec.options)}."
8080
)
8181

8282
if hasattr(hyperparameter_spec, "min"):
8383
if len(hyperparameter_value) < hyperparameter_spec.min:
8484
raise JumpStartHyperparametersError(
8585
f"Hyperparameter '{hyperparameter_name}' must have length no less than "
86-
f"{hyperparameter_spec.min}"
86+
f"{hyperparameter_spec.min}."
8787
)
8888

8989
if hasattr(hyperparameter_spec, "exclusive_min"):
9090
if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min:
9191
raise JumpStartHyperparametersError(
9292
f"Hyperparameter '{hyperparameter_name}' must have length greater than "
93-
f"{hyperparameter_spec.exclusive_min}"
93+
f"{hyperparameter_spec.exclusive_min}."
9494
)
9595

9696
if hasattr(hyperparameter_spec, "max"):
9797
if len(hyperparameter_value) > hyperparameter_spec.max:
9898
raise JumpStartHyperparametersError(
9999
f"Hyperparameter '{hyperparameter_name}' must have length no greater than "
100-
f"{hyperparameter_spec.max}"
100+
f"{hyperparameter_spec.max}."
101101
)
102102

103103
if hasattr(hyperparameter_spec, "exclusive_max"):
104104
if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max:
105105
raise JumpStartHyperparametersError(
106106
f"Hyperparameter '{hyperparameter_name}' must have length less than "
107-
f"{hyperparameter_spec.exclusive_max}"
107+
f"{hyperparameter_spec.exclusive_max}."
108108
)
109109

110110
# validate numeric types
@@ -125,35 +125,35 @@ def _validate_hyperparameter(
125125
if not hyperparameter_value_str[start_index:].isdigit():
126126
raise JumpStartHyperparametersError(
127127
f"Hyperparameter '{hyperparameter_name}' must be integer type "
128-
"('{hyperparameter_value}')."
128+
f"('{hyperparameter_value}')."
129129
)
130130

131131
if hasattr(hyperparameter_spec, "min"):
132132
if numeric_hyperparam_value < hyperparameter_spec.min:
133133
raise JumpStartHyperparametersError(
134134
f"Hyperparameter '{hyperparameter_name}' can be no less than "
135-
"{hyperparameter_spec.min}."
135+
f"{hyperparameter_spec.min}."
136136
)
137137

138138
if hasattr(hyperparameter_spec, "max"):
139139
if numeric_hyperparam_value > hyperparameter_spec.max:
140140
raise JumpStartHyperparametersError(
141141
f"Hyperparameter '{hyperparameter_name}' can be no greater than "
142-
"{hyperparameter_spec.max}."
142+
f"{hyperparameter_spec.max}."
143143
)
144144

145145
if hasattr(hyperparameter_spec, "exclusive_min"):
146146
if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min:
147147
raise JumpStartHyperparametersError(
148148
f"Hyperparameter '{hyperparameter_name}' must be greater than "
149-
"{hyperparameter_spec.exclusive_min}."
149+
f"{hyperparameter_spec.exclusive_min}."
150150
)
151151

152152
if hasattr(hyperparameter_spec, "exclusive_max"):
153153
if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max:
154154
raise JumpStartHyperparametersError(
155155
f"Hyperparameter '{hyperparameter_name}' must be less than "
156-
"{hyperparameter_spec.exclusive_max}."
156+
f"{hyperparameter_spec.exclusive_max}."
157157
)
158158

159159

0 commit comments

Comments
 (0)