Skip to content

change: allow specifying model name in create_model() for TensorFlow, SKLearn, and XGBoost #1397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,25 +165,27 @@ def create_model(
# remove unwanted entry_point kwarg
if "entry_point" in kwargs:
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}
del kwargs["entry_point"]

# remove image kwarg
if "image" in kwargs:
image = kwargs["image"]
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
del kwargs["image"]
else:
image = None

if "enable_network_isolation" not in kwargs:
kwargs["enable_network_isolation"] = self.enable_network_isolation()

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return SKLearnModel(
self.model_data,
role,
self.entry_point,
source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
Expand Down
24 changes: 6 additions & 18 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,12 @@ def create_model(
"""
role = role or self.role

if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

if endpoint_type == "tensorflow-serving" or self._script_mode_enabled():
return self._create_tfs_model(
role=role,
Expand Down Expand Up @@ -614,18 +620,9 @@ def _create_tfs_model(
**kwargs
):
"""Placeholder docstring"""
# remove image kwarg
if "image" in kwargs:
image = kwargs["image"]
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
else:
image = None

return Model(
model_data=self.model_data,
role=role,
image=(image or self.image_name),
name=self._current_job_name,
container_log_level=self.container_log_level,
framework_version=utils.get_short_version(self.framework_version),
sagemaker_session=self.sagemaker_session,
Expand All @@ -648,22 +645,13 @@ def _create_default_model(
**kwargs
):
"""Placeholder docstring"""
# remove image kwarg
if "image" in kwargs:
image = kwargs["image"]
kwargs = {k: v for k, v in kwargs.items() if k != "image"}
else:
image = None

return TensorFlowModel(
self.model_data,
role,
entry_point or self.entry_point,
source_dir=source_dir or self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
env={"SAGEMAKER_REQUIREMENTS": self.requirements_file},
image=(image or self.image_name),
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
Expand Down
14 changes: 5 additions & 9 deletions src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@

import logging

from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, _TrainingJob
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
get_unsupported_framework_version_error,
UploadedCode,
)


from sagemaker.session import Session

from sagemaker.estimator import _TrainingJob

from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

from sagemaker.xgboost import defaults
from sagemaker.xgboost.model import XGBoostModel

Expand Down Expand Up @@ -154,7 +148,10 @@ def create_model(
# Remove unwanted entry_point kwarg
if "entry_point" in kwargs:
logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"]))
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}
del kwargs["entry_point"]

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return XGBoostModel(
self.model_data,
Expand All @@ -163,7 +160,6 @@ def create_model(
framework_version=self.framework_version,
source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from mock import Mock
from mock import patch

from sagemaker.sklearn import defaults
from sagemaker.sklearn import SKLearn
from sagemaker.sklearn import SKLearnPredictor, SKLearnModel
from sagemaker.sklearn import defaults, SKLearn, SKLearnModel, SKLearnPredictor
from sagemaker.fw_utils import UploadedCode

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
Expand Down Expand Up @@ -254,13 +252,18 @@ def test_create_model_with_optional_params(sagemaker_session):
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
model_name = "model-name"
model = sklearn.create_model(
role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
name=model_name,
)

assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.name == model_name


def test_create_model_with_custom_image(sagemaker_session):
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,20 @@ def test_create_model_with_optional_params(sagemaker_session):
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
model_name = "model-name"
model = tf.create_model(
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
entry_point=SERVING_SCRIPT_FILE,
name=model_name,
)

assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.entry_point == SERVING_SCRIPT_FILE
assert model.name == model_name


@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@


from sagemaker.xgboost.defaults import XGBOOST_LATEST_VERSION
from sagemaker.xgboost import XGBoost
from sagemaker.xgboost import XGBoostPredictor, XGBoostModel
from sagemaker.xgboost import XGBoost, XGBoostModel, XGBoostPredictor


DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
Expand Down Expand Up @@ -242,13 +241,18 @@ def test_create_model_with_optional_params(sagemaker_session):
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
model_name = "model-name"
model = xgboost.create_model(
role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
name=model_name,
)

assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.name == model_name


def test_create_model_with_custom_image(sagemaker_session):
Expand Down