-
Notifications
You must be signed in to change notification settings - Fork 162
Fix model_dir adjustment for hyperparameter tuning jobs #181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -125,7 +124,13 @@ def _wait_until_master_is_down(master): | |||
return | |||
|
|||
|
|||
def train(env): | |||
def _cmd_args(env, model_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure how much value this function provides. I would just fold it in _run_worker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's used twice: once in _run_worker
and once in _train
model_dir = _model_dir_with_training_job(hyperparameters.get('model_dir'), env.job_name) | ||
logger.info('Appending the training job name to model_dir: {}'.format(model_dir)) | ||
else: | ||
model_dir = hyperparameters.get('model_dir') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels a little awkward for the normal training job case. We are getting model_dir from hyperparameters then later on sets hyperparameters['model_dir'] to the same value. I would just make model_dir optional in _run_worker. YMMV.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we would then have to call framework.env.read_hyperparameters()
again (see my note in the description) - not sure if it's particularly expensive to do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about we just pass the cmd_args in _run_worker instead of model_dir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ship
Description of changes:
This is an improvement upon #179. This PR also relies on a corresponding change to SageMaker Containers (aws/sagemaker-containers#186), but this PR can be merged independently (i.e. the fix won't take effect, but the resulting image will still work as it does today).
A couple caveats:
And a few random notes about the code:
os.path.join
for S3 paths because S3 will uses slashes, regardless of the OS running the code to generate the path.TrainingEnv.hyperparameters
is read-only (code), so I couldn't just overwriteenv.hyperparameters['model_dir']
(or, consequently, useenv.to_cmd_args()
).framework.env.read_hyperparameters()
- by the time the hyperparameters are parsed forenv.hyperparameters
, there is not a hyperparameter to indicate that the training job belongs to a hyperparameter tuning job.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.