Skip to content

Commit 781ffd1

Browse files
akorobkoyangaws
authored andcommitted
fix: hyperparameter query failure on script mode estimator attached to complete job (#718)
1 parent 475e051 commit 781ffd1

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,10 @@ def _default_s3_path(self, directory, mpi=False):
475475
return '/opt/ml/shared/{}'.format(directory)
476476
elif mpi:
477477
return '/opt/ml/model'
478-
else:
478+
elif self._current_job_name:
479479
return os.path.join(self.output_path, self._current_job_name, directory)
480+
else:
481+
return None
480482

481483
def _script_mode_enabled(self):
482484
return self.py_version == 'py3' or self.script_mode

tests/unit/test_tf_estimator.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,48 @@ def test_tf_script_mode_mpi(time, strftime, sagemaker_session):
830830

831831
actual_train_args = sagemaker_session.method_calls[0][2]
832832
assert actual_train_args == expected_train_args
833+
834+
835+
@patch('sagemaker.utils.create_tar_file', MagicMock())
836+
def test_tf_script_mode_attach(sagemaker_session, tf_version):
837+
training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py3-cpu:{}-cpu-py3'.format(tf_version)
838+
rjd = {
839+
'AlgorithmSpecification': {
840+
'TrainingInputMode': 'File',
841+
'TrainingImage': training_image
842+
},
843+
'HyperParameters': {
844+
'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"',
845+
'sagemaker_program': '"iris-dnn-classifier.py"',
846+
'sagemaker_enable_cloudwatch_metrics': 'false',
847+
'sagemaker_container_log_level': '"logging.INFO"',
848+
'sagemaker_job_name': '"neo"'
849+
},
850+
'RoleArn': 'arn:aws:iam::366:role/SageMakerRole',
851+
'ResourceConfig': {
852+
'VolumeSizeInGB': 30,
853+
'InstanceCount': 1,
854+
'InstanceType': 'ml.c4.xlarge'
855+
},
856+
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
857+
'TrainingJobName': 'neo',
858+
'TrainingJobStatus': 'Completed',
859+
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
860+
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
861+
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
862+
863+
estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
864+
assert estimator.latest_training_job.job_name == 'neo'
865+
assert estimator.py_version == 'py3'
866+
assert estimator.framework_version == tf_version
867+
assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
868+
assert estimator.train_instance_count == 1
869+
assert estimator.train_max_run == 24 * 60 * 60
870+
assert estimator.input_mode == 'File'
871+
assert estimator.input_mode == 'File'
872+
assert estimator.base_job_name == 'neo'
873+
assert estimator.output_path == 's3://place/output/neo'
874+
assert estimator.output_kms_key == ''
875+
assert estimator.hyperparameters() is not None
876+
assert estimator.source_dir == 's3://some/sourcedir.tar.gz'
877+
assert estimator.entry_point == 'iris-dnn-classifier.py'

0 commit comments

Comments
 (0)