Skip to content

Commit 23af3b1

Browse files
authored
breaking: rename estimator.train_image() to estimator.training_image_uri() (#1787)
1 parent 0d06276 commit 23af3b1

28 files changed

+52
-48
lines changed

src/sagemaker/algorithm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,13 @@ def hyperparameters(self):
229229
"""
230230
return self.hyperparam_dict
231231

232-
def train_image(self):
232+
def training_image_uri(self):
233233
"""Returns the docker image to use for training.
234234
235235
The fit() method, that does the model training, calls this method to
236236
find the image to use for model training.
237237
"""
238-
raise RuntimeError("train_image is never meant to be called on Algorithm Estimators")
238+
raise RuntimeError("training_image_uri is never meant to be called on Algorithm Estimators")
239239

240240
def enable_network_isolation(self):
241241
"""Return True if this Estimator will need network isolation to run.

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
)
9292
self._data_location = data_location
9393

94-
def train_image(self):
94+
def training_image_uri(self):
9595
"""Placeholder docstring"""
9696
return image_uris.retrieve(
9797
self.repo_name, self.sagemaker_session.boto_region_name, version=self.repo_version,

src/sagemaker/estimator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def __init__(
285285
self._enable_network_isolation = enable_network_isolation
286286

287287
@abstractmethod
288-
def train_image(self):
288+
def training_image_uri(self):
289289
"""Return the Docker image to use for training.
290290
291291
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
@@ -329,7 +329,7 @@ def _ensure_base_job_name(self):
329329
"""Set ``self.base_job_name`` if it is not set already."""
330330
# honor supplied base_job_name or generate it
331331
if self.base_job_name is None:
332-
self.base_job_name = base_name_from_image(self.train_image())
332+
self.base_job_name = base_name_from_image(self.training_image_uri())
333333

334334
def _get_or_create_name(self, name=None):
335335
"""Generate a name based on the base job name or training image if needed.
@@ -507,7 +507,7 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
507507

508508
def _compilation_job_name(self):
509509
"""Placeholder docstring"""
510-
base_name = self.base_job_name or base_name_from_image(self.train_image())
510+
base_name = self.base_job_name or base_name_from_image(self.training_image_uri())
511511
return name_from_base("compilation-" + base_name)
512512

513513
def compile_model(
@@ -1083,7 +1083,7 @@ def start_new(cls, estimator, inputs, experiment_config):
10831083
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
10841084
train_args["algorithm_arn"] = estimator.algorithm_arn
10851085
else:
1086-
train_args["image_uri"] = estimator.train_image()
1086+
train_args["image_uri"] = estimator.training_image_uri()
10871087

10881088
if estimator.debugger_rule_configs:
10891089
train_args["debugger_rule_configs"] = estimator.debugger_rule_configs
@@ -1350,7 +1350,7 @@ def __init__(
13501350
enable_network_isolation=enable_network_isolation,
13511351
)
13521352

1353-
def train_image(self):
1353+
def training_image_uri(self):
13541354
"""Returns the docker image to use for training.
13551355
13561356
The fit() method, that does the model training, calls this method to
@@ -1424,7 +1424,7 @@ def predict_wrapper(endpoint, session):
14241424
kwargs["enable_network_isolation"] = self.enable_network_isolation()
14251425

14261426
return Model(
1427-
image_uri or self.train_image(),
1427+
image_uri or self.training_image_uri(),
14281428
self.model_data,
14291429
role,
14301430
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -1826,7 +1826,7 @@ class constructor
18261826

18271827
return init_params
18281828

1829-
def train_image(self):
1829+
def training_image_uri(self):
18301830
"""Return the Docker image to use for training.
18311831
18321832
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def create_model(
268268
"An unknown RLFramework enum was passed in. framework: {}".format(self.framework)
269269
)
270270

271-
def train_image(self):
271+
def training_image_uri(self):
272272
"""Return the Docker image to use for training.
273273
274274
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does

src/sagemaker/tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _prepare_job_name_for_tuning(self, job_name=None):
317317
estimator = (
318318
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
319319
)
320-
base_name = base_name_from_image(estimator.train_image())
320+
base_name = base_name_from_image(estimator.training_image_uri())
321321
self._current_job_name = name_from_base(
322322
base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True
323323
)
@@ -1527,7 +1527,7 @@ def _prepare_training_config(
15271527
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
15281528
training_config["algorithm_arn"] = estimator.algorithm_arn
15291529
else:
1530-
training_config["image_uri"] = estimator.train_image()
1530+
training_config["image_uri"] = estimator.training_image_uri()
15311531

15321532
training_config["enable_network_isolation"] = estimator.enable_network_isolation()
15331533
training_config[

src/sagemaker/workflow/airflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
149149
if job_name is not None:
150150
estimator._current_job_name = job_name
151151
else:
152-
base_name = estimator.base_job_name or utils.base_name_from_image(estimator.train_image())
152+
base_name = estimator.base_job_name or utils.base_name_from_image(
153+
estimator.training_image_uri()
154+
)
153155
estimator._current_job_name = utils.name_from_base(base_name)
154156

155157
if estimator.output_path is None:
@@ -164,7 +166,7 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
164166

165167
train_config = {
166168
"AlgorithmSpecification": {
167-
"TrainingImage": estimator.train_image(),
169+
"TrainingImage": estimator.training_image_uri(),
168170
"TrainingInputMode": estimator.input_mode,
169171
},
170172
"OutputDataConfig": job_config["output_config"],

tests/integ/test_byo_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,4 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
151151
for prediction in result["predictions"]:
152152
assert prediction["score"] is not None
153153

154-
assert estimator.train_image() == image_uri
154+
assert estimator.training_image_uri() == image_uri

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def test_hyperparameters_no_model_dir(
521521
assert "model_dir" not in hyperparameters
522522

523523

524-
def test_train_image_custom_image(sagemaker_session):
524+
def test_custom_image(sagemaker_session):
525525
custom_image = "tensorflow:latest"
526526
tf = _build_tf(sagemaker_session, image_uri=custom_image)
527-
assert custom_image == tf.train_image()
527+
assert custom_image == tf.training_image_uri()

tests/unit/sagemaker/tensorflow/test_estimator_attach.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_attach(sagemaker_session, tensorflow_training_version, tensorflow_train
9898
assert estimator.hyperparameters() is not None
9999
assert estimator.source_dir == "s3://some/sourcedir.tar.gz"
100100
assert estimator.entry_point == "iris-dnn-classifier.py"
101-
assert estimator.train_image() == training_image
101+
assert estimator.training_image_uri() == training_image
102102

103103

104104
@patch("sagemaker.utils.create_tar_file", MagicMock())
@@ -207,4 +207,4 @@ def test_attach_custom_image(sagemaker_session):
207207

208208
estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
209209
assert estimator.image_uri == training_image
210-
assert estimator.train_image() == training_image
210+
assert estimator.training_image_uri() == training_image

tests/unit/test_chainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def test_model_prepare_container_def_no_instance_type_or_image(chainer_version,
415415
assert expected_msg in str(e)
416416

417417

418-
def test_train_image_default(sagemaker_session, chainer_version, chainer_py_version):
418+
def test_training_image_default(sagemaker_session, chainer_version, chainer_py_version):
419419
chainer = Chainer(
420420
entry_point=SCRIPT_PATH,
421421
role=ROLE,
@@ -426,7 +426,9 @@ def test_train_image_default(sagemaker_session, chainer_version, chainer_py_vers
426426
py_version=chainer_py_version,
427427
)
428428

429-
assert _get_full_cpu_image_uri(chainer_version, chainer_py_version) == chainer.train_image()
429+
assert (
430+
_get_full_cpu_image_uri(chainer_version, chainer_py_version) == chainer.training_image_uri()
431+
)
430432

431433

432434
def test_attach(sagemaker_session, chainer_version, chainer_py_version):
@@ -545,7 +547,7 @@ def test_attach_custom_image(sagemaker_session):
545547

546548
estimator = Chainer.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
547549
assert estimator.image_uri == training_image
548-
assert estimator.train_image() == training_image
550+
assert estimator.training_image_uri() == training_image
549551

550552

551553
@patch("sagemaker.chainer.estimator.python_deprecation_warning")

tests/unit/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
class DummyFramework(Framework):
110110
_framework_name = "dummy"
111111

112-
def train_image(self):
112+
def training_image_uri(self):
113113
return IMAGE_URI
114114

115115
def create_model(

tests/unit/test_fm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_all_hyperparameters(sagemaker_session):
147147

148148
def test_image(sagemaker_session):
149149
fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
150-
assert image_uris.retrieve("factorization-machines", REGION) == fm.train_image()
150+
assert image_uris.retrieve("factorization-machines", REGION) == fm.training_image_uri()
151151

152152

153153
@pytest.mark.parametrize(

tests/unit/test_ipinsights.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_all_hyperparameters(sagemaker_session):
120120

121121
def test_image(sagemaker_session):
122122
ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
123-
assert image_uris.retrieve("ipinsights", REGION) == ipinsights.train_image()
123+
assert image_uris.retrieve("ipinsights", REGION) == ipinsights.training_image_uri()
124124

125125

126126
@pytest.mark.parametrize(

tests/unit/test_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def sagemaker_session():
8383
class DummyFramework(Framework):
8484
_framework_name = "dummy"
8585

86-
def train_image(self):
86+
def training_image_uri(self):
8787
return IMAGE_NAME
8888

8989
def create_model(self, role=None, model_server_workers=None):

tests/unit/test_kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_all_hyperparameters(sagemaker_session):
112112

113113
def test_image(sagemaker_session):
114114
kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
115-
assert image_uris.retrieve("kmeans", REGION) == kmeans.train_image()
115+
assert image_uris.retrieve("kmeans", REGION) == kmeans.training_image_uri()
116116

117117

118118
@pytest.mark.parametrize("required_hyper_parameters, value", [("k", "string")])

tests/unit/test_knn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_all_hyperparameters_classifier(sagemaker_session):
146146

147147
def test_image(sagemaker_session):
148148
knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
149-
assert image_uris.retrieve("knn", REGION) == knn.train_image()
149+
assert image_uris.retrieve("knn", REGION) == knn.training_image_uri()
150150

151151

152152
@pytest.mark.parametrize(

tests/unit/test_lda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_all_hyperparameters(sagemaker_session):
9696

9797
def test_image(sagemaker_session):
9898
lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
99-
assert image_uris.retrieve("lda", REGION) == lda.train_image()
99+
assert image_uris.retrieve("lda", REGION) == lda.training_image_uri()
100100

101101

102102
@pytest.mark.parametrize("required_hyper_parameters, value", [("num_topics", "string")])

tests/unit/test_linear_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_all_hyperparameters(sagemaker_session):
179179

180180
def test_image(sagemaker_session):
181181
lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
182-
assert image_uris.retrieve("linear-learner", REGION) == lr.train_image()
182+
assert image_uris.retrieve("linear-learner", REGION) == lr.training_image_uri()
183183

184184

185185
@pytest.mark.parametrize("required_hyper_parameters, value", [("predictor_type", 0)])

tests/unit/test_mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def test_attach_custom_image(sagemaker_session):
670670

671671
estimator = MXNet.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
672672
assert estimator.image_uri == training_image
673-
assert estimator.train_image() == training_image
673+
assert estimator.training_image_uri() == training_image
674674

675675

676676
def test_estimator_script_mode_dont_launch_parameter_server(sagemaker_session):

tests/unit/test_ntm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_all_hyperparameters(sagemaker_session):
115115

116116
def test_image(sagemaker_session):
117117
ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
118-
assert image_uris.retrieve("ntm", REGION) == ntm.train_image()
118+
assert image_uris.retrieve("ntm", REGION) == ntm.training_image_uri()
119119

120120

121121
@pytest.mark.parametrize("required_hyper_parameters, value", [("num_topics", "string")])

tests/unit/test_object2vec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_all_hyperparameters(sagemaker_session):
144144

145145
def test_image(sagemaker_session):
146146
object2vec = Object2Vec(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
147-
assert image_uris.retrieve("object2vec", REGION) == object2vec.train_image()
147+
assert image_uris.retrieve("object2vec", REGION) == object2vec.training_image_uri()
148148

149149

150150
@pytest.mark.parametrize("required_hyper_parameters, value", [("epochs", "string")])

tests/unit/test_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_all_hyperparameters(sagemaker_session):
101101

102102
def test_image(sagemaker_session):
103103
pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
104-
assert image_uris.retrieve("pca", REGION) == pca.train_image()
104+
assert image_uris.retrieve("pca", REGION) == pca.training_image_uri()
105105

106106

107107
@pytest.mark.parametrize("required_hyper_parameters, value", [("num_components", "string")])

tests/unit/test_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def test_attach_custom_image(sagemaker_session):
537537
estimator = PyTorch.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
538538
assert estimator.latest_training_job.job_name == "neo"
539539
assert estimator.image_uri == training_image
540-
assert estimator.train_image() == training_image
540+
assert estimator.training_image_uri() == training_image
541541

542542

543543
@patch("sagemaker.pytorch.estimator.python_deprecation_warning")

tests/unit/test_randomcutforest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_all_hyperparameters(sagemaker_session):
107107

108108
def test_image(sagemaker_session):
109109
randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
110-
assert image_uris.retrieve("randomcutforest", REGION) == randomcutforest.train_image()
110+
assert image_uris.retrieve("randomcutforest", REGION) == randomcutforest.training_image_uri()
111111

112112

113113
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)])

tests/unit/test_rl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def test_deploy_ray(sagemaker_session, ray_tensorflow_version):
395395

396396

397397
@patch("sagemaker.image_uris.retrieve")
398-
def test_train_image(retrieve_image_uri, sagemaker_session, ray_tensorflow_version):
398+
def test_training_image_uri(retrieve_image_uri, sagemaker_session, ray_tensorflow_version):
399399
toolkit = RLToolkit.RAY
400400
framework = RLFramework.TENSORFLOW
401401

@@ -408,13 +408,13 @@ def test_train_image(retrieve_image_uri, sagemaker_session, ray_tensorflow_versi
408408
instance_type=CPU,
409409
image_uri=image,
410410
)
411-
assert image == rl.train_image()
411+
assert image == rl.training_image_uri()
412412
retrieve_image_uri.assert_not_called()
413413

414414
rl = _rl_estimator(
415415
sagemaker_session, toolkit, ray_tensorflow_version, framework, instance_type=CPU
416416
)
417-
assert retrieve_image_uri.return_value == rl.train_image()
417+
assert retrieve_image_uri.return_value == rl.training_image_uri()
418418

419419
retrieve_image_uri.assert_called_with(
420420
"ray-tensorflow", REGION, version=ray_tensorflow_version, instance_type=CPU
@@ -540,7 +540,7 @@ def test_attach_custom_image(sagemaker_session):
540540
estimator = RLEstimator.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
541541
assert estimator.latest_training_job.job_name == "neo"
542542
assert estimator.image_uri == training_image
543-
assert estimator.train_image() == training_image
543+
assert estimator.training_image_uri() == training_image
544544

545545

546546
def test_wrong_framework_format(sagemaker_session):

tests/unit/test_sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _create_train_job(version):
140140
}
141141

142142

143-
def test_train_image(sagemaker_session, sklearn_version):
143+
def test_training_image_uri(sagemaker_session, sklearn_version):
144144
container_log_level = '"logging.INFO"'
145145
source_dir = "s3://mybucket/source"
146146
sklearn = SKLearn(
@@ -155,7 +155,7 @@ def test_train_image(sagemaker_session, sklearn_version):
155155
source_dir=source_dir,
156156
)
157157

158-
assert _get_full_cpu_image_uri(sklearn_version) == sklearn.train_image()
158+
assert _get_full_cpu_image_uri(sklearn_version) == sklearn.training_image_uri()
159159

160160

161161
def test_create_model(sagemaker_session, sklearn_version):
@@ -525,7 +525,7 @@ def test_attach_custom_image(sagemaker_session):
525525

526526
estimator = SKLearn.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
527527
assert estimator.image_uri == training_image
528-
assert estimator.train_image() == training_image
528+
assert estimator.training_image_uri() == training_image
529529

530530

531531
def test_estimator_py2_raises(sagemaker_session, sklearn_version):

tests/unit/test_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def test_fit_multi_estimators(sagemaker_session):
386386
assert training_config_one["objective_type"] == "Minimize"
387387
assert training_config_one["objective_metric_name"] == OBJECTIVE_METRIC_NAME
388388
assert training_config_one["input_config"] is None
389-
assert training_config_one["image_uri"] == estimator_one.train_image()
389+
assert training_config_one["image_uri"] == estimator_one.training_image_uri()
390390
assert training_config_one["metric_definitions"] == METRIC_DEFINITIONS
391391
assert (
392392
training_config_one["static_hyperparameters"]["sagemaker_estimator_module"]
@@ -403,7 +403,7 @@ def test_fit_multi_estimators(sagemaker_session):
403403
assert training_config_two["objective_metric_name"] == OBJECTIVE_METRIC_NAME_TWO
404404
assert len(training_config_two["input_config"]) == 1
405405
assert training_config_two["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == INPUTS
406-
assert training_config_two["image_uri"] == estimator_two.train_image()
406+
assert training_config_two["image_uri"] == estimator_two.training_image_uri()
407407
assert training_config_two["metric_definitions"] is None
408408
assert training_config_two["static_hyperparameters"]["mini_batch_size"] == "4000"
409409
_assert_parameter_ranges(

0 commit comments

Comments
 (0)