Skip to content

Commit cb20ae4

Browse files
committed
add unit tests
1 parent 3161238 commit cb20ae4

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/sagemaker/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type):
120120
return
121121

122122
if instance_type in ("local", "local_gpu"):
123-
return local.LocalSession()
123+
self.sagemaker_session = local.LocalSession()
124124
else:
125-
return session.Session()
125+
self.sagemaker_session = session.Session()
126126

127127
def prepare_container_def(
128128
self, instance_type, accelerator_type=None
@@ -221,7 +221,7 @@ def _compilation_job_config(
221221
else json.dumps(input_shape),
222222
"Framework": framework,
223223
}
224-
self._init_sagemaker_session_if_does_not_exist(target_instance_type)
224+
225225
role = self.sagemaker_session.expand_role(role)
226226
output_model_config = {
227227
"TargetDevice": target_instance_type,
@@ -336,6 +336,7 @@ def compile(
336336
framework = framework.upper()
337337
framework_version = self._get_framework_version() or framework_version
338338

339+
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
339340
config = self._compilation_job_config(
340341
target_instance_family,
341342
input_shape,
@@ -346,7 +347,6 @@ def compile(
346347
framework,
347348
tags,
348349
)
349-
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
350350
self.sagemaker_session.compile_model(**config)
351351
job_status = self.sagemaker_session.wait_for_compilation_job(job_name)
352352
self.model_data = job_status["ModelArtifacts"]["S3ModelArtifacts"]

tests/unit/test_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,21 @@ def test_model_create_transformer(sagemaker_session):
435435
sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags)
436436

437437

438+
@patch("sagemaker.session.Session")
439+
@patch("sagemaker.local.LocalSession")
440+
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
441+
def test_transformer_creates_correct_session(local_session, session):
442+
model = DummyFrameworkModel(sagemaker_session=None)
443+
transformer = model.transformer(instance_count=1, instance_type="local")
444+
assert model.sagemaker_session == local_session.return_value
445+
assert transformer.sagemaker_session == local_session.return_value
446+
447+
model = DummyFrameworkModel(sagemaker_session=None)
448+
transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge")
449+
assert model.sagemaker_session == session.return_value
450+
assert transformer.sagemaker_session == session.return_value
451+
452+
438453
def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session):
439454
sagemaker_session.sagemaker_client.describe_model_package = Mock(
440455
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE
@@ -561,6 +576,21 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
561576
assert model._is_compiled_model is True
562577

563578

579+
@patch("sagemaker.session.Session")
580+
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
581+
def test_compile_creates_session(session):
582+
model = DummyFrameworkModel(sagemaker_session=None)
583+
model.compile(
584+
target_instance_family="ml_c4",
585+
input_shape={"data": [1, 3, 1024, 1024]},
586+
output_path="s3://output",
587+
role="role",
588+
framework="tensorflow",
589+
)
590+
591+
assert model.sagemaker_sesion == session.return_value
592+
593+
564594
def test_check_neo_region(sagemaker_session, tmpdir):
565595
sagemaker_session.wait_for_compilation_job = Mock(
566596
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE

0 commit comments

Comments
 (0)