Skip to content

infra: add Model unit tests for prepare_container_def and _create_sagemaker_model #1421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/model/test_framework_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_prepare_container_def_with_network_isolation(time, sagemaker_session):
@patch("os.path.isdir", MagicMock(return_value=True))
@patch("os.listdir", MagicMock(return_value=["blah.py"]))
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
def test_create_no_defaults(sagemaker_session, tmpdir):
def test_prepare_container_def_no_model_defaults(sagemaker_session, tmpdir):
model = DummyFrameworkModel(
sagemaker_session,
source_dir="sd",
Expand Down
111 changes: 106 additions & 5 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)

INSTANCE_COUNT = 2
INSTANCE_TYPE = "c4.4xlarge"
INSTANCE_TYPE = "ml.c4.4xlarge"
ROLE = "some-role"

BASE_PRODUCTION_VARIANT = {
Expand All @@ -43,17 +43,119 @@ def sagemaker_session():
return Mock()


@patch("sagemaker.production_variant")
def test_prepare_container_def():
env = {"FOO": "BAR"}
model = Model(MODEL_DATA, MODEL_IMAGE, env=env)

container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium")

expected = {"Image": MODEL_IMAGE, "Environment": env, "ModelDataUrl": MODEL_DATA}
assert expected == container_def


@patch("sagemaker.model.Model.prepare_container_def")
@patch("sagemaker.utils.name_from_image")
def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session):
def test_create_sagemaker_model(name_from_image, prepare_container_def, sagemaker_session):
name_from_image.return_value = MODEL_NAME

container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
prepare_container_def.return_value = container_def

model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
model._create_sagemaker_model(INSTANCE_TYPE)

prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
name_from_image.assert_called_with(MODEL_IMAGE)

sagemaker_session.create_model.assert_called_with(
MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None
)


@patch("sagemaker.utils.name_from_image", Mock())
@patch("sagemaker.model.Model.prepare_container_def")
def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemaker_session):
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)

accelerator_type = "ml.eia.medium"
model._create_sagemaker_model(INSTANCE_TYPE, accelerator_type=accelerator_type)

prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=accelerator_type)


@patch("sagemaker.model.Model.prepare_container_def")
@patch("sagemaker.utils.name_from_image")
def test_create_sagemaker_model_tags(name_from_image, prepare_container_def, sagemaker_session):
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
prepare_container_def.return_value = container_def

name_from_image.return_value = MODEL_NAME

model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)

tags = {"Key": "foo", "Value": "bar"}
model._create_sagemaker_model(INSTANCE_TYPE, tags=tags)

sagemaker_session.create_model.assert_called_with(
MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=tags
)


@patch("sagemaker.model.Model.prepare_container_def")
@patch("sagemaker.utils.name_from_image")
def test_create_sagemaker_model_optional_model_params(
name_from_image, prepare_container_def, sagemaker_session
):
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
prepare_container_def.return_value = container_def

vpc_config = {"Subnets": ["123"], "SecurityGroupIds": ["456", "789"]}

model = Model(
MODEL_DATA,
MODEL_IMAGE,
name=MODEL_NAME,
role=ROLE,
vpc_config=vpc_config,
enable_network_isolation=True,
sagemaker_session=sagemaker_session,
)
model._create_sagemaker_model(INSTANCE_TYPE)

name_from_image.assert_not_called()

sagemaker_session.create_model.assert_called_with(
MODEL_NAME,
ROLE,
container_def,
vpc_config=vpc_config,
enable_network_isolation=True,
tags=None,
)


@patch("sagemaker.session.Session")
@patch("sagemaker.local.LocalSession")
def test_create_sagemaker_model_creates_correct_session(local_session, session):
model = Model(MODEL_DATA, MODEL_IMAGE)
model._create_sagemaker_model("local")
assert model.sagemaker_session == local_session.return_value

model = Model(MODEL_DATA, MODEL_IMAGE)
model._create_sagemaker_model("ml.m5.xlarge")
assert model.sagemaker_session == session.return_value


@patch("sagemaker.production_variant")
@patch("sagemaker.model.Model.prepare_container_def")
@patch("sagemaker.utils.name_from_image")
def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session):
name_from_image.return_value = MODEL_NAME
production_variant.return_value = BASE_PRODUCTION_VARIANT

container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
prepare_container_def.return_value = container_def

model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

Expand Down Expand Up @@ -223,7 +325,7 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):

@patch("sagemaker.session.Session")
@patch("sagemaker.local.LocalSession")
def test_deploy_creates_correct_session(local_session, session, tmpdir):
def test_deploy_creates_correct_session(local_session, session):
# We expect a LocalSession when deploying to instance_type = 'local'
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1)
Expand Down Expand Up @@ -356,7 +458,6 @@ def test_model_create_transformer_network_isolation(create_sagemaker_model, sage

@patch("sagemaker.session.Session")
@patch("sagemaker.local.LocalSession")
@patch("sagemaker.fw_utils.tar_and_upload_dir", Mock())
def test_transformer_creates_correct_session(local_session, session):
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
transformer = model.transformer(instance_count=1, instance_type="local")
Expand Down