Skip to content

Commit f76f8a8

Browse files
authored
change: allow specifying model name in create_model() for TensorFlow, SKLearn, and XGBoost (#1397)
1 parent c51df67 commit f76f8a8

File tree

6 files changed

+33
-37
lines changed

6 files changed

+33
-37
lines changed

src/sagemaker/sklearn/estimator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,25 +165,27 @@ def create_model(
165165
# remove unwanted entry_point kwarg
166166
if "entry_point" in kwargs:
167167
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
168-
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}
168+
del kwargs["entry_point"]
169169

170170
# remove image kwarg
171171
if "image" in kwargs:
172172
image = kwargs["image"]
173-
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
173+
del kwargs["image"]
174174
else:
175175
image = None
176176

177177
if "enable_network_isolation" not in kwargs:
178178
kwargs["enable_network_isolation"] = self.enable_network_isolation()
179179

180+
if "name" not in kwargs:
181+
kwargs["name"] = self._current_job_name
182+
180183
return SKLearnModel(
181184
self.model_data,
182185
role,
183186
self.entry_point,
184187
source_dir=self._model_source_dir(),
185188
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
186-
name=self._current_job_name,
187189
container_log_level=self.container_log_level,
188190
code_location=self.code_location,
189191
py_version=self.py_version,

src/sagemaker/tensorflow/estimator.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,12 @@ def create_model(
584584
"""
585585
role = role or self.role
586586

587+
if "image" not in kwargs:
588+
kwargs["image"] = self.image_name
589+
590+
if "name" not in kwargs:
591+
kwargs["name"] = self._current_job_name
592+
587593
if endpoint_type == "tensorflow-serving" or self._script_mode_enabled():
588594
return self._create_tfs_model(
589595
role=role,
@@ -614,18 +620,9 @@ def _create_tfs_model(
614620
**kwargs
615621
):
616622
"""Placeholder docstring"""
617-
# remove image kwarg
618-
if "image" in kwargs:
619-
image = kwargs["image"]
620-
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
621-
else:
622-
image = None
623-
624623
return Model(
625624
model_data=self.model_data,
626625
role=role,
627-
image=(image or self.image_name),
628-
name=self._current_job_name,
629626
container_log_level=self.container_log_level,
630627
framework_version=utils.get_short_version(self.framework_version),
631628
sagemaker_session=self.sagemaker_session,
@@ -648,22 +645,13 @@ def _create_default_model(
648645
**kwargs
649646
):
650647
"""Placeholder docstring"""
651-
# remove image kwarg
652-
if "image" in kwargs:
653-
image = kwargs["image"]
654-
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
655-
else:
656-
image = None
657-
658648
return TensorFlowModel(
659649
self.model_data,
660650
role,
661651
entry_point or self.entry_point,
662652
source_dir=source_dir or self._model_source_dir(),
663653
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
664654
env={"SAGEMAKER_REQUIREMENTS": self.requirements_file},
665-
image=(image or self.image_name),
666-
name=self._current_job_name,
667655
container_log_level=self.container_log_level,
668656
code_location=self.code_location,
669657
py_version=self.py_version,

src/sagemaker/xgboost/estimator.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,16 @@
1515

1616
import logging
1717

18-
from sagemaker.estimator import Framework
18+
from sagemaker.estimator import Framework, _TrainingJob
1919
from sagemaker.fw_registry import default_framework_uri
2020
from sagemaker.fw_utils import (
2121
framework_name_from_image,
2222
framework_version_from_tag,
2323
get_unsupported_framework_version_error,
2424
UploadedCode,
2525
)
26-
27-
2826
from sagemaker.session import Session
29-
30-
from sagemaker.estimator import _TrainingJob
31-
3227
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
33-
3428
from sagemaker.xgboost import defaults
3529
from sagemaker.xgboost.model import XGBoostModel
3630

@@ -154,7 +148,10 @@ def create_model(
154148
# Remove unwanted entry_point kwarg
155149
if "entry_point" in kwargs:
156150
logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"]))
157-
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}
151+
del kwargs["entry_point"]
152+
153+
if "name" not in kwargs:
154+
kwargs["name"] = self._current_job_name
158155

159156
return XGBoostModel(
160157
self.model_data,
@@ -163,7 +160,6 @@ def create_model(
163160
framework_version=self.framework_version,
164161
source_dir=self._model_source_dir(),
165162
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
166-
name=self._current_job_name,
167163
container_log_level=self.container_log_level,
168164
code_location=self.code_location,
169165
py_version=self.py_version,

tests/unit/test_sklearn.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
from mock import Mock
2121
from mock import patch
2222

23-
from sagemaker.sklearn import defaults
24-
from sagemaker.sklearn import SKLearn
25-
from sagemaker.sklearn import SKLearnPredictor, SKLearnModel
23+
from sagemaker.sklearn import defaults, SKLearn, SKLearnModel, SKLearnPredictor
2624
from sagemaker.fw_utils import UploadedCode
2725

2826
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
@@ -254,13 +252,18 @@ def test_create_model_with_optional_params(sagemaker_session):
254252
new_role = "role"
255253
model_server_workers = 2
256254
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
255+
model_name = "model-name"
257256
model = sklearn.create_model(
258-
role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config
257+
role=new_role,
258+
model_server_workers=model_server_workers,
259+
vpc_config_override=vpc_config,
260+
name=model_name,
259261
)
260262

261263
assert model.role == new_role
262264
assert model.model_server_workers == model_server_workers
263265
assert model.vpc_config == vpc_config
266+
assert model.name == model_name
264267

265268

266269
def test_create_model_with_custom_image(sagemaker_session):

tests/unit/test_tf_estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,17 +322,20 @@ def test_create_model_with_optional_params(sagemaker_session):
322322
new_role = "role"
323323
model_server_workers = 2
324324
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
325+
model_name = "model-name"
325326
model = tf.create_model(
326327
role=new_role,
327328
model_server_workers=model_server_workers,
328329
vpc_config_override=vpc_config,
329330
entry_point=SERVING_SCRIPT_FILE,
331+
name=model_name,
330332
)
331333

332334
assert model.role == new_role
333335
assert model.model_server_workers == model_server_workers
334336
assert model.vpc_config == vpc_config
335337
assert model.entry_point == SERVING_SCRIPT_FILE
338+
assert model.name == model_name
336339

337340

338341
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")

tests/unit/test_xgboost.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323

2424
from sagemaker.xgboost.defaults import XGBOOST_LATEST_VERSION
25-
from sagemaker.xgboost import XGBoost
26-
from sagemaker.xgboost import XGBoostPredictor, XGBoostModel
25+
from sagemaker.xgboost import XGBoost, XGBoostModel, XGBoostPredictor
2726

2827

2928
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
@@ -242,13 +241,18 @@ def test_create_model_with_optional_params(sagemaker_session):
242241
new_role = "role"
243242
model_server_workers = 2
244243
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
244+
model_name = "model-name"
245245
model = xgboost.create_model(
246-
role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config
246+
role=new_role,
247+
model_server_workers=model_server_workers,
248+
vpc_config_override=vpc_config,
249+
name=model_name,
247250
)
248251

249252
assert model.role == new_role
250253
assert model.model_server_workers == model_server_workers
251254
assert model.vpc_config == vpc_config
255+
assert model.name == model_name
252256

253257

254258
def test_create_model_with_custom_image(sagemaker_session):

0 commit comments

Comments
 (0)