Skip to content

Commit 02252a2

Browse files
authored
infra: add Model unit tests for prepare_container_def and _create_sagemaker_model (aws#1421)
1 parent 36b4627 commit 02252a2

File tree

2 files changed

+107
-6
lines changed

2 files changed

+107
-6
lines changed

tests/unit/sagemaker/model/test_framework_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_prepare_container_def_with_network_isolation(time, sagemaker_session):
146146
@patch("os.path.isdir", MagicMock(return_value=True))
147147
@patch("os.listdir", MagicMock(return_value=["blah.py"]))
148148
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
149-
def test_create_no_defaults(sagemaker_session, tmpdir):
149+
def test_prepare_container_def_no_model_defaults(sagemaker_session, tmpdir):
150150
model = DummyFrameworkModel(
151151
sagemaker_session,
152152
source_dir="sd",

tests/unit/sagemaker/model/test_model.py

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
2727

2828
INSTANCE_COUNT = 2
29-
INSTANCE_TYPE = "c4.4xlarge"
29+
INSTANCE_TYPE = "ml.c4.4xlarge"
3030
ROLE = "some-role"
3131

3232
BASE_PRODUCTION_VARIANT = {
@@ -43,17 +43,119 @@ def sagemaker_session():
4343
return Mock()
4444

4545

46-
@patch("sagemaker.production_variant")
46+
def test_prepare_container_def():
47+
env = {"FOO": "BAR"}
48+
model = Model(MODEL_DATA, MODEL_IMAGE, env=env)
49+
50+
container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium")
51+
52+
expected = {"Image": MODEL_IMAGE, "Environment": env, "ModelDataUrl": MODEL_DATA}
53+
assert expected == container_def
54+
55+
4756
@patch("sagemaker.model.Model.prepare_container_def")
4857
@patch("sagemaker.utils.name_from_image")
49-
def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session):
58+
def test_create_sagemaker_model(name_from_image, prepare_container_def, sagemaker_session):
59+
name_from_image.return_value = MODEL_NAME
60+
61+
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
62+
prepare_container_def.return_value = container_def
63+
64+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
65+
model._create_sagemaker_model(INSTANCE_TYPE)
66+
67+
prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
68+
name_from_image.assert_called_with(MODEL_IMAGE)
69+
70+
sagemaker_session.create_model.assert_called_with(
71+
MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None
72+
)
73+
74+
75+
@patch("sagemaker.utils.name_from_image", Mock())
76+
@patch("sagemaker.model.Model.prepare_container_def")
77+
def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemaker_session):
78+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
79+
80+
accelerator_type = "ml.eia.medium"
81+
model._create_sagemaker_model(INSTANCE_TYPE, accelerator_type=accelerator_type)
82+
83+
prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=accelerator_type)
84+
85+
86+
@patch("sagemaker.model.Model.prepare_container_def")
87+
@patch("sagemaker.utils.name_from_image")
88+
def test_create_sagemaker_model_tags(name_from_image, prepare_container_def, sagemaker_session):
89+
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
90+
prepare_container_def.return_value = container_def
91+
5092
name_from_image.return_value = MODEL_NAME
5193

94+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
95+
96+
tags = {"Key": "foo", "Value": "bar"}
97+
model._create_sagemaker_model(INSTANCE_TYPE, tags=tags)
98+
99+
sagemaker_session.create_model.assert_called_with(
100+
MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=tags
101+
)
102+
103+
104+
@patch("sagemaker.model.Model.prepare_container_def")
105+
@patch("sagemaker.utils.name_from_image")
106+
def test_create_sagemaker_model_optional_model_params(
107+
name_from_image, prepare_container_def, sagemaker_session
108+
):
52109
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
53110
prepare_container_def.return_value = container_def
54111

112+
vpc_config = {"Subnets": ["123"], "SecurityGroupIds": ["456", "789"]}
113+
114+
model = Model(
115+
MODEL_DATA,
116+
MODEL_IMAGE,
117+
name=MODEL_NAME,
118+
role=ROLE,
119+
vpc_config=vpc_config,
120+
enable_network_isolation=True,
121+
sagemaker_session=sagemaker_session,
122+
)
123+
model._create_sagemaker_model(INSTANCE_TYPE)
124+
125+
name_from_image.assert_not_called()
126+
127+
sagemaker_session.create_model.assert_called_with(
128+
MODEL_NAME,
129+
ROLE,
130+
container_def,
131+
vpc_config=vpc_config,
132+
enable_network_isolation=True,
133+
tags=None,
134+
)
135+
136+
137+
@patch("sagemaker.session.Session")
138+
@patch("sagemaker.local.LocalSession")
139+
def test_create_sagemaker_model_creates_correct_session(local_session, session):
140+
model = Model(MODEL_DATA, MODEL_IMAGE)
141+
model._create_sagemaker_model("local")
142+
assert model.sagemaker_session == local_session.return_value
143+
144+
model = Model(MODEL_DATA, MODEL_IMAGE)
145+
model._create_sagemaker_model("ml.m5.xlarge")
146+
assert model.sagemaker_session == session.return_value
147+
148+
149+
@patch("sagemaker.production_variant")
150+
@patch("sagemaker.model.Model.prepare_container_def")
151+
@patch("sagemaker.utils.name_from_image")
152+
def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session):
153+
name_from_image.return_value = MODEL_NAME
55154
production_variant.return_value = BASE_PRODUCTION_VARIANT
56155

156+
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
157+
prepare_container_def.return_value = container_def
158+
57159
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
58160
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
59161

@@ -223,7 +325,7 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):
223325

224326
@patch("sagemaker.session.Session")
225327
@patch("sagemaker.local.LocalSession")
226-
def test_deploy_creates_correct_session(local_session, session, tmpdir):
328+
def test_deploy_creates_correct_session(local_session, session):
227329
# We expect a LocalSession when deploying to instance_type = 'local'
228330
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
229331
model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1)
@@ -356,7 +458,6 @@ def test_model_create_transformer_network_isolation(create_sagemaker_model, sage
356458

357459
@patch("sagemaker.session.Session")
358460
@patch("sagemaker.local.LocalSession")
359-
@patch("sagemaker.fw_utils.tar_and_upload_dir", Mock())
360461
def test_transformer_creates_correct_session(local_session, session):
361462
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
362463
transformer = model.transformer(instance_count=1, instance_type="local")

0 commit comments

Comments
 (0)