Skip to content

Commit 7ff630f

Browse files
authored
Local mode improvements (#117)
* Change local mode GPU docker engine Switch local mode to use nvidia-docker2 + docker-compose instead of nvidia-docker + nvidia-docker-compose. Also get rid of the log output for billable seconds in local mode as it doesn't make sense. Finally, let the users know if we can't cleanup a container directory but not fail the training job.
1 parent 6184b22 commit 7ff630f

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

src/sagemaker/local/image.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,16 @@ def train(self, input_data_config, hyperparameters):
117117
# free up the training data directory as it may contain
118118
# lots of data downloaded from S3. This doesn't delete any local
119119
# data that was just mounted to the container.
120-
shutil.rmtree(data_dir)
120+
_delete_tree(data_dir)
121121
# Also free the container config files.
122122
for host in self.hosts:
123-
shutil.rmtree(os.path.join(self.container_root, host))
123+
container_config_path = os.path.join(self.container_root, host)
124+
_delete_tree(container_config_path)
124125

125126
self._cleanup()
127+
# Print our Job Complete line to have a simmilar experience to training on SageMaker where you
128+
# see this line at the end.
129+
print('===== Job Complete =====')
126130
return s3_model_artifacts
127131

128132
def serve(self, primary_container):
@@ -162,7 +166,7 @@ def stop_serving(self):
162166
self.container.down()
163167
self._cleanup()
164168
# for serving we can delete everything in the container root.
165-
shutil.rmtree(self.container_root)
169+
_delete_tree(self.container_root)
166170

167171
def retrieve_model_artifacts(self, compose_data):
168172
"""Get the model artifacts from all the container nodes.
@@ -185,9 +189,9 @@ def retrieve_model_artifacts(self, compose_data):
185189
volumes = compose_data['services'][str(host)]['volumes']
186190

187191
for volume in volumes:
188-
container_dir, host_dir = volume.split(':')
189-
if host_dir == '/opt/ml/model':
190-
self._recursive_copy(container_dir, s3_model_artifacts)
192+
host_dir, container_dir = volume.split(':')
193+
if container_dir == '/opt/ml/model':
194+
self._recursive_copy(host_dir, s3_model_artifacts)
191195

192196
return s3_model_artifacts
193197

@@ -304,7 +308,7 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
304308
return content
305309

306310
def _compose(self, detached=False):
307-
compose_cmd = 'nvidia-docker-compose' if self.instance_type == "local_gpu" else 'docker-compose'
311+
compose_cmd = 'docker-compose'
308312

309313
command = [
310314
compose_cmd,
@@ -480,6 +484,21 @@ def _create_config_file_directories(root, host):
480484
os.makedirs(os.path.join(root, host, d))
481485

482486

487+
def _delete_tree(path):
488+
try:
489+
shutil.rmtree(path)
490+
except OSError as exc:
491+
# on Linux, when docker writes to any mounted volume, it uses the container's user. In most cases
492+
# this is root. When the container exits and we try to delete them we can't because root owns those
493+
# files. We expect this to happen, so we handle EACCESS. Any other error we will raise the
494+
# exception up.
495+
if exc.errno == errno.EACCES:
496+
logger.warning("Failed to delete: %s Please remove it manually." % path)
497+
else:
498+
logger.error("Failed to delete: %s" % path)
499+
raise
500+
501+
483502
def _aws_credentials(session):
484503
try:
485504
creds = session.get_credentials()

src/sagemaker/local/local_session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,8 @@ def __init__(self, boto_session=None):
166166
logger.warning("Windows Support for Local Mode is Experimental")
167167
self.sagemaker_client = LocalSagemakerClient(self)
168168
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
169+
170+
def logs_for_job(self, job_name, wait=False, poll=5):
171+
# override logs_for_job() as it doesn't need to perform any action
172+
# on local mode.
173+
pass

0 commit comments

Comments
 (0)