Skip to content

Commit 9eac2d7

Browse files
author
Chuyang Deng
committed
Update method comments and modify model_dir.
1 parent d7fa2b6 commit 9eac2d7

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,10 @@ def main():
193193
hyperparameters = framework.env.read_hyperparameters()
194194
env = framework.training_env(hyperparameters=hyperparameters)
195195

196+
# If the training job is part of the multiple training jobs for tuning, we need to append the training job name to
197+
# model_dir in case they read from/write to the same object
196198
if '_tuning_objective_metric' in hyperparameters:
197-
env.hyperparameters['model_dir'] = env.job_name
199+
env.hyperparameters['model_dir'] = '{}-{}'.format(hyperparameters.get('model_dir'), env.job_name)
198200

199201
s3_utils.configure(env.hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
200202
logger.setLevel(env.log_level)

test/unit/test_training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
254254
train.assert_called_once_with(single_machine_training_env)
255255
configure_s3_env.assert_called_once()
256256

257+
257258
@patch('sagemaker_tensorflow_container.training.logger')
258259
@patch('sagemaker_tensorflow_container.training.train')
259260
@patch('logging.Logger.setLevel')
@@ -280,4 +281,5 @@ def test_main_tunning_model_dir(configure_s3_env, read_hyperparameters, training
280281
training_env.return_value = single_machine_training_env
281282
os.environ['SAGEMAKER_REGION'] = REGION
282283
training.main()
283-
configure_s3_env.assert_called_once_with(single_machine_training_env.job_name, REGION)
284+
expected_model_dir = '{}-{}'.format(MODEL_DIR, single_machine_training_env.job_name)
285+
configure_s3_env.assert_called_once_with(expected_model_dir, REGION)

0 commit comments

Comments
 (0)