@@ -830,3 +830,48 @@ def test_tf_script_mode_mpi(time, strftime, sagemaker_session):
830
830
831
831
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
832
832
assert actual_train_args == expected_train_args
833
+
834
+
835
+ @patch ('sagemaker.utils.create_tar_file' , MagicMock ())
836
+ def test_tf_script_mode_attach (sagemaker_session , tf_version ):
837
+ training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py3-cpu:{}-cpu-py3' .format (tf_version )
838
+ rjd = {
839
+ 'AlgorithmSpecification' : {
840
+ 'TrainingInputMode' : 'File' ,
841
+ 'TrainingImage' : training_image
842
+ },
843
+ 'HyperParameters' : {
844
+ 'sagemaker_submit_directory' : '"s3://some/sourcedir.tar.gz"' ,
845
+ 'sagemaker_program' : '"iris-dnn-classifier.py"' ,
846
+ 'sagemaker_enable_cloudwatch_metrics' : 'false' ,
847
+ 'sagemaker_container_log_level' : '"logging.INFO"' ,
848
+ 'sagemaker_job_name' : '"neo"'
849
+ },
850
+ 'RoleArn' : 'arn:aws:iam::366:role/SageMakerRole' ,
851
+ 'ResourceConfig' : {
852
+ 'VolumeSizeInGB' : 30 ,
853
+ 'InstanceCount' : 1 ,
854
+ 'InstanceType' : 'ml.c4.xlarge'
855
+ },
856
+ 'StoppingCondition' : {'MaxRuntimeInSeconds' : 24 * 60 * 60 },
857
+ 'TrainingJobName' : 'neo' ,
858
+ 'TrainingJobStatus' : 'Completed' ,
859
+ 'OutputDataConfig' : {'KmsKeyId' : '' , 'S3OutputPath' : 's3://place/output/neo' },
860
+ 'TrainingJobOutput' : {'S3TrainingJobOutput' : 's3://here/output.tar.gz' }}
861
+ sagemaker_session .sagemaker_client .describe_training_job = Mock (name = 'describe_training_job' , return_value = rjd )
862
+
863
+ estimator = TensorFlow .attach (training_job_name = 'neo' , sagemaker_session = sagemaker_session )
864
+ assert estimator .latest_training_job .job_name == 'neo'
865
+ assert estimator .py_version == 'py3'
866
+ assert estimator .framework_version == tf_version
867
+ assert estimator .role == 'arn:aws:iam::366:role/SageMakerRole'
868
+ assert estimator .train_instance_count == 1
869
+ assert estimator .train_max_run == 24 * 60 * 60
870
+ assert estimator .input_mode == 'File'
871
+ assert estimator .input_mode == 'File'
872
+ assert estimator .base_job_name == 'neo'
873
+ assert estimator .output_path == 's3://place/output/neo'
874
+ assert estimator .output_kms_key == ''
875
+ assert estimator .hyperparameters () is not None
876
+ assert estimator .source_dir == 's3://some/sourcedir.tar.gz'
877
+ assert estimator .entry_point == 'iris-dnn-classifier.py'
0 commit comments