Skip to content

Commit cef7c5c

Browse files
authored
fix: allow custom image when calling deploy or create_model with various frameworks (#1347)
1 parent d8d64da commit cef7c5c

File tree

12 files changed

+82
-7
lines changed

12 files changed

+82
-7
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def create_model(
215215
py_version=self.py_version,
216216
framework_version=self.framework_version,
217217
model_server_workers=model_server_workers,
218-
image=self.image_name,
218+
image=kwargs["image"] if "image" in kwargs else self.image_name,
219219
sagemaker_session=self.sagemaker_session,
220220
vpc_config=self.get_vpc_config(vpc_config_override),
221221
dependencies=(dependencies or self.dependencies),

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def create_model(
210210
code_location=self.code_location,
211211
py_version=self.py_version,
212212
framework_version=self.framework_version,
213-
image=(image_name or self.image_name),
213+
image=kwargs["image"] if "image" in kwargs else (image_name or self.image_name),
214214
model_server_workers=model_server_workers,
215215
sagemaker_session=self.sagemaker_session,
216216
vpc_config=self.get_vpc_config(vpc_config_override),

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def create_model(
175175
code_location=self.code_location,
176176
py_version=self.py_version,
177177
framework_version=self.framework_version,
178-
image=self.image_name,
178+
image=kwargs["image"] if "image" in kwargs else self.image_name,
179179
model_server_workers=model_server_workers,
180180
sagemaker_session=self.sagemaker_session,
181181
vpc_config=self.get_vpc_config(vpc_config_override),

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def create_model(
218218
base_args = dict(
219219
model_data=self.model_data,
220220
role=role or self.role,
221-
image=self.image_name,
221+
image=kwargs["image"] if "image" in kwargs else self.image_name,
222222
name=self._current_job_name,
223223
container_log_level=self.container_log_level,
224224
sagemaker_session=self.sagemaker_session,

src/sagemaker/sklearn/estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@ def create_model(
167167
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
168168
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}
169169

170+
# remove image kwarg
171+
if "image" in kwargs:
172+
image = kwargs["image"]
173+
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
174+
else:
175+
image = None
176+
170177
return SKLearnModel(
171178
self.model_data,
172179
role,
@@ -179,7 +186,7 @@ def create_model(
179186
py_version=self.py_version,
180187
framework_version=self.framework_version,
181188
model_server_workers=model_server_workers,
182-
image=self.image_name,
189+
image=image or self.image_name,
183190
sagemaker_session=self.sagemaker_session,
184191
vpc_config=self.get_vpc_config(vpc_config_override),
185192
enable_network_isolation=self.enable_network_isolation(),

src/sagemaker/tensorflow/estimator.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,10 +601,17 @@ def _create_tfs_model(
601601
**kwargs
602602
):
603603
"""Placeholder docstring"""
604+
# remove image kwarg
605+
if "image" in kwargs:
606+
image = kwargs["image"]
607+
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
608+
else:
609+
image = None
610+
604611
return Model(
605612
model_data=self.model_data,
606613
role=role,
607-
image=self.image_name,
614+
image=(image or self.image_name),
608615
name=self._current_job_name,
609616
container_log_level=self.container_log_level,
610617
framework_version=utils.get_short_version(self.framework_version),
@@ -628,14 +635,21 @@ def _create_default_model(
628635
**kwargs
629636
):
630637
"""Placeholder docstring"""
638+
# remove image kwarg
639+
if "image" in kwargs:
640+
image = kwargs["image"]
641+
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
642+
else:
643+
image = None
644+
631645
return TensorFlowModel(
632646
self.model_data,
633647
role,
634648
entry_point or self.entry_point,
635649
source_dir=source_dir or self._model_source_dir(),
636650
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
637651
env={"SAGEMAKER_REQUIREMENTS": self.requirements_file},
638-
image=self.image_name,
652+
image=(image or self.image_name),
639653
name=self._current_job_name,
640654
container_log_level=self.container_log_level,
641655
code_location=self.code_location,

tests/unit/test_chainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,11 @@ def test_model_empty_framework_version(warning, sagemaker_session):
642642
)
643643
assert model.framework_version == defaults.CHAINER_VERSION
644644
warning.assert_called_with(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
645+
646+
647+
def test_custom_image_estimator_deploy(sagemaker_session):
648+
custom_image = "mycustomimage:latest"
649+
chainer = _chainer_estimator(sagemaker_session)
650+
chainer.fit(inputs="s3://mybucket/train", job_name="new_name")
651+
model = chainer.create_model(image=custom_image)
652+
assert model.image == custom_image

tests/unit/test_mxnet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,3 +834,17 @@ def test_mx_enable_sm_metrics_if_fw_ver_is_at_least_1_6(sagemaker_session):
834834
framework_version=fw_version,
835835
)
836836
assert mx.enable_sagemaker_metrics
837+
838+
839+
def test_custom_image_estimator_deploy(sagemaker_session):
840+
custom_image = "mycustomimage:latest"
841+
mx = MXNet(
842+
entry_point=SCRIPT_PATH,
843+
role=ROLE,
844+
sagemaker_session=sagemaker_session,
845+
train_instance_count=INSTANCE_COUNT,
846+
train_instance_type=INSTANCE_TYPE,
847+
)
848+
mx.fit(inputs="s3://mybucket/train", job_name="new_name")
849+
model = mx.create_model(image=custom_image)
850+
assert model.image == custom_image

tests/unit/test_pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,3 +602,11 @@ def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
602602
for fw_version in ["1.3", "1.4", "2.0", "2.1"]:
603603
pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version)
604604
assert pytorch.enable_sagemaker_metrics
605+
606+
607+
def test_custom_image_estimator_deploy(sagemaker_session):
608+
custom_image = "mycustomimage:latest"
609+
pytorch = _pytorch_estimator(sagemaker_session)
610+
pytorch.fit(inputs="s3://mybucket/train", job_name="new_name")
611+
model = pytorch.create_model(image=custom_image)
612+
assert model.image == custom_image

tests/unit/test_rl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,3 +618,11 @@ def test_wrong_type_parameters(sagemaker_session):
618618
train_instance_type=INSTANCE_TYPE,
619619
)
620620
assert "combination is not supported." in str(e.value)
621+
622+
623+
def test_custom_image_estimator_deploy(sagemaker_session):
624+
custom_image = "mycustomimage:latest"
625+
rl = _rl_estimator(sagemaker_session)
626+
rl.fit(inputs="s3://mybucket/train", job_name="new_name")
627+
model = rl.create_model(image=custom_image)
628+
assert model.image == custom_image

tests/unit/test_sklearn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,11 @@ def test_model_py2_warning(warning, sagemaker_session):
570570
)
571571
assert model.py_version == "py2"
572572
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
573+
574+
575+
def test_custom_image_estimator_deploy(sagemaker_session):
576+
custom_image = "mycustomimage:latest"
577+
sklearn = _sklearn_estimator(sagemaker_session)
578+
sklearn.fit(inputs="s3://mybucket/train", job_name="new_name")
579+
model = sklearn.create_model(image=custom_image)
580+
assert model.image == custom_image

tests/unit/test_tf_estimator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,3 +1300,11 @@ def test_tf_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
13001300
for fw_version in ["1.15", "1.16", "2.0", "2.1"]:
13011301
tf = _build_tf(sagemaker_session, framework_version=fw_version)
13021302
assert tf.enable_sagemaker_metrics
1303+
1304+
1305+
def test_custom_image_estimator_deploy(sagemaker_session):
1306+
custom_image = "mycustomimage:latest"
1307+
tf = _build_tf(sagemaker_session)
1308+
tf.fit(inputs="s3://mybucket/train", job_name="new_name")
1309+
model = tf.create_model(image=custom_image)
1310+
assert model.image == custom_image

0 commit comments

Comments
 (0)