Skip to content

Commit 015108c

Browse files
authored
fix: make repack_model only removes py file when new entry_point provided (#1352)
* fix: make repack_model only removes py files when new entry_point present
1 parent 4d05b8e commit 015108c

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/sagemaker/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,9 @@ def _create_or_update_code_dir(
530530
"""
531531
code_dir = os.path.join(model_dir, "code")
532532
if os.path.exists(code_dir):
533-
shutil.rmtree(code_dir, ignore_errors=True)
533+
for filename in os.listdir(code_dir):
534+
if filename.endswith(".py"):
535+
os.remove(os.path.join(code_dir, filename))
534536
if source_directory and source_directory.lower().startswith("s3://"):
535537
local_code_path = os.path.join(tmp, "local_code.tar.gz")
536538
download_file_from_url(source_directory, local_code_path, sagemaker_session)
@@ -539,9 +541,12 @@ def _create_or_update_code_dir(
539541
t.extractall(path=code_dir)
540542

541543
elif source_directory:
544+
if os.path.exists(code_dir):
545+
shutil.rmtree(code_dir)
542546
shutil.copytree(source_directory, code_dir)
543547
else:
544-
os.mkdir(code_dir)
548+
if not os.path.exists(code_dir):
549+
os.mkdir(code_dir)
545550
shutil.copy2(inference_script, code_dir)
546551

547552
for dependency in dependencies:

tests/unit/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,35 @@ def test_repack_model_from_file_to_folder(tmp):
614614
}
615615

616616

617+
def test_repack_model_with_inference_code_and_requirements(tmp, fake_s3):
618+
create_file_tree(
619+
tmp,
620+
[
621+
"new-inference.py",
622+
"model-dir/model",
623+
"model-dir/code/old-inference.py",
624+
"model-dir/code/requirements.txt",
625+
],
626+
)
627+
628+
fake_s3.tar_and_upload("model-dir", "s3://fake/location")
629+
630+
sagemaker.utils.repack_model(
631+
os.path.join(tmp, "new-inference.py"),
632+
None,
633+
None,
634+
"s3://fake/location",
635+
"s3://destination-bucket/repacked-model",
636+
fake_s3.sagemaker_session,
637+
)
638+
639+
assert list_tar_files(fake_s3.fake_upload_path, tmp) == {
640+
"/code/requirements.txt",
641+
"/code/new-inference.py",
642+
"/model",
643+
}
644+
645+
617646
class FakeS3(object):
618647
def __init__(self, tmp):
619648
self.tmp = tmp

0 commit comments

Comments
 (0)