Skip to content

breaking: rename estimator.train_image() to estimator.training_image_uri() #1787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ def hyperparameters(self):
"""
return self.hyperparam_dict

def train_image(self):
def training_image_uri(self):
"""Returns the docker image to use for training.

The fit() method, that does the model training, calls this method to
find the image to use for model training.
"""
raise RuntimeError("train_image is never meant to be called on Algorithm Estimators")
raise RuntimeError("training_image_uri is never meant to be called on Algorithm Estimators")

def enable_network_isolation(self):
"""Return True if this Estimator will need network isolation to run.
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
)
self._data_location = data_location

def train_image(self):
def training_image_uri(self):
"""Placeholder docstring"""
return image_uris.retrieve(
self.repo_name, self.sagemaker_session.boto_region_name, version=self.repo_version,
Expand Down
14 changes: 7 additions & 7 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(
self._enable_network_isolation = enable_network_isolation

@abstractmethod
def train_image(self):
def training_image_uri(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
Expand Down Expand Up @@ -329,7 +329,7 @@ def _ensure_base_job_name(self):
"""Set ``self.base_job_name`` if it is not set already."""
# honor supplied base_job_name or generate it
if self.base_job_name is None:
self.base_job_name = base_name_from_image(self.train_image())
self.base_job_name = base_name_from_image(self.training_image_uri())

def _get_or_create_name(self, name=None):
"""Generate a name based on the base job name or training image if needed.
Expand Down Expand Up @@ -507,7 +507,7 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf

def _compilation_job_name(self):
"""Placeholder docstring"""
base_name = self.base_job_name or base_name_from_image(self.train_image())
base_name = self.base_job_name or base_name_from_image(self.training_image_uri())
return name_from_base("compilation-" + base_name)

def compile_model(
Expand Down Expand Up @@ -1083,7 +1083,7 @@ def start_new(cls, estimator, inputs, experiment_config):
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
train_args["algorithm_arn"] = estimator.algorithm_arn
else:
train_args["image_uri"] = estimator.train_image()
train_args["image_uri"] = estimator.training_image_uri()

if estimator.debugger_rule_configs:
train_args["debugger_rule_configs"] = estimator.debugger_rule_configs
Expand Down Expand Up @@ -1350,7 +1350,7 @@ def __init__(
enable_network_isolation=enable_network_isolation,
)

def train_image(self):
def training_image_uri(self):
"""Returns the docker image to use for training.

The fit() method, that does the model training, calls this method to
Expand Down Expand Up @@ -1424,7 +1424,7 @@ def predict_wrapper(endpoint, session):
kwargs["enable_network_isolation"] = self.enable_network_isolation()

return Model(
image_uri or self.train_image(),
image_uri or self.training_image_uri(),
self.model_data,
role,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down Expand Up @@ -1826,7 +1826,7 @@ class constructor

return init_params

def train_image(self):
def training_image_uri(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def create_model(
"An unknown RLFramework enum was passed in. framework: {}".format(self.framework)
)

def train_image(self):
def training_image_uri(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _prepare_job_name_for_tuning(self, job_name=None):
estimator = (
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
)
base_name = base_name_from_image(estimator.train_image())
base_name = base_name_from_image(estimator.training_image_uri())
self._current_job_name = name_from_base(
base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True
)
Expand Down Expand Up @@ -1527,7 +1527,7 @@ def _prepare_training_config(
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
training_config["algorithm_arn"] = estimator.algorithm_arn
else:
training_config["image_uri"] = estimator.train_image()
training_config["image_uri"] = estimator.training_image_uri()

training_config["enable_network_isolation"] = estimator.enable_network_isolation()
training_config[
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
if job_name is not None:
estimator._current_job_name = job_name
else:
base_name = estimator.base_job_name or utils.base_name_from_image(estimator.train_image())
base_name = estimator.base_job_name or utils.base_name_from_image(
estimator.training_image_uri()
)
estimator._current_job_name = utils.name_from_base(base_name)

if estimator.output_path is None:
Expand All @@ -164,7 +166,7 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=

train_config = {
"AlgorithmSpecification": {
"TrainingImage": estimator.train_image(),
"TrainingImage": estimator.training_image_uri(),
"TrainingInputMode": estimator.input_mode,
},
"OutputDataConfig": job_config["output_config"],
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_byo_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
for prediction in result["predictions"]:
assert prediction["score"] is not None

assert estimator.train_image() == image_uri
assert estimator.training_image_uri() == image_uri
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/tensorflow/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def test_hyperparameters_no_model_dir(
assert "model_dir" not in hyperparameters


def test_train_image_custom_image(sagemaker_session):
def test_custom_image(sagemaker_session):
custom_image = "tensorflow:latest"
tf = _build_tf(sagemaker_session, image_uri=custom_image)
assert custom_image == tf.train_image()
assert custom_image == tf.training_image_uri()
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/tensorflow/test_estimator_attach.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_attach(sagemaker_session, tensorflow_training_version, tensorflow_train
assert estimator.hyperparameters() is not None
assert estimator.source_dir == "s3://some/sourcedir.tar.gz"
assert estimator.entry_point == "iris-dnn-classifier.py"
assert estimator.train_image() == training_image
assert estimator.training_image_uri() == training_image


@patch("sagemaker.utils.create_tar_file", MagicMock())
Expand Down Expand Up @@ -207,4 +207,4 @@ def test_attach_custom_image(sagemaker_session):

estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
assert estimator.image_uri == training_image
assert estimator.train_image() == training_image
assert estimator.training_image_uri() == training_image
8 changes: 5 additions & 3 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def test_model_prepare_container_def_no_instance_type_or_image(chainer_version,
assert expected_msg in str(e)


def test_train_image_default(sagemaker_session, chainer_version, chainer_py_version):
def test_training_image_default(sagemaker_session, chainer_version, chainer_py_version):
chainer = Chainer(
entry_point=SCRIPT_PATH,
role=ROLE,
Expand All @@ -426,7 +426,9 @@ def test_train_image_default(sagemaker_session, chainer_version, chainer_py_vers
py_version=chainer_py_version,
)

assert _get_full_cpu_image_uri(chainer_version, chainer_py_version) == chainer.train_image()
assert (
_get_full_cpu_image_uri(chainer_version, chainer_py_version) == chainer.training_image_uri()
)


def test_attach(sagemaker_session, chainer_version, chainer_py_version):
Expand Down Expand Up @@ -545,7 +547,7 @@ def test_attach_custom_image(sagemaker_session):

estimator = Chainer.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
assert estimator.image_uri == training_image
assert estimator.train_image() == training_image
assert estimator.training_image_uri() == training_image


@patch("sagemaker.chainer.estimator.python_deprecation_warning")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
class DummyFramework(Framework):
_framework_name = "dummy"

def train_image(self):
def training_image_uri(self):
return IMAGE_URI

def create_model(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("factorization-machines", REGION) == fm.train_image()
assert image_uris.retrieve("factorization-machines", REGION) == fm.training_image_uri()


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("ipinsights", REGION) == ipinsights.train_image()
assert image_uris.retrieve("ipinsights", REGION) == ipinsights.training_image_uri()


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def sagemaker_session():
class DummyFramework(Framework):
_framework_name = "dummy"

def train_image(self):
def training_image_uri(self):
return IMAGE_NAME

def create_model(self, role=None, model_server_workers=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("kmeans", REGION) == kmeans.train_image()
assert image_uris.retrieve("kmeans", REGION) == kmeans.training_image_uri()


@pytest.mark.parametrize("required_hyper_parameters, value", [("k", "string")])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_all_hyperparameters_classifier(sagemaker_session):

def test_image(sagemaker_session):
knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("knn", REGION) == knn.train_image()
assert image_uris.retrieve("knn", REGION) == knn.training_image_uri()


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("lda", REGION) == lda.train_image()
assert image_uris.retrieve("lda", REGION) == lda.training_image_uri()


@pytest.mark.parametrize("required_hyper_parameters, value", [("num_topics", "string")])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_linear_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("linear-learner", REGION) == lr.train_image()
assert image_uris.retrieve("linear-learner", REGION) == lr.training_image_uri()


@pytest.mark.parametrize("required_hyper_parameters, value", [("predictor_type", 0)])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def test_attach_custom_image(sagemaker_session):

estimator = MXNet.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
assert estimator.image_uri == training_image
assert estimator.train_image() == training_image
assert estimator.training_image_uri() == training_image


def test_estimator_script_mode_dont_launch_parameter_server(sagemaker_session):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_ntm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("ntm", REGION) == ntm.train_image()
assert image_uris.retrieve("ntm", REGION) == ntm.training_image_uri()


@pytest.mark.parametrize("required_hyper_parameters, value", [("num_topics", "string")])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_object2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
object2vec = Object2Vec(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("object2vec", REGION) == object2vec.train_image()
assert image_uris.retrieve("object2vec", REGION) == object2vec.training_image_uri()


@pytest.mark.parametrize("required_hyper_parameters, value", [("epochs", "string")])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("pca", REGION) == pca.train_image()
assert image_uris.retrieve("pca", REGION) == pca.training_image_uri()


@pytest.mark.parametrize("required_hyper_parameters, value", [("num_components", "string")])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def test_attach_custom_image(sagemaker_session):
estimator = PyTorch.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
assert estimator.latest_training_job.job_name == "neo"
assert estimator.image_uri == training_image
assert estimator.train_image() == training_image
assert estimator.training_image_uri() == training_image


@patch("sagemaker.pytorch.estimator.python_deprecation_warning")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_randomcutforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_all_hyperparameters(sagemaker_session):

def test_image(sagemaker_session):
randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
assert image_uris.retrieve("randomcutforest", REGION) == randomcutforest.train_image()
assert image_uris.retrieve("randomcutforest", REGION) == randomcutforest.training_image_uri()


@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)])
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_deploy_ray(sagemaker_session, ray_tensorflow_version):


@patch("sagemaker.image_uris.retrieve")
def test_train_image(retrieve_image_uri, sagemaker_session, ray_tensorflow_version):
def test_training_image_uri(retrieve_image_uri, sagemaker_session, ray_tensorflow_version):
toolkit = RLToolkit.RAY
framework = RLFramework.TENSORFLOW

Expand All @@ -408,13 +408,13 @@ def test_train_image(retrieve_image_uri, sagemaker_session, ray_tensorflow_versi
instance_type=CPU,
image_uri=image,
)
assert image == rl.train_image()
assert image == rl.training_image_uri()
retrieve_image_uri.assert_not_called()

rl = _rl_estimator(
sagemaker_session, toolkit, ray_tensorflow_version, framework, instance_type=CPU
)
assert retrieve_image_uri.return_value == rl.train_image()
assert retrieve_image_uri.return_value == rl.training_image_uri()

retrieve_image_uri.assert_called_with(
"ray-tensorflow", REGION, version=ray_tensorflow_version, instance_type=CPU
Expand Down Expand Up @@ -540,7 +540,7 @@ def test_attach_custom_image(sagemaker_session):
estimator = RLEstimator.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
assert estimator.latest_training_job.job_name == "neo"
assert estimator.image_uri == training_image
assert estimator.train_image() == training_image
assert estimator.training_image_uri() == training_image


def test_wrong_framework_format(sagemaker_session):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _create_train_job(version):
}


def test_train_image(sagemaker_session, sklearn_version):
def test_training_image_uri(sagemaker_session, sklearn_version):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
sklearn = SKLearn(
Expand All @@ -155,7 +155,7 @@ def test_train_image(sagemaker_session, sklearn_version):
source_dir=source_dir,
)

assert _get_full_cpu_image_uri(sklearn_version) == sklearn.train_image()
assert _get_full_cpu_image_uri(sklearn_version) == sklearn.training_image_uri()


def test_create_model(sagemaker_session, sklearn_version):
Expand Down Expand Up @@ -525,7 +525,7 @@ def test_attach_custom_image(sagemaker_session):

estimator = SKLearn.attach(training_job_name="neo", sagemaker_session=sagemaker_session)
assert estimator.image_uri == training_image
assert estimator.train_image() == training_image
assert estimator.training_image_uri() == training_image


def test_estimator_py2_raises(sagemaker_session, sklearn_version):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def test_fit_multi_estimators(sagemaker_session):
assert training_config_one["objective_type"] == "Minimize"
assert training_config_one["objective_metric_name"] == OBJECTIVE_METRIC_NAME
assert training_config_one["input_config"] is None
assert training_config_one["image_uri"] == estimator_one.train_image()
assert training_config_one["image_uri"] == estimator_one.training_image_uri()
assert training_config_one["metric_definitions"] == METRIC_DEFINITIONS
assert (
training_config_one["static_hyperparameters"]["sagemaker_estimator_module"]
Expand All @@ -403,7 +403,7 @@ def test_fit_multi_estimators(sagemaker_session):
assert training_config_two["objective_metric_name"] == OBJECTIVE_METRIC_NAME_TWO
assert len(training_config_two["input_config"]) == 1
assert training_config_two["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == INPUTS
assert training_config_two["image_uri"] == estimator_two.train_image()
assert training_config_two["image_uri"] == estimator_two.training_image_uri()
assert training_config_two["metric_definitions"] is None
assert training_config_two["static_hyperparameters"]["mini_batch_size"] == "4000"
_assert_parameter_ranges(
Expand Down
Loading