@@ -435,6 +435,21 @@ def test_model_create_transformer(sagemaker_session):
435
435
sagemaker .model .Model ._create_sagemaker_model .assert_called_with (instance_type , tags = tags )
436
436
437
437
438
+ @patch ("sagemaker.session.Session" )
439
+ @patch ("sagemaker.local.LocalSession" )
440
+ @patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
441
+ def test_transformer_creates_correct_session (local_session , session ):
442
+ model = DummyFrameworkModel (sagemaker_session = None )
443
+ transformer = model .transformer (instance_count = 1 , instance_type = "local" )
444
+ assert model .sagemaker_session == local_session .return_value
445
+ assert transformer .sagemaker_session == local_session .return_value
446
+
447
+ model = DummyFrameworkModel (sagemaker_session = None )
448
+ transformer = model .transformer (instance_count = 1 , instance_type = "ml.m5.xlarge" )
449
+ assert model .sagemaker_session == session .return_value
450
+ assert transformer .sagemaker_session == session .return_value
451
+
452
+
438
453
def test_model_package_enable_network_isolation_with_no_product_id (sagemaker_session ):
439
454
sagemaker_session .sagemaker_client .describe_model_package = Mock (
440
455
return_value = DESCRIBE_MODEL_PACKAGE_RESPONSE
@@ -561,6 +576,21 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
561
576
assert model ._is_compiled_model is True
562
577
563
578
579
+ @patch ("sagemaker.session.Session" )
580
+ @patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
581
+ def test_compile_creates_session (session ):
582
+ model = DummyFrameworkModel (sagemaker_session = None )
583
+ model .compile (
584
+ target_instance_family = "ml_c4" ,
585
+ input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
586
+ output_path = "s3://output" ,
587
+ role = "role" ,
588
+ framework = "tensorflow" ,
589
+ )
590
+
591
+ assert model .sagemaker_sesion == session .return_value
592
+
593
+
564
594
def test_check_neo_region (sagemaker_session , tmpdir ):
565
595
sagemaker_session .wait_for_compilation_job = Mock (
566
596
return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
0 commit comments