@@ -190,6 +190,16 @@ def sagemaker_session():
190
190
return sms
191
191
192
192
193
+ @pytest .fixture ()
194
+ def training_job_description (sagemaker_session ):
195
+ returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
196
+ mock_describe_training_job = Mock (
197
+ name = "describe_training_job" , return_value = returned_job_description
198
+ )
199
+ sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
200
+ return returned_job_description
201
+
202
+
193
203
def test_framework_all_init_args (sagemaker_session ):
194
204
f = DummyFramework (
195
205
"my_script.py" ,
@@ -650,13 +660,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session):
650
660
assert train_kwargs ["hyperparameters" ]["sagemaker_enable_cloudwatch_metrics" ]
651
661
652
662
653
- def test_attach_framework (sagemaker_session ):
654
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
655
- returned_job_description ["VpcConfig" ] = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
656
- returned_job_description ["EnableNetworkIsolation" ] = True
657
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
658
- name = "describe_training_job" , return_value = returned_job_description
659
- )
663
+ def test_attach_framework (sagemaker_session , training_job_description ):
664
+ training_job_description ["VpcConfig" ] = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
665
+ training_job_description ["EnableNetworkIsolation" ] = True
660
666
661
667
framework_estimator = DummyFramework .attach (
662
668
training_job_name = "neo" , sagemaker_session = sagemaker_session
@@ -680,29 +686,25 @@ def test_attach_framework(sagemaker_session):
680
686
assert framework_estimator .enable_network_isolation () is True
681
687
682
688
683
- def test_attach_without_hyperparameters (sagemaker_session ):
684
- returned_job_description = RETURNED_JOB_DESCRIPTION . copy ( )
685
- del returned_job_description [ "HyperParameters" ]
689
+ def test_attach_no_logs (sagemaker_session , training_job_description ):
690
+ Estimator . attach ( training_job_name = "job" , sagemaker_session = sagemaker_session )
691
+ sagemaker_session . logs_for_job . assert_not_called ()
686
692
687
- mock_describe_training_job = Mock (
688
- name = "describe_training_job" , return_value = returned_job_description
689
- )
690
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
691
693
694
+ def test_logs (sagemaker_session , training_job_description ):
692
695
estimator = Estimator .attach (training_job_name = "job" , sagemaker_session = sagemaker_session )
693
-
694
- assert estimator . hyperparameters () == {}
696
+ estimator . logs ()
697
+ sagemaker_session . logs_for_job . assert_called_with ( estimator . latest_training_job , wait = True )
695
698
696
699
697
- def test_attach_framework_with_tuning (sagemaker_session ):
698
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
699
- returned_job_description ["HyperParameters" ]["_tuning_objective_metric" ] = "Validation-accuracy"
700
+ def test_attach_without_hyperparameters (sagemaker_session , training_job_description ):
701
+ del training_job_description ["HyperParameters" ]
702
+ estimator = Estimator .attach (training_job_name = "job" , sagemaker_session = sagemaker_session )
703
+ assert estimator .hyperparameters () == {}
700
704
701
- mock_describe_training_job = Mock (
702
- name = "describe_training_job" , return_value = returned_job_description
703
- )
704
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
705
705
706
+ def test_attach_framework_with_tuning (sagemaker_session , training_job_description ):
707
+ training_job_description ["HyperParameters" ]["_tuning_objective_metric" ] = "Validation-accuracy"
706
708
framework_estimator = DummyFramework .attach (
707
709
training_job_name = "neo" , sagemaker_session = sagemaker_session
708
710
)
@@ -722,48 +724,35 @@ def test_attach_framework_with_tuning(sagemaker_session):
722
724
assert framework_estimator .encrypt_inter_container_traffic is False
723
725
724
726
725
- def test_attach_framework_with_model_channel (sagemaker_session ):
727
+ def test_attach_framework_with_model_channel (sagemaker_session , training_job_description ):
726
728
s3_uri = "s3://some/s3/path/model.tar.gz"
727
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
728
- returned_job_description ["InputDataConfig" ] = [
729
+ training_job_description ["InputDataConfig" ] = [
729
730
{
730
731
"ChannelName" : "model" ,
731
732
"InputMode" : "File" ,
732
733
"DataSource" : {"S3DataSource" : {"S3Uri" : s3_uri }},
733
734
}
734
735
]
735
736
736
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
737
- name = "describe_training_job" , return_value = returned_job_description
738
- )
739
-
740
737
framework_estimator = DummyFramework .attach (
741
738
training_job_name = "neo" , sagemaker_session = sagemaker_session
742
739
)
743
740
assert framework_estimator .model_uri is s3_uri
744
741
assert framework_estimator .encrypt_inter_container_traffic is False
745
742
746
743
747
- def test_attach_framework_with_inter_container_traffic_encryption_flag (sagemaker_session ):
748
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
749
- returned_job_description ["EnableInterContainerTrafficEncryption" ] = True
750
-
751
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
752
- name = "describe_training_job" , return_value = returned_job_description
753
- )
754
-
744
+ def test_attach_framework_with_inter_container_traffic_encryption_flag (
745
+ sagemaker_session , training_job_description
746
+ ):
747
+ training_job_description ["EnableInterContainerTrafficEncryption" ] = True
755
748
framework_estimator = DummyFramework .attach (
756
749
training_job_name = "neo" , sagemaker_session = sagemaker_session
757
750
)
758
751
759
752
assert framework_estimator .encrypt_inter_container_traffic is True
760
753
761
754
762
- def test_attach_framework_base_from_generated_name (sagemaker_session ):
763
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
764
- name = "describe_training_job" , return_value = RETURNED_JOB_DESCRIPTION
765
- )
766
-
755
+ def test_attach_framework_base_from_generated_name (sagemaker_session , training_job_description ):
767
756
base_job_name = "neo"
768
757
framework_estimator = DummyFramework .attach (
769
758
training_job_name = utils .name_from_base ("neo" ), sagemaker_session = sagemaker_session
0 commit comments