@@ -225,7 +225,7 @@ def test_single_container_local_mode_s3_data_not_remove_input(modules_sagemaker_
225
225
delete_local_path (path )
226
226
227
227
228
- def test_multi_container_local_mode (modules_sagemaker_session ):
228
+ def test_multi_container_local_mode_remove_input (modules_sagemaker_session ):
229
229
with lock .lock (LOCK_PATH ):
230
230
try :
231
231
source_code = SourceCode (
@@ -265,6 +265,68 @@ def test_multi_container_local_mode(modules_sagemaker_session):
265
265
266
266
model_trainer .train ()
267
267
assert os .path .exists (os .path .join (CWD , "compressed_artifacts/model.tar.gz" ))
268
+
269
+ finally :
270
+ subprocess .run (["docker" , "compose" , "down" , "-v" ])
271
+
272
+ assert not os .path .exists (os .path .join (CWD , "shared" ))
273
+ assert not os .path .exists (os .path .join (CWD , "input" ))
274
+ assert not os .path .exists (os .path .join (CWD , "algo-1" ))
275
+ assert not os .path .exists (os .path .join (CWD , "algo-2" ))
276
+
277
+ directories = [
278
+ "compressed_artifacts" ,
279
+ "artifacts" ,
280
+ "model" ,
281
+ "output" ,
282
+ ]
283
+
284
+ for directory in directories :
285
+ path = os .path .join (CWD , directory )
286
+ delete_local_path (path )
287
+
288
+
289
+ def test_multi_container_local_mode_not_remove_input (modules_sagemaker_session ):
290
+ with lock .lock (LOCK_PATH ):
291
+ try :
292
+ source_code = SourceCode (
293
+ source_dir = SOURCE_DIR ,
294
+ entry_script = "local_training_script.py" ,
295
+ )
296
+
297
+ distributed = Torchrun (
298
+ process_count_per_node = 1 ,
299
+ )
300
+
301
+ compute = Compute (
302
+ instance_type = "local_cpu" ,
303
+ instance_count = 2 ,
304
+ )
305
+
306
+ train_data = InputData (
307
+ channel_name = "train" ,
308
+ data_source = os .path .join (SOURCE_DIR , "data/train/" ),
309
+ )
310
+
311
+ test_data = InputData (
312
+ channel_name = "test" ,
313
+ data_source = os .path .join (SOURCE_DIR , "data/test/" ),
314
+ )
315
+
316
+ model_trainer = ModelTrainer (
317
+ training_image = DEFAULT_CPU_IMAGE ,
318
+ sagemaker_session = modules_sagemaker_session ,
319
+ source_code = source_code ,
320
+ distributed = distributed ,
321
+ compute = compute ,
322
+ input_data_config = [train_data , test_data ],
323
+ base_job_name = "local_mode_multi_container" ,
324
+ training_mode = Mode .LOCAL_CONTAINER ,
325
+ remove_inputs_and_container_artifacts = False ,
326
+ )
327
+
328
+ model_trainer .train ()
329
+ assert os .path .exists (os .path .join (CWD , "compressed_artifacts/model.tar.gz" ))
268
330
assert os .path .exists (os .path .join (CWD , "algo-1" ))
269
331
assert os .path .exists (os .path .join (CWD , "algo-2" ))
270
332
@@ -274,7 +336,11 @@ def test_multi_container_local_mode(modules_sagemaker_session):
274
336
"compressed_artifacts" ,
275
337
"artifacts" ,
276
338
"model" ,
339
+ "shared" ,
340
+ "input" ,
277
341
"output" ,
342
+ "algo-1" ,
343
+ "algo-2" ,
278
344
]
279
345
280
346
for directory in directories :
0 commit comments