Skip to content

fix: allow custom image when calling deploy or create_model with various frameworks #1347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def create_model(
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
image=self.image_name,
image=kwargs["image"] if "image" in kwargs else self.image_name,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=(dependencies or self.dependencies),
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def create_model(
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
image=(image_name or self.image_name),
image=kwargs["image"] if "image" in kwargs else (image_name or self.image_name),
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def create_model(
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
image=self.image_name,
image=kwargs["image"] if "image" in kwargs else self.image_name,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def create_model(
base_args = dict(
model_data=self.model_data,
role=role or self.role,
image=self.image_name,
image=kwargs["image"] if "image" in kwargs else self.image_name,
name=self._current_job_name,
container_log_level=self.container_log_level,
sagemaker_session=self.sagemaker_session,
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def create_model(
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}

# remove image kwarg
if "image" in kwargs:
image = kwargs["image"]
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
else:
image = None

return SKLearnModel(
self.model_data,
role,
Expand All @@ -179,7 +186,7 @@ def create_model(
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
image=self.image_name,
image=image or self.image_name,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
enable_network_isolation=self.enable_network_isolation(),
Expand Down
18 changes: 16 additions & 2 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,17 @@ def _create_tfs_model(
**kwargs
):
"""Placeholder docstring"""
# remove image kwarg
if "image" in kwargs:
image = kwargs["image"]
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
else:
image = None

return Model(
model_data=self.model_data,
role=role,
image=self.image_name,
image=(image or self.image_name),
name=self._current_job_name,
container_log_level=self.container_log_level,
framework_version=utils.get_short_version(self.framework_version),
Expand All @@ -628,14 +635,21 @@ def _create_default_model(
**kwargs
):
"""Placeholder docstring"""
# remove image kwarg
if "image" in kwargs:
image = kwargs["image"]
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
else:
image = None

return TensorFlowModel(
self.model_data,
role,
entry_point or self.entry_point,
source_dir=source_dir or self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
env={"SAGEMAKER_REQUIREMENTS": self.requirements_file},
image=self.image_name,
image=(image or self.image_name),
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location,
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,11 @@ def test_model_empty_framework_version(warning, sagemaker_session):
)
assert model.framework_version == defaults.CHAINER_VERSION
warning.assert_called_with(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)


def test_custom_image_estimator_deploy(sagemaker_session):
custom_image = "mycustomimage:latest"
chainer = _chainer_estimator(sagemaker_session)
chainer.fit(inputs="s3://mybucket/train", job_name="new_name")
model = chainer.create_model(image=custom_image)
assert model.image == custom_image
14 changes: 14 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,3 +834,17 @@ def test_mx_enable_sm_metrics_if_fw_ver_is_at_least_1_6(sagemaker_session):
framework_version=fw_version,
)
assert mx.enable_sagemaker_metrics


def test_custom_image_estimator_deploy(sagemaker_session):
custom_image = "mycustomimage:latest"
mx = MXNet(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
)
mx.fit(inputs="s3://mybucket/train", job_name="new_name")
model = mx.create_model(image=custom_image)
assert model.image == custom_image
8 changes: 8 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,11 @@ def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
for fw_version in ["1.3", "1.4", "2.0", "2.1"]:
pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version)
assert pytorch.enable_sagemaker_metrics


def test_custom_image_estimator_deploy(sagemaker_session):
custom_image = "mycustomimage:latest"
pytorch = _pytorch_estimator(sagemaker_session)
pytorch.fit(inputs="s3://mybucket/train", job_name="new_name")
model = pytorch.create_model(image=custom_image)
assert model.image == custom_image
8 changes: 8 additions & 0 deletions tests/unit/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,11 @@ def test_wrong_type_parameters(sagemaker_session):
train_instance_type=INSTANCE_TYPE,
)
assert "combination is not supported." in str(e.value)


def test_custom_image_estimator_deploy(sagemaker_session):
custom_image = "mycustomimage:latest"
rl = _rl_estimator(sagemaker_session)
rl.fit(inputs="s3://mybucket/train", job_name="new_name")
model = rl.create_model(image=custom_image)
assert model.image == custom_image
8 changes: 8 additions & 0 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,11 @@ def test_model_py2_warning(warning, sagemaker_session):
)
assert model.py_version == "py2"
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)


def test_custom_image_estimator_deploy(sagemaker_session):
custom_image = "mycustomimage:latest"
sklearn = _sklearn_estimator(sagemaker_session)
sklearn.fit(inputs="s3://mybucket/train", job_name="new_name")
model = sklearn.create_model(image=custom_image)
assert model.image == custom_image
8 changes: 8 additions & 0 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,3 +1300,11 @@ def test_tf_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
for fw_version in ["1.15", "1.16", "2.0", "2.1"]:
tf = _build_tf(sagemaker_session, framework_version=fw_version)
assert tf.enable_sagemaker_metrics


def test_custom_image_estimator_deploy(sagemaker_session):
custom_image = "mycustomimage:latest"
tf = _build_tf(sagemaker_session)
tf.fit(inputs="s3://mybucket/train", job_name="new_name")
model = tf.create_model(image=custom_image)
assert model.image == custom_image