Skip to content

Local mode improvements #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,16 @@ def train(self, input_data_config, hyperparameters):
shutil.rmtree(data_dir)
# Also free the container config files.
for host in self.hosts:
shutil.rmtree(os.path.join(self.container_root, host))
container_config_path = os.path.join(self.container_root, host)
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you factor out this cleanup logic into a function? We don't know how long it'll stay around, and it might evolve over time, and we should keep it consistent between the various places.

shutil.rmtree(container_config_path)
except OSError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are probably other ways an OSError can be thrown here, right? I don't love catching a fairly generic exception type and swallowing it without showing a stacktrace.

I think we could go two directions here - be more specific about what we catch, or just catch everything and print out this error message with a stacktrace / enough information to debug. I'm guessing there's not a great way to do the first, so maybe we should do the second? Are there any cases where we would want to rethrow an exception because we failed cleanup?

logger.warning("Failed to delete: %s Please remove it manually." % container_config_path)

self._cleanup()
# Print our Job Complete line to have a simmilar experience to training on SageMaker where you
# see this line at the end.
print('===== Job Complete =====')
return s3_model_artifacts

def serve(self, primary_container):
Expand Down Expand Up @@ -162,7 +169,10 @@ def stop_serving(self):
self.container.down()
self._cleanup()
# for serving we can delete everything in the container root.
shutil.rmtree(self.container_root)
try:
shutil.rmtree(self.container_root)
except OSError:
logger.warning("Failed to delete: %s Please remove it manually." % self.container_root)

def retrieve_model_artifacts(self, compose_data):
"""Get the model artifacts from all the container nodes.
Expand All @@ -185,9 +195,9 @@ def retrieve_model_artifacts(self, compose_data):
volumes = compose_data['services'][str(host)]['volumes']

for volume in volumes:
container_dir, host_dir = volume.split(':')
if host_dir == '/opt/ml/model':
self._recursive_copy(container_dir, s3_model_artifacts)
host_dir, container_dir = volume.split(':')
if container_dir == '/opt/ml/model':
self._recursive_copy(host_dir, s3_model_artifacts)

return s3_model_artifacts

Expand Down Expand Up @@ -304,7 +314,7 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
return content

def _compose(self, detached=False):
compose_cmd = 'nvidia-docker-compose' if self.instance_type == "local_gpu" else 'docker-compose'
compose_cmd = 'docker-compose'

command = [
compose_cmd,
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,8 @@ def __init__(self, boto_session=None):
logger.warning("Windows Support for Local Mode is Experimental")
self.sagemaker_client = LocalSagemakerClient(self)
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)

def logs_for_job(self, job_name, wait=False, poll=5):
# override logs_for_job() as it doesn't need to perform any action
# on local mode.
pass