Skip to content

change: allow specifying model name in create_model() for Chainer, MXNet, PyTorch, and RL #1396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,15 @@ def create_model(
if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return ChainerModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,15 @@ def create_model(
if "image" not in kwargs:
kwargs["image"] = image_name or self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return MXNetModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,15 @@ def create_model(
if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return PyTorchModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def create_model(
base_args = dict(
model_data=self.model_data,
role=role or self.role,
image=kwargs["image"] if "image" in kwargs else self.image_name,
name=self._current_job_name,
image=kwargs.get("image", self.image_name),
name=kwargs.get("name", self._current_job_name),
container_log_level=self.container_log_level,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,19 +322,22 @@ def test_create_model_with_optional_params(sagemaker_session):
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
model_name = "model-name"
model = chainer.create_model(
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
entry_point=SERVING_SCRIPT_FILE,
env=ENV,
name=model_name,
)

assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.entry_point == SERVING_SCRIPT_FILE
assert model.env == ENV
assert model.name == model_name


def test_create_model_with_custom_image(sagemaker_session):
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,22 @@ def test_create_model_with_optional_params(sagemaker_session):
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
model_name = "model-name"
model = mx.create_model(
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
entry_point=SERVING_SCRIPT_FILE,
env=ENV,
name=model_name,
)

assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.entry_point == SERVING_SCRIPT_FILE
assert model.env == ENV
assert model.name == model_name


def test_create_model_with_custom_image(sagemaker_session):
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,22 @@ def test_create_model_with_optional_params(sagemaker_session):
new_role = "role"
model_server_workers = 2
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
model_name = "model-name"
model = pytorch.create_model(
role=new_role,
model_server_workers=model_server_workers,
vpc_config_override=vpc_config,
entry_point=SERVING_SCRIPT_FILE,
env=ENV,
name=model_name,
)

assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config
assert model.entry_point == SERVING_SCRIPT_FILE
assert model.env == ENV
assert model.name == model_name


def test_create_model_with_custom_image(sagemaker_session):
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,15 @@ def test_create_model_with_optional_params(sagemaker_session, rl_coach_mxnet_ver
new_role = "role"
new_entry_point = "deploy_script.py"
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
model_name = "model-name"
model = rl.create_model(
role=new_role, entry_point=new_entry_point, vpc_config_override=vpc_config
role=new_role, entry_point=new_entry_point, vpc_config_override=vpc_config, name=model_name
)

assert model.role == new_role
assert model.vpc_config == vpc_config
assert model.entry_point == new_entry_point
assert model.name == model_name


def test_create_model_with_custom_image(sagemaker_session):
Expand Down
3 changes: 1 addition & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ envlist = black-format,flake8,pylint,twine,sphinx,py27,py36

skip_missing_interpreters = False


[flake8]
max-line-length = 120
exclude =
Expand Down Expand Up @@ -59,7 +58,7 @@ passenv =
# Can be used to specify which tests to run, e.g.: tox -- -s
commands =
coverage run --source sagemaker -m pytest {posargs}
{env:IGNORE_COVERAGE:} coverage report --fail-under=84 --omit */tensorflow/tensorflow_serving/*
{env:IGNORE_COVERAGE:} coverage report --fail-under=85 --omit */tensorflow/tensorflow_serving/*
extras = test

[testenv:flake8]
Expand Down