Skip to content

Commit 6dc9982

Browse files
laurenyusaurabh3949matthewfollegot
authored
fix: create Session or LocalSession if not specified in Model (#1288)
Co-authored-by: Saurabh Gupta <[email protected]> Co-authored-by: Matthew Follegot <[email protected]>
1 parent 2cce5b8 commit 6dc9982

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/sagemaker/model.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ def __init__(
115115
self._enable_network_isolation = enable_network_isolation
116116
self.model_kms_key = model_kms_key
117117

118+
def _init_sagemaker_session_if_does_not_exist(self, instance_type):
119+
"""Set ``self.sagemaker_session`` to be a ``LocalSession`` or
120+
``Session`` if it is not already. The type of session object is
121+
determined by the instance type.
122+
"""
123+
if self.sagemaker_session:
124+
return
125+
126+
if instance_type in ("local", "local_gpu"):
127+
self.sagemaker_session = local.LocalSession()
128+
else:
129+
self.sagemaker_session = session.Session()
130+
118131
def prepare_container_def(
119132
self, instance_type, accelerator_type=None
120133
): # pylint: disable=unused-argument
@@ -164,6 +177,8 @@ def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=Non
164177
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
165178
self.name = self.name or utils.name_from_image(container_def["Image"])
166179
enable_network_isolation = self.enable_network_isolation()
180+
181+
self._init_sagemaker_session_if_does_not_exist(instance_type)
167182
self.sagemaker_session.create_model(
168183
self.name,
169184
self.role,
@@ -324,6 +339,7 @@ def compile(
324339
framework = framework.upper()
325340
framework_version = self._get_framework_version() or framework_version
326341

342+
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
327343
config = self._compilation_job_config(
328344
target_instance_family,
329345
input_shape,
@@ -413,11 +429,7 @@ def deploy(
413429
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
414430
is not None. Otherwise, return None.
415431
"""
416-
if not self.sagemaker_session:
417-
if instance_type in ("local", "local_gpu"):
418-
self.sagemaker_session = local.LocalSession()
419-
else:
420-
self.sagemaker_session = session.Session()
432+
self._init_sagemaker_session_if_does_not_exist(instance_type)
421433

422434
if self.role is None:
423435
raise ValueError("Role can not be null for deploying a model")
@@ -514,6 +526,8 @@ def transformer(
514526
volume_kms_key (str): Optional. KMS key ID for encrypting the volume
515527
attached to the ML compute instance (default: None).
516528
"""
529+
self._init_sagemaker_session_if_does_not_exist(instance_type)
530+
517531
self._create_sagemaker_model(instance_type, tags=tags)
518532
if self.enable_network_isolation():
519533
env = None

tests/unit/test_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,21 @@ def test_model_create_transformer(sagemaker_session):
435435
sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags)
436436

437437

438+
@patch("sagemaker.session.Session")
439+
@patch("sagemaker.local.LocalSession")
440+
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
441+
def test_transformer_creates_correct_session(local_session, session):
442+
model = DummyFrameworkModel(sagemaker_session=None)
443+
transformer = model.transformer(instance_count=1, instance_type="local")
444+
assert model.sagemaker_session == local_session.return_value
445+
assert transformer.sagemaker_session == local_session.return_value
446+
447+
model = DummyFrameworkModel(sagemaker_session=None)
448+
transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge")
449+
assert model.sagemaker_session == session.return_value
450+
assert transformer.sagemaker_session == session.return_value
451+
452+
438453
def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session):
439454
sagemaker_session.sagemaker_client.describe_model_package = Mock(
440455
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE
@@ -561,6 +576,24 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
561576
assert model._is_compiled_model is True
562577

563578

579+
@patch("sagemaker.session.Session")
580+
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
581+
def test_compile_creates_session(session):
582+
session.return_value.boto_region_name = "us-west-2"
583+
584+
model = DummyFrameworkModel(sagemaker_session=None)
585+
model.compile(
586+
target_instance_family="ml_c4",
587+
input_shape={"data": [1, 3, 1024, 1024]},
588+
output_path="s3://output",
589+
role="role",
590+
framework="tensorflow",
591+
job_name="compile-model",
592+
)
593+
594+
assert model.sagemaker_session == session.return_value
595+
596+
564597
def test_check_neo_region(sagemaker_session, tmpdir):
565598
sagemaker_session.wait_for_compilation_job = Mock(
566599
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE

0 commit comments

Comments
 (0)