Skip to content

Save training output files in local mode #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CHANGELOG
========

* bug-fix: Change module names to string type in __all__
* feature: Save training output files in local mode
* bug-fix: tensorflow-serving-api: SageMaker does not conflict with tensorflow-serving-api module version
* feature: Local Mode: add support for local training data using file://
* feature: Updated TensorFlow Serving api protobuf files
Expand Down
15 changes: 11 additions & 4 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def train(self, input_data_config, hyperparameters):
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
_execute_and_stream_output(compose_command)

s3_model_artifacts = self.retrieve_model_artifacts(compose_data)
s3_artifacts = self.retrieve_artifacts(compose_data)

# free up the training data directory as it may contain
# lots of data downloaded from S3. This doesn't delete any local
Expand All @@ -157,7 +157,7 @@ def train(self, input_data_config, hyperparameters):
# Print our Job Complete line to have a simmilar experience to training on SageMaker where you
# see this line at the end.
print('===== Job Complete =====')
return s3_model_artifacts
return s3_artifacts

def serve(self, primary_container):
"""Host a local endpoint using docker-compose.
Expand Down Expand Up @@ -209,7 +209,7 @@ def stop_serving(self):
# for serving we can delete everything in the container root.
_delete_tree(self.container_root)

def retrieve_model_artifacts(self, compose_data):
def retrieve_artifacts(self, compose_data):
"""Get the model artifacts from all the container nodes.

Used after training completes to gather the data from all the individual containers. As the
Expand All @@ -223,8 +223,13 @@ def retrieve_model_artifacts(self, compose_data):

"""
# Grab the model artifacts from all the Nodes.
s3_model_artifacts = os.path.join(self.container_root, 's3_model_artifacts')
s3_artifacts = os.path.join(self.container_root, 's3_artifacts')
os.mkdir(s3_artifacts)

s3_model_artifacts = os.path.join(s3_artifacts, 'model')
s3_output_artifacts = os.path.join(s3_artifacts, 'output')
os.mkdir(s3_model_artifacts)
os.mkdir(s3_output_artifacts)

for host in self.hosts:
volumes = compose_data['services'][str(host)]['volumes']
Expand All @@ -233,6 +238,8 @@ def retrieve_model_artifacts(self, compose_data):
host_dir, container_dir = volume.split(':')
if container_dir == '/opt/ml/model':
self._recursive_copy(host_dir, s3_model_artifacts)
elif container_dir == '/opt/ml/output':
self._recursive_copy(host_dir, s3_output_artifacts)

return s3_model_artifacts

Expand Down
36 changes: 27 additions & 9 deletions tests/unit/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,46 +117,64 @@ def test_retrieve_artifacts(LocalSession, tmpdir):
sagemaker_container.hosts = ['algo-1', 'algo-2'] # avoid any randomness
sagemaker_container.container_root = str(tmpdir.mkdir('container-root'))

volume1 = os.path.join(sagemaker_container.container_root, 'algo-1/output/')
volume2 = os.path.join(sagemaker_container.container_root, 'algo-2/output/')
os.makedirs(volume1)
os.makedirs(volume2)
volume1 = os.path.join(sagemaker_container.container_root, 'algo-1')
volume2 = os.path.join(sagemaker_container.container_root, 'algo-2')
os.mkdir(volume1)
os.mkdir(volume2)

compose_data = {
'services': {
'algo-1': {
'volumes': ['%s:/opt/ml/model' % volume1]
'volumes': ['%s:/opt/ml/model' % os.path.join(volume1, 'model'),
'%s:/opt/ml/output' % os.path.join(volume1, 'output')]
},
'algo-2': {
'volumes': ['%s:/opt/ml/model' % volume2]
'volumes': ['%s:/opt/ml/model' % os.path.join(volume2, 'model'),
'%s:/opt/ml/output' % os.path.join(volume2, 'output')]
}
}
}

dirs1 = ['model', 'model/data']
dirs2 = ['model', 'model/data', 'model/tmp']
dirs3 = ['output', 'output/data']
dirs4 = ['output', 'output/data', 'output/log']

files1 = ['model/data/model.json', 'model/data/variables.csv']
files2 = ['model/data/model.json', 'model/data/variables2.csv', 'model/tmp/something-else.json']
files3 = ['output/data/loss.json', 'output/data/accuracy.json']
files4 = ['output/data/loss.json', 'output/data/accuracy2.json', 'output/log/warnings.txt']

expected = ['model', 'model/data/', 'model/data/model.json', 'model/data/variables.csv',
'model/data/variables2.csv', 'model/tmp/something-else.json']
'model/data/variables2.csv', 'model/tmp/something-else.json', 'output', 'output/data', 'output/log',
'output/data/loss.json', 'output/data/accuracy.json', 'output/data/accuracy2.json',
'output/log/warnings.txt']

for d in dirs1:
os.mkdir(os.path.join(volume1, d))
for d in dirs2:
os.mkdir(os.path.join(volume2, d))
for d in dirs3:
os.mkdir(os.path.join(volume1, d))
for d in dirs4:
os.mkdir(os.path.join(volume2, d))

# create all the files
for f in files1:
open(os.path.join(volume1, f), 'a').close()
for f in files2:
open(os.path.join(volume2, f), 'a').close()
for f in files3:
open(os.path.join(volume1, f), 'a').close()
for f in files4:
open(os.path.join(volume2, f), 'a').close()

s3_model_artifacts = sagemaker_container.retrieve_model_artifacts(compose_data)
s3_model_artifacts = sagemaker_container.retrieve_artifacts(compose_data)
s3_artifacts = os.path.dirname(s3_model_artifacts)

for f in expected:
assert os.path.exists(os.path.join(s3_model_artifacts, f))
assert set(os.listdir(s3_artifacts)) == set(['model', 'output'])
assert os.path.exists(os.path.join(s3_artifacts, f))


def test_stream_output():
Expand Down