@@ -191,6 +191,16 @@ def sagemaker_session():
191
191
return sms
192
192
193
193
194
+ @pytest .fixture ()
195
+ def training_job_description (sagemaker_session ):
196
+ returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
197
+ mock_describe_training_job = Mock (
198
+ name = "describe_training_job" , return_value = returned_job_description
199
+ )
200
+ sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
201
+ return returned_job_description
202
+
203
+
194
204
def test_framework_all_init_args (sagemaker_session ):
195
205
f = DummyFramework (
196
206
"my_script.py" ,
@@ -651,13 +661,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session):
651
661
assert train_kwargs ["hyperparameters" ]["sagemaker_enable_cloudwatch_metrics" ]
652
662
653
663
654
- def test_attach_framework (sagemaker_session ):
655
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
656
- returned_job_description ["VpcConfig" ] = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
657
- returned_job_description ["EnableNetworkIsolation" ] = True
658
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
659
- name = "describe_training_job" , return_value = returned_job_description
660
- )
664
+ def test_attach_framework (sagemaker_session , training_job_description ):
665
+ training_job_description ["VpcConfig" ] = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
666
+ training_job_description ["EnableNetworkIsolation" ] = True
661
667
662
668
framework_estimator = DummyFramework .attach (
663
669
training_job_name = "neo" , sagemaker_session = sagemaker_session
@@ -681,29 +687,25 @@ def test_attach_framework(sagemaker_session):
681
687
assert framework_estimator .enable_network_isolation () is True
682
688
683
689
684
- def test_attach_without_hyperparameters (sagemaker_session ):
685
- returned_job_description = RETURNED_JOB_DESCRIPTION . copy ( )
686
- del returned_job_description [ "HyperParameters" ]
690
+ def test_attach_no_logs (sagemaker_session , training_job_description ):
691
+ Estimator . attach ( training_job_name = "job" , sagemaker_session = sagemaker_session )
692
+ sagemaker_session . logs_for_job . assert_not_called ()
687
693
688
- mock_describe_training_job = Mock (
689
- name = "describe_training_job" , return_value = returned_job_description
690
- )
691
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
692
694
695
+ def test_logs (sagemaker_session , training_job_description ):
693
696
estimator = Estimator .attach (training_job_name = "job" , sagemaker_session = sagemaker_session )
694
-
695
- assert estimator . hyperparameters () == {}
697
+ estimator . logs ()
698
+ sagemaker_session . logs_for_job . assert_called_with ( estimator . latest_training_job , wait = True )
696
699
697
700
698
- def test_attach_framework_with_tuning (sagemaker_session ):
699
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
700
- returned_job_description ["HyperParameters" ]["_tuning_objective_metric" ] = "Validation-accuracy"
701
+ def test_attach_without_hyperparameters (sagemaker_session , training_job_description ):
702
+ del training_job_description ["HyperParameters" ]
703
+ estimator = Estimator .attach (training_job_name = "job" , sagemaker_session = sagemaker_session )
704
+ assert estimator .hyperparameters () == {}
701
705
702
- mock_describe_training_job = Mock (
703
- name = "describe_training_job" , return_value = returned_job_description
704
- )
705
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
706
706
707
+ def test_attach_framework_with_tuning (sagemaker_session , training_job_description ):
708
+ training_job_description ["HyperParameters" ]["_tuning_objective_metric" ] = "Validation-accuracy"
707
709
framework_estimator = DummyFramework .attach (
708
710
training_job_name = "neo" , sagemaker_session = sagemaker_session
709
711
)
@@ -723,48 +725,35 @@ def test_attach_framework_with_tuning(sagemaker_session):
723
725
assert framework_estimator .encrypt_inter_container_traffic is False
724
726
725
727
726
- def test_attach_framework_with_model_channel (sagemaker_session ):
728
+ def test_attach_framework_with_model_channel (sagemaker_session , training_job_description ):
727
729
s3_uri = "s3://some/s3/path/model.tar.gz"
728
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
729
- returned_job_description ["InputDataConfig" ] = [
730
+ training_job_description ["InputDataConfig" ] = [
730
731
{
731
732
"ChannelName" : "model" ,
732
733
"InputMode" : "File" ,
733
734
"DataSource" : {"S3DataSource" : {"S3Uri" : s3_uri }},
734
735
}
735
736
]
736
737
737
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
738
- name = "describe_training_job" , return_value = returned_job_description
739
- )
740
-
741
738
framework_estimator = DummyFramework .attach (
742
739
training_job_name = "neo" , sagemaker_session = sagemaker_session
743
740
)
744
741
assert framework_estimator .model_uri is s3_uri
745
742
assert framework_estimator .encrypt_inter_container_traffic is False
746
743
747
744
748
- def test_attach_framework_with_inter_container_traffic_encryption_flag (sagemaker_session ):
749
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
750
- returned_job_description ["EnableInterContainerTrafficEncryption" ] = True
751
-
752
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
753
- name = "describe_training_job" , return_value = returned_job_description
754
- )
755
-
745
+ def test_attach_framework_with_inter_container_traffic_encryption_flag (
746
+ sagemaker_session , training_job_description
747
+ ):
748
+ training_job_description ["EnableInterContainerTrafficEncryption" ] = True
756
749
framework_estimator = DummyFramework .attach (
757
750
training_job_name = "neo" , sagemaker_session = sagemaker_session
758
751
)
759
752
760
753
assert framework_estimator .encrypt_inter_container_traffic is True
761
754
762
755
763
- def test_attach_framework_base_from_generated_name (sagemaker_session ):
764
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
765
- name = "describe_training_job" , return_value = RETURNED_JOB_DESCRIPTION
766
- )
767
-
756
+ def test_attach_framework_base_from_generated_name (sagemaker_session , training_job_description ):
768
757
base_job_name = "neo"
769
758
framework_estimator = DummyFramework .attach (
770
759
training_job_name = utils .name_from_base ("neo" ), sagemaker_session = sagemaker_session
0 commit comments