Skip to content

Commit 3ba55a3

Browse files
authored
Merge branch 'zwei' into master
2 parents 385ef32 + 4293c26 commit 3ba55a3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+903
-606
lines changed

doc/v2.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ To view logs after attaching a training job to an estimator, use :func:`sagemake
203203
until the completion of the Hyperparameter Tuning Job or Batch Transform Job, respectively.
204204
To make the function non-blocking, use ``wait=False``.
205205

206+
XGBoost Predictor
207+
-----------------
208+
209+
The default serializer of ``sagemaker.xgboost.model.XGBoostPredictor`` has been changed from ``NumpySerializer`` to ``LibSVMSerializer``.
210+
211+
206212
Parameter and Class Name Changes
207213
================================
208214

@@ -263,6 +269,8 @@ The follow serializer/deserializer classes have been renamed and/or moved:
263269
| ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` |
264270
+--------------------------------------------------------+-------------------------------------------------------+
265271

272+
``sagemaker.serializers.LibSVMSerializer`` has been added in v2.0.
273+
266274
``distributions``
267275
~~~~~~~~~~~~~~~~~
268276

src/sagemaker/algorithm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,13 @@ def hyperparameters(self):
229229
"""
230230
return self.hyperparam_dict
231231

232-
def train_image(self):
232+
def training_image_uri(self):
233233
"""Returns the docker image to use for training.
234234
235235
The fit() method, that does the model training, calls this method to
236236
find the image to use for model training.
237237
"""
238-
raise RuntimeError("train_image is never meant to be called on Algorithm Estimators")
238+
raise RuntimeError("training_image_uri is never meant to be called on Algorithm Estimators")
239239

240240
def enable_network_isolation(self):
241241
"""Return True if this Estimator will need network isolation to run.

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
)
9292
self._data_location = data_location
9393

94-
def train_image(self):
94+
def training_image_uri(self):
9595
"""Placeholder docstring"""
9696
return image_uris.retrieve(
9797
self.repo_name, self.sagemaker_session.boto_region_name, version=self.repo_version,

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
3636
modifiers.training_params.TrainPrefixRemover(),
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
38+
modifiers.training_input.ShuffleConfigModuleRenamer(),
3839
modifiers.serde.SerdeConstructorRenamer(),
3940
]
4041

@@ -51,6 +52,7 @@
5152
modifiers.predictors.PredictorImportFromRenamer(),
5253
modifiers.tfs.TensorFlowServingImportFromRenamer(),
5354
modifiers.training_input.TrainingInputImportFromRenamer(),
55+
modifiers.training_input.ShuffleConfigImportFromRenamer(),
5456
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
5557
modifiers.serde.SerdeImportFromPredictorRenamer(),
5658
]

src/sagemaker/cli/compatibility/v2/modifiers/training_input.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,73 @@ def modify_node(self, node):
100100
if node.module == "sagemaker.session":
101101
node.module = "sagemaker.inputs"
102102
return node
103+
104+
105+
class ShuffleConfigModuleRenamer(Modifier):
106+
"""A class to change ``ShuffleConfig`` usage to use ``sagemaker.inputs.ShuffleConfig``."""
107+
108+
def node_should_be_modified(self, node):
109+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
110+
111+
This looks for the following calls:
112+
113+
- ``sagemaker.session.ShuffleConfig``
114+
- ``session.ShuffleConfig``
115+
116+
Args:
117+
node (ast.Call): a node that represents a function call. For more,
118+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
119+
120+
Returns:
121+
bool: If the ``ast.Call`` instantiates a class of interest.
122+
"""
123+
if isinstance(node.func, ast.Name):
124+
return False
125+
126+
return matching.matches_name_or_namespaces(
127+
node, "ShuffleConfig", ("sagemaker.session", "session")
128+
)
129+
130+
def modify_node(self, node):
131+
"""Modifies the ``ast.Call`` node to call ``sagemaker.inputs.ShuffleConfig``.
132+
133+
Args:
134+
node (ast.Call): a node that represents a ``sagemaker.session.ShuffleConfig``
135+
constructor.
136+
137+
Returns:
138+
ast.Call: the original node, with its namespace changed to use the ``inputs`` module.
139+
"""
140+
_rename_namespace(node, "session")
141+
return node
142+
143+
144+
class ShuffleConfigImportFromRenamer(Modifier):
145+
"""A class to update import statements of ``ShuffleConfig``."""
146+
147+
def node_should_be_modified(self, node):
148+
"""Checks if the import statement imports ``sagemaker.session.ShuffleConfig``.
149+
150+
Args:
151+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
152+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
153+
154+
Returns:
155+
bool: If the import statement imports ``sagemaker.session.ShuffleConfig``.
156+
"""
157+
return node.module == "sagemaker.session" and any(
158+
name.name == "ShuffleConfig" for name in node.names
159+
)
160+
161+
def modify_node(self, node):
162+
"""Changes the ``ast.ImportFrom`` node's namespace to ``sagemaker.inputs``.
163+
164+
Args:
165+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
166+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
167+
168+
Returns:
169+
ast.ImportFrom: the original node, with its module modified to ``"sagemaker.inputs"``.
170+
"""
171+
node.module = "sagemaker.inputs"
172+
return node

src/sagemaker/cli/framework_upgrade.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def get_latest_values(existing_content, scope=None):
4141
)
4242

4343
latest_version = list(existing_content["versions"].keys())[-1]
44-
registries = existing_content["versions"][latest_version]["registries"]
45-
py_versions = existing_content["versions"][latest_version]["py_versions"]
46-
repository = existing_content["versions"][latest_version]["repository"]
44+
registries = existing_content["versions"][latest_version].get("registries", None)
45+
py_versions = existing_content["versions"][latest_version].get("py_versions", None)
46+
repository = existing_content["versions"][latest_version].get("repository", None)
4747

4848
return registries, py_versions, repository
4949

@@ -92,8 +92,9 @@ def add_dlc_framework_version(
9292
new_version = {
9393
"registries": registries,
9494
"repository": repository,
95-
"py_versions": py_versions,
9695
}
96+
if py_versions:
97+
new_version["py_versions"] = py_versions
9798
existing_content[scope]["versions"][full_version] = new_version
9899

99100

@@ -128,10 +129,11 @@ def add_algo_version(
128129
existing_content["scope"].append(scope)
129130

130131
new_version = {
131-
"py_versions": py_versions,
132132
"registries": registries,
133133
"repository": repository,
134134
}
135+
if py_versions:
136+
new_version["py_versions"] = py_versions
135137
if tag_prefix:
136138
new_version["tag_prefix"] = tag_prefix
137139
existing_content["versions"][full_version] = new_version
@@ -171,7 +173,8 @@ def add_version(
171173
py_versions (str): Supported Python versions (e.g. "py3,py37").
172174
tag_prefix (str): Algorithm image's tag prefix.
173175
"""
174-
py_versions = py_versions.split(",")
176+
if py_versions:
177+
py_versions = py_versions.split(",")
175178
processors = processors.split(",")
176179
latest_registries, latest_py_versions, latest_repository = get_latest_values(
177180
existing_content, scope

src/sagemaker/estimator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def __init__(
285285
self._enable_network_isolation = enable_network_isolation
286286

287287
@abstractmethod
288-
def train_image(self):
288+
def training_image_uri(self):
289289
"""Return the Docker image to use for training.
290290
291291
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
@@ -329,7 +329,7 @@ def _ensure_base_job_name(self):
329329
"""Set ``self.base_job_name`` if it is not set already."""
330330
# honor supplied base_job_name or generate it
331331
if self.base_job_name is None:
332-
self.base_job_name = base_name_from_image(self.train_image())
332+
self.base_job_name = base_name_from_image(self.training_image_uri())
333333

334334
def _get_or_create_name(self, name=None):
335335
"""Generate a name based on the base job name or training image if needed.
@@ -507,7 +507,7 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
507507

508508
def _compilation_job_name(self):
509509
"""Placeholder docstring"""
510-
base_name = self.base_job_name or base_name_from_image(self.train_image())
510+
base_name = self.base_job_name or base_name_from_image(self.training_image_uri())
511511
return name_from_base("compilation-" + base_name)
512512

513513
def compile_model(
@@ -1083,7 +1083,7 @@ def start_new(cls, estimator, inputs, experiment_config):
10831083
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
10841084
train_args["algorithm_arn"] = estimator.algorithm_arn
10851085
else:
1086-
train_args["image_uri"] = estimator.train_image()
1086+
train_args["image_uri"] = estimator.training_image_uri()
10871087

10881088
if estimator.debugger_rule_configs:
10891089
train_args["debugger_rule_configs"] = estimator.debugger_rule_configs
@@ -1350,7 +1350,7 @@ def __init__(
13501350
enable_network_isolation=enable_network_isolation,
13511351
)
13521352

1353-
def train_image(self):
1353+
def training_image_uri(self):
13541354
"""Returns the docker image to use for training.
13551355
13561356
The fit() method, that does the model training, calls this method to
@@ -1424,7 +1424,7 @@ def predict_wrapper(endpoint, session):
14241424
kwargs["enable_network_isolation"] = self.enable_network_isolation()
14251425

14261426
return Model(
1427-
image_uri or self.train_image(),
1427+
image_uri or self.training_image_uri(),
14281428
self.model_data,
14291429
role,
14301430
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -1826,7 +1826,7 @@ class constructor
18261826

18271827
return init_params
18281828

1829-
def train_image(self):
1829+
def training_image_uri(self):
18301830
"""Return the Docker image to use for training.
18311831
18321832
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does

src/sagemaker/fw_utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -49,40 +49,6 @@
4949
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
5050

5151

52-
def is_version_equal_or_higher(lowest_version, framework_version):
53-
"""Determine whether the ``framework_version`` is equal to or higher than
54-
``lowest_version``
55-
56-
Args:
57-
lowest_version (List[int]): lowest version represented in an integer
58-
list
59-
framework_version (str): framework version string
60-
61-
Returns:
62-
bool: Whether or not ``framework_version`` is equal to or higher than
63-
``lowest_version``
64-
"""
65-
version_list = [int(s) for s in framework_version.split(".")]
66-
return version_list >= lowest_version[0 : len(version_list)]
67-
68-
69-
def is_version_equal_or_lower(highest_version, framework_version):
70-
"""Determine whether the ``framework_version`` is equal to or lower than
71-
``highest_version``
72-
73-
Args:
74-
highest_version (List[int]): highest version represented in an integer
75-
list
76-
framework_version (str): framework version string
77-
78-
Returns:
79-
bool: Whether or not ``framework_version`` is equal to or lower than
80-
``highest_version``
81-
"""
82-
version_list = [int(s) for s in framework_version.split(".")]
83-
return version_list <= highest_version[0 : len(version_list)]
84-
85-
8652
def validate_source_dir(script, directory):
8753
"""Validate that the source directory exists and it contains the user script
8854
Args:

0 commit comments

Comments
 (0)