Skip to content

Commit 36b4627

Browse files
authored
infra: use Model class for model deployment unit tests (#1418)
This commit also adds a few more unit tests to round out coverage for the deploy() method.
1 parent 36987c4 commit 36b4627

File tree

2 files changed

+243
-140
lines changed

2 files changed

+243
-140
lines changed

tests/unit/sagemaker/model/test_framework_model.py

Lines changed: 0 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
MODEL_DATA = "s3://bucket/model.tar.gz"
2525
MODEL_IMAGE = "mi"
2626
ENTRY_POINT = "blah.py"
27-
INSTANCE_TYPE = "p2.xlarge"
2827
ROLE = "some-role"
2928

3029
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
@@ -172,145 +171,6 @@ def test_create_no_defaults(sagemaker_session, tmpdir):
172171
}
173172

174173

175-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
176-
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
177-
def test_deploy(sagemaker_session, tmpdir):
178-
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
179-
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)
180-
sagemaker_session.endpoint_from_production_variants.assert_called_with(
181-
name=MODEL_NAME,
182-
production_variants=[
183-
{
184-
"InitialVariantWeight": 1,
185-
"ModelName": MODEL_NAME,
186-
"InstanceType": INSTANCE_TYPE,
187-
"InitialInstanceCount": 1,
188-
"VariantName": "AllTraffic",
189-
}
190-
],
191-
tags=None,
192-
kms_key=None,
193-
wait=True,
194-
data_capture_config_dict=None,
195-
)
196-
197-
198-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
199-
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
200-
def test_deploy_endpoint_name(sagemaker_session, tmpdir):
201-
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
202-
model.deploy(endpoint_name="blah", instance_type=INSTANCE_TYPE, initial_instance_count=55)
203-
sagemaker_session.endpoint_from_production_variants.assert_called_with(
204-
name="blah",
205-
production_variants=[
206-
{
207-
"InitialVariantWeight": 1,
208-
"ModelName": MODEL_NAME,
209-
"InstanceType": INSTANCE_TYPE,
210-
"InitialInstanceCount": 55,
211-
"VariantName": "AllTraffic",
212-
}
213-
],
214-
tags=None,
215-
kms_key=None,
216-
wait=True,
217-
data_capture_config_dict=None,
218-
)
219-
220-
221-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
222-
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
223-
def test_deploy_tags(sagemaker_session, tmpdir):
224-
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
225-
tags = [{"ModelName": "TestModel"}]
226-
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags)
227-
sagemaker_session.endpoint_from_production_variants.assert_called_with(
228-
name=MODEL_NAME,
229-
production_variants=[
230-
{
231-
"InitialVariantWeight": 1,
232-
"ModelName": MODEL_NAME,
233-
"InstanceType": INSTANCE_TYPE,
234-
"InitialInstanceCount": 1,
235-
"VariantName": "AllTraffic",
236-
}
237-
],
238-
tags=tags,
239-
kms_key=None,
240-
wait=True,
241-
data_capture_config_dict=None,
242-
)
243-
244-
245-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
246-
@patch("tarfile.open")
247-
@patch("time.strftime", return_value=TIMESTAMP)
248-
def test_deploy_accelerator_type(tfo, time, sagemaker_session):
249-
model = DummyFrameworkModel(sagemaker_session)
250-
model.deploy(
251-
instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE
252-
)
253-
sagemaker_session.endpoint_from_production_variants.assert_called_with(
254-
name=MODEL_NAME,
255-
production_variants=[
256-
{
257-
"InitialVariantWeight": 1,
258-
"ModelName": MODEL_NAME,
259-
"InstanceType": INSTANCE_TYPE,
260-
"InitialInstanceCount": 1,
261-
"VariantName": "AllTraffic",
262-
"AcceleratorType": ACCELERATOR_TYPE,
263-
}
264-
],
265-
tags=None,
266-
kms_key=None,
267-
wait=True,
268-
data_capture_config_dict=None,
269-
)
270-
271-
272-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
273-
@patch("tarfile.open")
274-
@patch("time.strftime", return_value=TIMESTAMP)
275-
def test_deploy_kms_key(tfo, time, sagemaker_session):
276-
key = "some-key-arn"
277-
model = DummyFrameworkModel(sagemaker_session)
278-
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key)
279-
sagemaker_session.endpoint_from_production_variants.assert_called_with(
280-
name=MODEL_NAME,
281-
production_variants=[
282-
{
283-
"InitialVariantWeight": 1,
284-
"ModelName": MODEL_NAME,
285-
"InstanceType": INSTANCE_TYPE,
286-
"InitialInstanceCount": 1,
287-
"VariantName": "AllTraffic",
288-
}
289-
],
290-
tags=None,
291-
kms_key=key,
292-
wait=True,
293-
data_capture_config_dict=None,
294-
)
295-
296-
297-
@patch("sagemaker.session.Session")
298-
@patch("sagemaker.local.LocalSession")
299-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
300-
def test_deploy_creates_correct_session(local_session, session, tmpdir):
301-
# We expect a LocalSession when deploying to instance_type = 'local'
302-
model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir))
303-
model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1)
304-
assert model.sagemaker_session == local_session.return_value
305-
306-
# We expect a real Session when deploying to instance_type != local/local_gpu
307-
model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir))
308-
model.deploy(
309-
endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2
310-
)
311-
assert model.sagemaker_session == session.return_value
312-
313-
314174
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
315175
def test_deploy_update_endpoint(sagemaker_session, tmpdir):
316176
model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir)

0 commit comments

Comments
 (0)