Skip to content

Commit c6c1f0a

Browse files
author
Payton Staub
committed
Fix source_directory handling
1 parent 7ffed8c commit c6c1f0a

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

src/sagemaker/workflow/_repack_model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,18 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
6868

6969
# copy source_dir to code/
7070
if source_dir:
71-
actual_source_dir_path = os.path.join("/opt/ml/code", source_dir)
72-
if os.path.exists(actual_source_dir_path):
73-
for item in os.listdir(actual_source_dir_path):
74-
s = os.path.join(actual_source_dir_path, item)
75-
d = os.path.join(os.path.join(code_dir, source_dir), item)
76-
if os.path.isdir(s):
77-
shutil.copytree(s, d)
78-
else:
79-
shutil.copy2(s, d)
71+
if os.path.exists(code_dir):
72+
shutil.rmtree(code_dir)
73+
shutil.copytree(source_dir, code_dir)
74+
# actual_source_dir_path = os.path.join("/opt/ml/code", source_dir)
75+
# if os.path.exists(actual_source_dir_path):
76+
# for item in os.listdir(actual_source_dir_path):
77+
# s = os.path.join(actual_source_dir_path, item)
78+
# d = os.path.join(os.path.join(code_dir, source_dir), item)
79+
# if os.path.isdir(s):
80+
# shutil.copytree(s, d)
81+
# else:
82+
# shutil.copy2(s, d)
8083

8184
# copy any dependencies to code/lib/
8285
if dependencies:

tests/unit/sagemaker/workflow/test_repack_model_script.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def test_repack_with_source_dir_and_dependencies(tmp):
155155
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a"))
156156
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb"))
157157
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b"))
158-
assert os.path.exists(os.path.join("/opt/ml/model/code/sourcedir", "foo.py"))
159-
assert os.path.exists(os.path.join("/opt/ml/model/code/sourcedir/some/dir", "a"))
158+
assert os.path.exists(os.path.join("/opt/ml/model/code/", "foo.py"))
159+
assert os.path.exists(os.path.join("/opt/ml/model/code/some/dir", "a"))
160160

161161

162162
def create_file_tree(root, tree):

0 commit comments

Comments
 (0)