Skip to content

Commit 9e70fa3

Browse files
committed
fix tests
1 parent 099eac9 commit 9e70fa3

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

src/sagemaker/modules/local_core/local_container.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def train(
209209

210210
if remove_inputs_and_container_artifacts:
211211
shutil.rmtree(os.path.join(self.container_root, "input"))
212+
shutil.rmtree(os.path.join(self.container_root, "shared"))
212213
for host in self.hosts:
213214
shutil.rmtree(os.path.join(self.container_root, host))
214215
for folder in self._temperary_folders:

tests/integ/sagemaker/modules/train/test_local_model_trainer.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def test_single_container_local_mode_s3_data_not_remove_input(modules_sagemaker_
225225
delete_local_path(path)
226226

227227

228-
def test_multi_container_local_mode(modules_sagemaker_session):
228+
def test_multi_container_local_mode_remove_input(modules_sagemaker_session):
229229
with lock.lock(LOCK_PATH):
230230
try:
231231
source_code = SourceCode(
@@ -265,6 +265,68 @@ def test_multi_container_local_mode(modules_sagemaker_session):
265265

266266
model_trainer.train()
267267
assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz"))
268+
269+
finally:
270+
subprocess.run(["docker", "compose", "down", "-v"])
271+
272+
assert not os.path.exists(os.path.join(CWD, "shared"))
273+
assert not os.path.exists(os.path.join(CWD, "input"))
274+
assert not os.path.exists(os.path.join(CWD, "algo-1"))
275+
assert not os.path.exists(os.path.join(CWD, "algo-2"))
276+
277+
directories = [
278+
"compressed_artifacts",
279+
"artifacts",
280+
"model",
281+
"output",
282+
]
283+
284+
for directory in directories:
285+
path = os.path.join(CWD, directory)
286+
delete_local_path(path)
287+
288+
289+
def test_multi_container_local_mode_not_remove_input(modules_sagemaker_session):
290+
with lock.lock(LOCK_PATH):
291+
try:
292+
source_code = SourceCode(
293+
source_dir=SOURCE_DIR,
294+
entry_script="local_training_script.py",
295+
)
296+
297+
distributed = Torchrun(
298+
process_count_per_node=1,
299+
)
300+
301+
compute = Compute(
302+
instance_type="local_cpu",
303+
instance_count=2,
304+
)
305+
306+
train_data = InputData(
307+
channel_name="train",
308+
data_source=os.path.join(SOURCE_DIR, "data/train/"),
309+
)
310+
311+
test_data = InputData(
312+
channel_name="test",
313+
data_source=os.path.join(SOURCE_DIR, "data/test/"),
314+
)
315+
316+
model_trainer = ModelTrainer(
317+
training_image=DEFAULT_CPU_IMAGE,
318+
sagemaker_session=modules_sagemaker_session,
319+
source_code=source_code,
320+
distributed=distributed,
321+
compute=compute,
322+
input_data_config=[train_data, test_data],
323+
base_job_name="local_mode_multi_container",
324+
training_mode=Mode.LOCAL_CONTAINER,
325+
remove_inputs_and_container_artifacts=False,
326+
)
327+
328+
model_trainer.train()
329+
assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz"))
268330
assert os.path.exists(os.path.join(CWD, "algo-1"))
269331
assert os.path.exists(os.path.join(CWD, "algo-2"))
270332

@@ -274,7 +336,11 @@ def test_multi_container_local_mode(modules_sagemaker_session):
274336
"compressed_artifacts",
275337
"artifacts",
276338
"model",
339+
"shared",
340+
"input",
277341
"output",
342+
"algo-1",
343+
"algo-2",
278344
]
279345

280346
for directory in directories:

0 commit comments

Comments
 (0)