Skip to content

Commit 4760601

Browse files
authored
Save training output files in local mode (#177)
1 parent 8738b21 commit 4760601

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CHANGELOG
77
========
88

99
* bug-fix: Change module names to string type in __all__
10+
* feature: Save training output files in local mode
1011
* bug-fix: tensorflow-serving-api: SageMaker does not conflict with tensorflow-serving-api module version
1112
* feature: Local Mode: add support for local training data using file://
1213
* feature: Updated TensorFlow Serving api protobuf files

src/sagemaker/local/image.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def train(self, input_data_config, hyperparameters):
141141
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
142142
_execute_and_stream_output(compose_command)
143143

144-
s3_model_artifacts = self.retrieve_model_artifacts(compose_data)
144+
s3_artifacts = self.retrieve_artifacts(compose_data)
145145

146146
# free up the training data directory as it may contain
147147
# lots of data downloaded from S3. This doesn't delete any local
@@ -157,7 +157,7 @@ def train(self, input_data_config, hyperparameters):
157157
# Print our Job Complete line to have a simmilar experience to training on SageMaker where you
158158
# see this line at the end.
159159
print('===== Job Complete =====')
160-
return s3_model_artifacts
160+
return s3_artifacts
161161

162162
def serve(self, primary_container):
163163
"""Host a local endpoint using docker-compose.
@@ -209,7 +209,7 @@ def stop_serving(self):
209209
# for serving we can delete everything in the container root.
210210
_delete_tree(self.container_root)
211211

212-
def retrieve_model_artifacts(self, compose_data):
212+
def retrieve_artifacts(self, compose_data):
213213
"""Get the model artifacts from all the container nodes.
214214
215215
Used after training completes to gather the data from all the individual containers. As the
@@ -223,8 +223,13 @@ def retrieve_model_artifacts(self, compose_data):
223223
224224
"""
225225
# Grab the model artifacts from all the Nodes.
226-
s3_model_artifacts = os.path.join(self.container_root, 's3_model_artifacts')
226+
s3_artifacts = os.path.join(self.container_root, 's3_artifacts')
227+
os.mkdir(s3_artifacts)
228+
229+
s3_model_artifacts = os.path.join(s3_artifacts, 'model')
230+
s3_output_artifacts = os.path.join(s3_artifacts, 'output')
227231
os.mkdir(s3_model_artifacts)
232+
os.mkdir(s3_output_artifacts)
228233

229234
for host in self.hosts:
230235
volumes = compose_data['services'][str(host)]['volumes']
@@ -233,6 +238,8 @@ def retrieve_model_artifacts(self, compose_data):
233238
host_dir, container_dir = volume.split(':')
234239
if container_dir == '/opt/ml/model':
235240
self._recursive_copy(host_dir, s3_model_artifacts)
241+
elif container_dir == '/opt/ml/output':
242+
self._recursive_copy(host_dir, s3_output_artifacts)
236243

237244
return s3_model_artifacts
238245

tests/unit/test_image.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,46 +117,64 @@ def test_retrieve_artifacts(LocalSession, tmpdir):
117117
sagemaker_container.hosts = ['algo-1', 'algo-2'] # avoid any randomness
118118
sagemaker_container.container_root = str(tmpdir.mkdir('container-root'))
119119

120-
volume1 = os.path.join(sagemaker_container.container_root, 'algo-1/output/')
121-
volume2 = os.path.join(sagemaker_container.container_root, 'algo-2/output/')
122-
os.makedirs(volume1)
123-
os.makedirs(volume2)
120+
volume1 = os.path.join(sagemaker_container.container_root, 'algo-1')
121+
volume2 = os.path.join(sagemaker_container.container_root, 'algo-2')
122+
os.mkdir(volume1)
123+
os.mkdir(volume2)
124124

125125
compose_data = {
126126
'services': {
127127
'algo-1': {
128-
'volumes': ['%s:/opt/ml/model' % volume1]
128+
'volumes': ['%s:/opt/ml/model' % os.path.join(volume1, 'model'),
129+
'%s:/opt/ml/output' % os.path.join(volume1, 'output')]
129130
},
130131
'algo-2': {
131-
'volumes': ['%s:/opt/ml/model' % volume2]
132+
'volumes': ['%s:/opt/ml/model' % os.path.join(volume2, 'model'),
133+
'%s:/opt/ml/output' % os.path.join(volume2, 'output')]
132134
}
133135
}
134136
}
135137

136138
dirs1 = ['model', 'model/data']
137139
dirs2 = ['model', 'model/data', 'model/tmp']
140+
dirs3 = ['output', 'output/data']
141+
dirs4 = ['output', 'output/data', 'output/log']
138142

139143
files1 = ['model/data/model.json', 'model/data/variables.csv']
140144
files2 = ['model/data/model.json', 'model/data/variables2.csv', 'model/tmp/something-else.json']
145+
files3 = ['output/data/loss.json', 'output/data/accuracy.json']
146+
files4 = ['output/data/loss.json', 'output/data/accuracy2.json', 'output/log/warnings.txt']
141147

142148
expected = ['model', 'model/data/', 'model/data/model.json', 'model/data/variables.csv',
143-
'model/data/variables2.csv', 'model/tmp/something-else.json']
149+
'model/data/variables2.csv', 'model/tmp/something-else.json', 'output', 'output/data', 'output/log',
150+
'output/data/loss.json', 'output/data/accuracy.json', 'output/data/accuracy2.json',
151+
'output/log/warnings.txt']
144152

145153
for d in dirs1:
146154
os.mkdir(os.path.join(volume1, d))
147155
for d in dirs2:
148156
os.mkdir(os.path.join(volume2, d))
157+
for d in dirs3:
158+
os.mkdir(os.path.join(volume1, d))
159+
for d in dirs4:
160+
os.mkdir(os.path.join(volume2, d))
149161

150162
# create all the files
151163
for f in files1:
152164
open(os.path.join(volume1, f), 'a').close()
153165
for f in files2:
154166
open(os.path.join(volume2, f), 'a').close()
167+
for f in files3:
168+
open(os.path.join(volume1, f), 'a').close()
169+
for f in files4:
170+
open(os.path.join(volume2, f), 'a').close()
155171

156-
s3_model_artifacts = sagemaker_container.retrieve_model_artifacts(compose_data)
172+
s3_model_artifacts = sagemaker_container.retrieve_artifacts(compose_data)
173+
s3_artifacts = os.path.dirname(s3_model_artifacts)
157174

158175
for f in expected:
159-
assert os.path.exists(os.path.join(s3_model_artifacts, f))
176+
assert set(os.listdir(s3_artifacts)) == set(['model', 'output'])
177+
assert os.path.exists(os.path.join(s3_artifacts, f))
160178

161179

162180
def test_stream_output():

0 commit comments

Comments
 (0)