@@ -117,46 +117,64 @@ def test_retrieve_artifacts(LocalSession, tmpdir):
117
117
sagemaker_container .hosts = ['algo-1' , 'algo-2' ] # avoid any randomness
118
118
sagemaker_container .container_root = str (tmpdir .mkdir ('container-root' ))
119
119
120
- volume1 = os .path .join (sagemaker_container .container_root , 'algo-1/output/ ' )
121
- volume2 = os .path .join (sagemaker_container .container_root , 'algo-2/output/ ' )
122
- os .makedirs (volume1 )
123
- os .makedirs (volume2 )
120
+ volume1 = os .path .join (sagemaker_container .container_root , 'algo-1' )
121
+ volume2 = os .path .join (sagemaker_container .container_root , 'algo-2' )
122
+ os .mkdir (volume1 )
123
+ os .mkdir (volume2 )
124
124
125
125
compose_data = {
126
126
'services' : {
127
127
'algo-1' : {
128
- 'volumes' : ['%s:/opt/ml/model' % volume1 ]
128
+ 'volumes' : ['%s:/opt/ml/model' % os .path .join (volume1 , 'model' ),
129
+ '%s:/opt/ml/output' % os .path .join (volume1 , 'output' )]
129
130
},
130
131
'algo-2' : {
131
- 'volumes' : ['%s:/opt/ml/model' % volume2 ]
132
+ 'volumes' : ['%s:/opt/ml/model' % os .path .join (volume2 , 'model' ),
133
+ '%s:/opt/ml/output' % os .path .join (volume2 , 'output' )]
132
134
}
133
135
}
134
136
}
135
137
136
138
dirs1 = ['model' , 'model/data' ]
137
139
dirs2 = ['model' , 'model/data' , 'model/tmp' ]
140
+ dirs3 = ['output' , 'output/data' ]
141
+ dirs4 = ['output' , 'output/data' , 'output/log' ]
138
142
139
143
files1 = ['model/data/model.json' , 'model/data/variables.csv' ]
140
144
files2 = ['model/data/model.json' , 'model/data/variables2.csv' , 'model/tmp/something-else.json' ]
145
+ files3 = ['output/data/loss.json' , 'output/data/accuracy.json' ]
146
+ files4 = ['output/data/loss.json' , 'output/data/accuracy2.json' , 'output/log/warnings.txt' ]
141
147
142
148
expected = ['model' , 'model/data/' , 'model/data/model.json' , 'model/data/variables.csv' ,
143
- 'model/data/variables2.csv' , 'model/tmp/something-else.json' ]
149
+ 'model/data/variables2.csv' , 'model/tmp/something-else.json' , 'output' , 'output/data' , 'output/log' ,
150
+ 'output/data/loss.json' , 'output/data/accuracy.json' , 'output/data/accuracy2.json' ,
151
+ 'output/log/warnings.txt' ]
144
152
145
153
for d in dirs1 :
146
154
os .mkdir (os .path .join (volume1 , d ))
147
155
for d in dirs2 :
148
156
os .mkdir (os .path .join (volume2 , d ))
157
+ for d in dirs3 :
158
+ os .mkdir (os .path .join (volume1 , d ))
159
+ for d in dirs4 :
160
+ os .mkdir (os .path .join (volume2 , d ))
149
161
150
162
# create all the files
151
163
for f in files1 :
152
164
open (os .path .join (volume1 , f ), 'a' ).close ()
153
165
for f in files2 :
154
166
open (os .path .join (volume2 , f ), 'a' ).close ()
167
+ for f in files3 :
168
+ open (os .path .join (volume1 , f ), 'a' ).close ()
169
+ for f in files4 :
170
+ open (os .path .join (volume2 , f ), 'a' ).close ()
155
171
156
- s3_model_artifacts = sagemaker_container .retrieve_model_artifacts (compose_data )
172
+ s3_model_artifacts = sagemaker_container .retrieve_artifacts (compose_data )
173
+ s3_artifacts = os .path .dirname (s3_model_artifacts )
157
174
158
175
for f in expected :
159
- assert os .path .exists (os .path .join (s3_model_artifacts , f ))
176
+ assert set (os .listdir (s3_artifacts )) == set (['model' , 'output' ])
177
+ assert os .path .exists (os .path .join (s3_artifacts , f ))
160
178
161
179
162
180
def test_stream_output ():
0 commit comments