@@ -115,6 +115,19 @@ def __init__(
115
115
self ._enable_network_isolation = enable_network_isolation
116
116
self .model_kms_key = model_kms_key
117
117
118
+ def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
119
+ """Set ``self.sagemaker_session`` to be a ``LocalSession`` or
120
+ ``Session`` if it is not already. The type of session object is
121
+ determined by the instance type.
122
+ """
123
+ if self .sagemaker_session :
124
+ return
125
+
126
+ if instance_type in ("local" , "local_gpu" ):
127
+ self .sagemaker_session = local .LocalSession ()
128
+ else :
129
+ self .sagemaker_session = session .Session ()
130
+
118
131
def prepare_container_def (
119
132
self , instance_type , accelerator_type = None
120
133
): # pylint: disable=unused-argument
@@ -164,6 +177,8 @@ def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=Non
164
177
container_def = self .prepare_container_def (instance_type , accelerator_type = accelerator_type )
165
178
self .name = self .name or utils .name_from_image (container_def ["Image" ])
166
179
enable_network_isolation = self .enable_network_isolation ()
180
+
181
+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
167
182
self .sagemaker_session .create_model (
168
183
self .name ,
169
184
self .role ,
@@ -324,6 +339,7 @@ def compile(
324
339
framework = framework .upper ()
325
340
framework_version = self ._get_framework_version () or framework_version
326
341
342
+ self ._init_sagemaker_session_if_does_not_exist (target_instance_family )
327
343
config = self ._compilation_job_config (
328
344
target_instance_family ,
329
345
input_shape ,
@@ -413,11 +429,7 @@ def deploy(
413
429
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
414
430
is not None. Otherwise, return None.
415
431
"""
416
- if not self .sagemaker_session :
417
- if instance_type in ("local" , "local_gpu" ):
418
- self .sagemaker_session = local .LocalSession ()
419
- else :
420
- self .sagemaker_session = session .Session ()
432
+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
421
433
422
434
if self .role is None :
423
435
raise ValueError ("Role can not be null for deploying a model" )
@@ -514,6 +526,8 @@ def transformer(
514
526
volume_kms_key (str): Optional. KMS key ID for encrypting the volume
515
527
attached to the ML compute instance (default: None).
516
528
"""
529
+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
530
+
517
531
self ._create_sagemaker_model (instance_type , tags = tags )
518
532
if self .enable_network_isolation ():
519
533
env = None
0 commit comments